arasuezofis commited on
Commit
3149ed3
·
verified ·
1 Parent(s): 62c1db6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -63
app.py CHANGED
@@ -1,13 +1,12 @@
1
  # app.py
2
  # ------------------------------------------------------------
3
- # Invoice Chat • SmolVLM-Instruct-250M
4
- # Gradio Space with robust page picker + safe streaming chat
5
  # ------------------------------------------------------------
6
 
7
  import io
8
  import os
9
  import re
10
- from typing import List, Tuple, Optional, Union
11
 
12
  import gradio as gr
13
  import torch
@@ -16,7 +15,7 @@ import fitz # PyMuPDF
16
  from transformers import (
17
  AutoProcessor,
18
  AutoTokenizer,
19
- AutoModelForImageTextToText, # <= new, replaces AutoModelForVision2Seq
20
  TextIteratorStreamer,
21
  )
22
 
@@ -24,17 +23,12 @@ from transformers import (
24
  # Model bootstrap
25
  # -----------------------------
26
  MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M"
27
-
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
30
 
31
- # Tokenizer has the chat template
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
33
- # Processor handles vision tensors
34
  processor = AutoProcessor.from_pretrained(MODEL_ID)
35
- # New class to avoid deprecation warnings
36
- model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype=DTYPE)
37
- model.to(DEVICE).eval()
38
 
39
  SYSTEM_PROMPT = (
40
  "You are an invoice assistant. Respond ONLY using details visible in the uploaded document. "
@@ -68,15 +62,10 @@ def pdf_to_images_from_path(path: str, max_pages: int = 8, dpi: int = 216) -> Li
68
 
69
  def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> List[Image.Image]:
70
  """
71
- Accept PDF/PNG/JPEG via Gradio File. Handles:
72
- - str path (tempfile path)
73
- - dict with 'name'/'path' or 'data'
74
- - bytes / BytesIO
75
- Returns a list of PIL images. PDFs => multi-image; PNG/JPEG => single image.
76
  """
77
  if not file_val:
78
  return []
79
-
80
  path: Optional[str] = None
81
  raw_bytes: Optional[bytes] = None
82
 
@@ -95,11 +84,8 @@ def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> Li
95
  elif isinstance(file_val, io.BytesIO):
96
  raw_bytes = file_val.getvalue()
97
 
98
- def is_pdf_name(name: str) -> bool:
99
- return name.lower().endswith(".pdf")
100
-
101
  if path:
102
- if is_pdf_name(path):
103
  return pdf_to_images_from_path(path)
104
  with open(path, "rb") as f:
105
  img = Image.open(io.BytesIO(f.read())).convert("RGB")
@@ -115,34 +101,28 @@ def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> Li
115
 
116
  def parse_page_selection(value, num_pages: int) -> int:
117
  """
118
- Accept 'Page 3', '3', 3, 'pg-2', etc. Return safe 0-based index clamped to [0, num_pages-1].
119
  """
120
- if num_pages <= 0:
121
- return 0
122
- if value is None:
123
  return 0
124
  if isinstance(value, int):
125
  idx = value - 1
126
  else:
127
- s = str(value).strip()
128
- m = re.search(r"(\d+)", s)
129
  idx = int(m.group(1)) - 1 if m else 0
130
  return max(0, min(num_pages - 1, idx))
131
 
132
  def build_messages(history_msgs: list, user_text: str, images: List[Image.Image]):
133
  """
134
- Compose the full prompt for the model:
135
  - system prompt
136
- - trimmed history (already in {'role','content'} format)
137
- - current user turn with images + text
138
  """
139
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
140
-
141
- # Keep last 8 messages to stay lean
142
- trimmed = history_msgs[-8:] if history_msgs else []
143
  messages.extend(trimmed)
144
 
145
- # Current user turn: images first, then text
146
  multimodal = []
147
  for im in images:
148
  multimodal.append(im)
@@ -155,46 +135,40 @@ def build_messages(history_msgs: list, user_text: str, images: List[Image.Image]
155
  # -----------------------------
156
  # Core generation (streaming)
157
  # -----------------------------
158
- def generate_reply(images: List[Image.Image], user_text: str, chat_history: List[Tuple[str, str]]):
159
  """
160
  Stream a model reply grounded on provided images + user question + compact chat history.
161
- - Build prompt as TEXT (chat template) -> tokenize to dict (input_ids, attention_mask)
162
- - Vision tensors via processor (pixel_values)
163
- - Pass ONLY allowed kwargs to model.generate (avoid rows/cols etc.)
164
  """
165
- messages = build_messages(chat_history, user_text, images)
166
 
167
- # 1) Build prompt text
168
  prompt_text = tokenizer.apply_chat_template(
169
  messages,
170
  add_generation_prompt=True,
171
- tokenize=False, # IMPORTANT: return a string
172
  )
173
 
174
- # 2) Tokenize to get a dict
175
  text_inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE)
176
 
177
- # 3) Vision tensors
178
  vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE)
179
 
180
- # 4) Allow-list only the keys generate() expects
181
  model_inputs = {
182
  "input_ids": text_inputs["input_ids"],
183
- # attention_mask may or may not exist depending on tokenizer; include if present
184
  **({"attention_mask": text_inputs["attention_mask"]} if "attention_mask" in text_inputs else {}),
185
- # vision inputs
186
  **({"pixel_values": vision_inputs["pixel_values"]} if "pixel_values" in vision_inputs else {}),
187
  }
188
 
189
- # 5) Streamer uses the same tokenizer
190
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
191
-
192
  gen_kwargs = dict(
193
  **model_inputs,
194
  streamer=streamer,
195
  max_new_tokens=512,
196
- do_sample=False,
197
- # NOTE: some I2T models ignore temperature/top_p; avoid passing unsupported flags
198
  )
199
 
200
  import threading
@@ -241,19 +215,16 @@ def chat(user_text, history_msgs, images_state, selected_img):
241
 
242
  sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None)
243
  if sel_img is None:
244
- # push a system-style nudge
245
  history_msgs = history_msgs + [
246
  {"role": "user", "content": user_text},
247
- {"role": "assistant", "content": "Please upload a document first."}
248
  ]
249
  return gr.update(value=history_msgs), history_msgs
250
 
251
- # Stream the assistant reply
252
  stream = generate_reply([sel_img], user_text, history_msgs)
253
  acc = ""
254
  for chunk in stream:
255
  acc = chunk
256
- # do incremental streaming by replacing the last assistant message
257
  yield (
258
  history_msgs + [
259
  {"role": "user", "content": user_text},
@@ -262,11 +233,9 @@ def chat(user_text, history_msgs, images_state, selected_img):
262
  history_msgs + [
263
  {"role": "user", "content": user_text},
264
  {"role": "assistant", "content": acc},
265
- ]
266
  )
267
 
268
-
269
-
270
  # -----------------------------
271
  # App definition
272
  # -----------------------------
@@ -284,14 +253,14 @@ with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
284
  choices=[],
285
  value=None,
286
  allow_custom_value=True,
287
- info="Type a page number (e.g., 2) or choose from the list."
288
  )
289
  load_btn = gr.Button("Prepare Document", variant="primary")
290
  with gr.Column(scale=2):
291
  image_view = gr.Image(label="Current page/image", interactive=False)
292
 
293
- # Lock Chatbot type to silence deprecation warning
294
- chatbot = gr.Chatbot(height=400, type="tuples")
295
  user_box = gr.Textbox(
296
  label="Your question",
297
  placeholder="e.g., What is the invoice number and total with tax?",
@@ -306,22 +275,22 @@ with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
306
  load_btn.click(
307
  start_chat,
308
  inputs=[file, gr.State(0)],
309
- outputs=[pages, images_state, image_view, gr.Textbox(visible=False)]
310
  )
311
  pages.change(
312
  page_picker_changed,
313
  inputs=[pages, images_state],
314
- outputs=[image_view, selected_img_state]
315
  )
316
  ask_btn.click(
317
  chat,
318
  inputs=[user_box, chatbot, images_state, selected_img_state],
319
- outputs=[chatbot, chatbot]
320
  )
321
  user_box.submit(
322
  chat,
323
  inputs=[user_box, chatbot, images_state, selected_img_state],
324
- outputs=[chatbot, chatbot]
325
  )
326
 
327
  if __name__ == "__main__":
 
1
  # app.py
2
  # ------------------------------------------------------------
3
+ # Invoice Chat • SmolVLM-Instruct-250M (messages-mode, streaming)
 
4
  # ------------------------------------------------------------
5
 
6
  import io
7
  import os
8
  import re
9
+ from typing import List, Optional, Union
10
 
11
  import gradio as gr
12
  import torch
 
15
  from transformers import (
16
  AutoProcessor,
17
  AutoTokenizer,
18
+ AutoModelForImageTextToText, # modern replacement for AutoModelForVision2Seq
19
  TextIteratorStreamer,
20
  )
21
 
 
23
  # Model bootstrap
24
  # -----------------------------
25
  MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M"
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
28
 
 
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
 
30
  processor = AutoProcessor.from_pretrained(MODEL_ID)
31
+ model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype=DTYPE).to(DEVICE).eval()
 
 
32
 
33
  SYSTEM_PROMPT = (
34
  "You are an invoice assistant. Respond ONLY using details visible in the uploaded document. "
 
62
 
63
  def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> List[Image.Image]:
64
  """
65
+ Accept PDF/PNG/JPEG (path/dict/bytes/BytesIO) and return a list of PIL images.
 
 
 
 
66
  """
67
  if not file_val:
68
  return []
 
69
  path: Optional[str] = None
70
  raw_bytes: Optional[bytes] = None
71
 
 
84
  elif isinstance(file_val, io.BytesIO):
85
  raw_bytes = file_val.getvalue()
86
 
 
 
 
87
  if path:
88
+ if path.lower().endswith(".pdf"):
89
  return pdf_to_images_from_path(path)
90
  with open(path, "rb") as f:
91
  img = Image.open(io.BytesIO(f.read())).convert("RGB")
 
101
 
102
  def parse_page_selection(value, num_pages: int) -> int:
103
  """
104
+ Accept 'Page 3', '3', 3, 'pg-2', etc. Return safe 0-based index.
105
  """
106
+ if num_pages <= 0 or value is None:
 
 
107
  return 0
108
  if isinstance(value, int):
109
  idx = value - 1
110
  else:
111
+ m = re.search(r"(\d+)", str(value).strip())
 
112
  idx = int(m.group(1)) - 1 if m else 0
113
  return max(0, min(num_pages - 1, idx))
114
 
115
  def build_messages(history_msgs: list, user_text: str, images: List[Image.Image]):
116
  """
117
+ Compose the model prompt using OpenAI-style messages:
118
  - system prompt
119
+ - trimmed prior messages
120
+ - current user turn (images + text)
121
  """
122
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
123
+ trimmed = history_msgs[-8:] if history_msgs else [] # keep the window tight
 
 
124
  messages.extend(trimmed)
125
 
 
126
  multimodal = []
127
  for im in images:
128
  multimodal.append(im)
 
135
  # -----------------------------
136
  # Core generation (streaming)
137
  # -----------------------------
138
+ def generate_reply(images: List[Image.Image], user_text: str, history_msgs: list):
139
  """
140
  Stream a model reply grounded on provided images + user question + compact chat history.
141
+ - Build prompt text (chat template) -> tokenize (dict)
142
+ - Vision tensors via processor (dict)
143
+ - Allow-list kwargs to model.generate
144
  """
145
+ messages = build_messages(history_msgs, user_text, images)
146
 
147
+ # 1) Build prompt as TEXT (not tokens)
148
  prompt_text = tokenizer.apply_chat_template(
149
  messages,
150
  add_generation_prompt=True,
151
+ tokenize=False,
152
  )
153
 
154
+ # 2) Tokenize mapping with input_ids/attention_mask
155
  text_inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE)
156
 
157
+ # 3) Vision tensors (pixel_values)
158
  vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE)
159
 
 
160
  model_inputs = {
161
  "input_ids": text_inputs["input_ids"],
 
162
  **({"attention_mask": text_inputs["attention_mask"]} if "attention_mask" in text_inputs else {}),
 
163
  **({"pixel_values": vision_inputs["pixel_values"]} if "pixel_values" in vision_inputs else {}),
164
  }
165
 
 
166
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
167
  gen_kwargs = dict(
168
  **model_inputs,
169
  streamer=streamer,
170
  max_new_tokens=512,
171
+ do_sample=False, # keep deterministic for enterprise-grade UX
 
172
  )
173
 
174
  import threading
 
215
 
216
  sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None)
217
  if sel_img is None:
 
218
  history_msgs = history_msgs + [
219
  {"role": "user", "content": user_text},
220
+ {"role": "assistant", "content": "Please upload a document first."},
221
  ]
222
  return gr.update(value=history_msgs), history_msgs
223
 
 
224
  stream = generate_reply([sel_img], user_text, history_msgs)
225
  acc = ""
226
  for chunk in stream:
227
  acc = chunk
 
228
  yield (
229
  history_msgs + [
230
  {"role": "user", "content": user_text},
 
233
  history_msgs + [
234
  {"role": "user", "content": user_text},
235
  {"role": "assistant", "content": acc},
236
+ ],
237
  )
238
 
 
 
239
  # -----------------------------
240
  # App definition
241
  # -----------------------------
 
253
  choices=[],
254
  value=None,
255
  allow_custom_value=True,
256
+ info="Type a page number (e.g., 2) or choose from the list.",
257
  )
258
  load_btn = gr.Button("Prepare Document", variant="primary")
259
  with gr.Column(scale=2):
260
  image_view = gr.Image(label="Current page/image", interactive=False)
261
 
262
+ # messages mode (no more tuples warnings)
263
+ chatbot = gr.Chatbot(height=400, type="messages")
264
  user_box = gr.Textbox(
265
  label="Your question",
266
  placeholder="e.g., What is the invoice number and total with tax?",
 
275
  load_btn.click(
276
  start_chat,
277
  inputs=[file, gr.State(0)],
278
+ outputs=[pages, images_state, image_view, gr.Textbox(visible=False)],
279
  )
280
  pages.change(
281
  page_picker_changed,
282
  inputs=[pages, images_state],
283
+ outputs=[image_view, selected_img_state],
284
  )
285
  ask_btn.click(
286
  chat,
287
  inputs=[user_box, chatbot, images_state, selected_img_state],
288
+ outputs=[chatbot, chatbot],
289
  )
290
  user_box.submit(
291
  chat,
292
  inputs=[user_box, chatbot, images_state, selected_img_state],
293
+ outputs=[chatbot, chatbot],
294
  )
295
 
296
  if __name__ == "__main__":