arasuezofis commited on
Commit
9037c59
·
verified ·
1 Parent(s): 3a1ba6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -49
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # app.py
2
  # ------------------------------------------------------------
3
  # Invoice Chat • SmolVLM-Instruct-250M
4
- # Gradio Space with resilient page picker + streaming chat
5
  # ------------------------------------------------------------
6
 
7
  import io
@@ -16,7 +16,7 @@ import fitz # PyMuPDF
16
  from transformers import (
17
  AutoProcessor,
18
  AutoTokenizer,
19
- AutoModelForVision2Seq,
20
  TextIteratorStreamer,
21
  )
22
 
@@ -26,13 +26,14 @@ from transformers import (
26
  MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M"
27
 
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
- # float16 only if CUDA is available; on CPU use float32
30
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
31
 
32
- # Load tokenizer (has the chat template), processor (images), and model
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
 
34
  processor = AutoProcessor.from_pretrained(MODEL_ID)
35
- model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, torch_dtype=DTYPE)
 
36
  model.to(DEVICE).eval()
37
 
38
  SYSTEM_PROMPT = (
@@ -44,7 +45,6 @@ SYSTEM_PROMPT = (
44
  # Utilities
45
  # -----------------------------
46
  def pdf_to_images_from_bytes(pdf_bytes: bytes, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
47
- """Render first N pages of a PDF (in-memory) as PIL RGB images."""
48
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
49
  images: List[Image.Image] = []
50
  for i, page in enumerate(doc):
@@ -55,9 +55,7 @@ def pdf_to_images_from_bytes(pdf_bytes: bytes, max_pages: int = 8, dpi: int = 21
55
  images.append(img)
56
  return images
57
 
58
-
59
  def pdf_to_images_from_path(path: str, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
60
- """Render first N pages of a PDF (file path) as PIL RGB images."""
61
  doc = fitz.open(path)
62
  images: List[Image.Image] = []
63
  for i, page in enumerate(doc):
@@ -68,12 +66,11 @@ def pdf_to_images_from_path(path: str, max_pages: int = 8, dpi: int = 216) -> Li
68
  images.append(img)
69
  return images
70
 
71
-
72
  def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> List[Image.Image]:
73
  """
74
- Accept PDF/PNG/JPEG via Gradio File. Handles multiple shapes of input:
75
  - str path (tempfile path)
76
- - dict with 'name' or 'path'
77
  - bytes / BytesIO
78
  Returns a list of PIL images. PDFs => multi-image; PNG/JPEG => single image.
79
  """
@@ -98,7 +95,6 @@ def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> Li
98
  elif isinstance(file_val, io.BytesIO):
99
  raw_bytes = file_val.getvalue()
100
 
101
- # PDF vs Image
102
  def is_pdf_name(name: str) -> bool:
103
  return name.lower().endswith(".pdf")
104
 
@@ -117,35 +113,28 @@ def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> Li
117
 
118
  return []
119
 
120
-
121
  def parse_page_selection(value, num_pages: int) -> int:
122
  """
123
  Accept 'Page 3', '3', 3, 'pg-2', etc. Return safe 0-based index clamped to [0, num_pages-1].
124
- Defaults to 0 if unusable.
125
  """
126
  if num_pages <= 0:
127
  return 0
128
  if value is None:
129
  return 0
130
-
131
  if isinstance(value, int):
132
  idx = value - 1
133
  else:
134
  s = str(value).strip()
135
  m = re.search(r"(\d+)", s)
136
  idx = int(m.group(1)) - 1 if m else 0
137
-
138
  return max(0, min(num_pages - 1, idx))
139
 
140
-
141
  def build_messages(history: List[Tuple[str, str]], user_text: str, images: List[Image.Image]):
142
  """
143
- Construct chat-format messages compatible with tokenizer.apply_chat_template.
144
- We trim the history to avoid runaway context growth.
145
  """
146
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
147
  trimmed = history[-4:] if history else []
148
-
149
  for u, a in trimmed:
150
  messages.append({"role": "user", "content": u})
151
  messages.append({"role": "assistant", "content": a})
@@ -155,45 +144,52 @@ def build_messages(history: List[Tuple[str, str]], user_text: str, images: List[
155
  multimodal.append(im)
156
  if user_text.strip():
157
  multimodal.append(user_text.strip())
158
-
159
  messages.append({"role": "user", "content": multimodal})
160
  return messages
161
 
162
-
 
 
163
  def generate_reply(images: List[Image.Image], user_text: str, chat_history: List[Tuple[str, str]]):
164
  """
165
  Stream a model reply grounded on provided images + user question + compact chat history.
166
- Key fix: build text with chat template (string), then tokenize to get a dict.
 
 
167
  """
168
  messages = build_messages(chat_history, user_text, images)
169
 
170
- # 1) Get the chat prompt as TEXT (not tokens)
171
  prompt_text = tokenizer.apply_chat_template(
172
  messages,
173
  add_generation_prompt=True,
174
- tokenize=False, # <-- IMPORTANT: return string
175
  )
176
 
177
- # 2) Tokenize to get a dict (input_ids, attention_mask)
178
- text_inputs = tokenizer(
179
- prompt_text,
180
- return_tensors="pt"
181
- ).to(DEVICE)
182
 
183
- # 3) Vision tensors (dict with pixel_values)
184
  vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE)
185
 
186
- # 4) Merge dicts safely
187
- model_inputs = {**text_inputs, **vision_inputs}
 
 
 
 
 
 
188
 
189
- # 5) Stream with the same tokenizer
190
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
191
  gen_kwargs = dict(
192
  **model_inputs,
193
  streamer=streamer,
194
  max_new_tokens=512,
195
  do_sample=False,
196
- temperature=0.0,
197
  )
198
 
199
  import threading
@@ -205,8 +201,6 @@ def generate_reply(images: List[Image.Image], user_text: str, chat_history: List
205
  partial += token
206
  yield partial
207
 
208
-
209
-
210
  # -----------------------------
211
  # Gradio UI Orchestration
212
  # -----------------------------
@@ -219,11 +213,9 @@ def start_chat(file_val, page_index):
219
  None,
220
  "No file loaded. Please upload a PDF/PNG/JPEG.",
221
  )
222
-
223
  choices = [f"Page {i+1}" for i in range(len(imgs))]
224
  safe_idx = 0 if page_index is None else max(0, min(len(imgs) - 1, int(page_index)))
225
  default_value = choices[safe_idx]
226
-
227
  return (
228
  gr.update(choices=choices, value=default_value),
229
  imgs,
@@ -231,19 +223,16 @@ def start_chat(file_val, page_index):
231
  "Document ready. Select a page and ask questions.",
232
  )
233
 
234
-
235
  def page_picker_changed(pages_dropdown, images_state):
236
  if not images_state:
237
  return None, gr.update()
238
  idx = parse_page_selection(pages_dropdown, len(images_state))
239
  selected = images_state[idx]
240
- return selected, selected # for preview and selected state
241
-
242
 
243
  def chat(user_text, history, images_state, selected_img):
244
  if not user_text or not user_text.strip():
245
  return gr.update(), history
246
-
247
  sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None)
248
  if sel_img is None:
249
  history = history + [(user_text, "Please upload a document first.")]
@@ -255,7 +244,6 @@ def chat(user_text, history, images_state, selected_img):
255
  acc = chunk
256
  yield history + [(user_text, acc)], history + [(user_text, acc)]
257
 
258
-
259
  # -----------------------------
260
  # App definition
261
  # -----------------------------
@@ -265,7 +253,6 @@ with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
265
  "Upload a PDF/PNG/JPEG, pick a page, and interrogate the document. "
266
  "Optimized for CPU-friendly, low-latency insights."
267
  )
268
-
269
  with gr.Row():
270
  with gr.Column(scale=1):
271
  file = gr.File(label="Upload invoice (PDF / PNG / JPEG)")
@@ -273,14 +260,15 @@ with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
273
  label="Select page (for PDFs)",
274
  choices=[],
275
  value=None,
276
- allow_custom_value=True, # set False to lock to dropdown values
277
  info="Type a page number (e.g., 2) or choose from the list."
278
  )
279
  load_btn = gr.Button("Prepare Document", variant="primary")
280
  with gr.Column(scale=2):
281
  image_view = gr.Image(label="Current page/image", interactive=False)
282
 
283
- chatbot = gr.Chatbot(height=400)
 
284
  user_box = gr.Textbox(
285
  label="Your question",
286
  placeholder="e.g., What is the invoice number and total with tax?",
@@ -297,13 +285,11 @@ with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
297
  inputs=[file, gr.State(0)],
298
  outputs=[pages, images_state, image_view, gr.Textbox(visible=False)]
299
  )
300
-
301
  pages.change(
302
  page_picker_changed,
303
  inputs=[pages, images_state],
304
  outputs=[image_view, selected_img_state]
305
  )
306
-
307
  ask_btn.click(
308
  chat,
309
  inputs=[user_box, chatbot, images_state, selected_img_state],
 
1
  # app.py
2
  # ------------------------------------------------------------
3
  # Invoice Chat • SmolVLM-Instruct-250M
4
+ # Gradio Space with robust page picker + safe streaming chat
5
  # ------------------------------------------------------------
6
 
7
  import io
 
16
  from transformers import (
17
  AutoProcessor,
18
  AutoTokenizer,
19
+ AutoModelForImageTextToText, # <= new, replaces AutoModelForVision2Seq
20
  TextIteratorStreamer,
21
  )
22
 
 
26
  MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M"
27
 
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
29
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
30
 
31
+ # Tokenizer has the chat template
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
33
+ # Processor handles vision tensors
34
  processor = AutoProcessor.from_pretrained(MODEL_ID)
35
+ # New class to avoid deprecation warnings
36
+ model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype=DTYPE)
37
  model.to(DEVICE).eval()
38
 
39
  SYSTEM_PROMPT = (
 
45
  # Utilities
46
  # -----------------------------
47
  def pdf_to_images_from_bytes(pdf_bytes: bytes, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
 
48
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
49
  images: List[Image.Image] = []
50
  for i, page in enumerate(doc):
 
55
  images.append(img)
56
  return images
57
 
 
58
  def pdf_to_images_from_path(path: str, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
 
59
  doc = fitz.open(path)
60
  images: List[Image.Image] = []
61
  for i, page in enumerate(doc):
 
66
  images.append(img)
67
  return images
68
 
 
69
  def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> List[Image.Image]:
70
  """
71
+ Accept PDF/PNG/JPEG via Gradio File. Handles:
72
  - str path (tempfile path)
73
+ - dict with 'name'/'path' or 'data'
74
  - bytes / BytesIO
75
  Returns a list of PIL images. PDFs => multi-image; PNG/JPEG => single image.
76
  """
 
95
  elif isinstance(file_val, io.BytesIO):
96
  raw_bytes = file_val.getvalue()
97
 
 
98
  def is_pdf_name(name: str) -> bool:
99
  return name.lower().endswith(".pdf")
100
 
 
113
 
114
  return []
115
 
 
116
  def parse_page_selection(value, num_pages: int) -> int:
117
  """
118
  Accept 'Page 3', '3', 3, 'pg-2', etc. Return safe 0-based index clamped to [0, num_pages-1].
 
119
  """
120
  if num_pages <= 0:
121
  return 0
122
  if value is None:
123
  return 0
 
124
  if isinstance(value, int):
125
  idx = value - 1
126
  else:
127
  s = str(value).strip()
128
  m = re.search(r"(\d+)", s)
129
  idx = int(m.group(1)) - 1 if m else 0
 
130
  return max(0, min(num_pages - 1, idx))
131
 
 
132
  def build_messages(history: List[Tuple[str, str]], user_text: str, images: List[Image.Image]):
133
  """
134
+ Construct chat-format messages for tokenizer.apply_chat_template.
 
135
  """
136
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
137
  trimmed = history[-4:] if history else []
 
138
  for u, a in trimmed:
139
  messages.append({"role": "user", "content": u})
140
  messages.append({"role": "assistant", "content": a})
 
144
  multimodal.append(im)
145
  if user_text.strip():
146
  multimodal.append(user_text.strip())
 
147
  messages.append({"role": "user", "content": multimodal})
148
  return messages
149
 
150
+ # -----------------------------
151
+ # Core generation (streaming)
152
+ # -----------------------------
153
  def generate_reply(images: List[Image.Image], user_text: str, chat_history: List[Tuple[str, str]]):
154
  """
155
  Stream a model reply grounded on provided images + user question + compact chat history.
156
+ - Build prompt as TEXT (chat template) -> tokenize to dict (input_ids, attention_mask)
157
+ - Vision tensors via processor (pixel_values)
158
+ - Pass ONLY allowed kwargs to model.generate (avoid rows/cols etc.)
159
  """
160
  messages = build_messages(chat_history, user_text, images)
161
 
162
+ # 1) Build prompt text
163
  prompt_text = tokenizer.apply_chat_template(
164
  messages,
165
  add_generation_prompt=True,
166
+ tokenize=False, # IMPORTANT: return a string
167
  )
168
 
169
+ # 2) Tokenize to get a dict
170
+ text_inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE)
 
 
 
171
 
172
+ # 3) Vision tensors
173
  vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE)
174
 
175
+ # 4) Allow-list only the keys generate() expects
176
+ model_inputs = {
177
+ "input_ids": text_inputs["input_ids"],
178
+ # attention_mask may or may not exist depending on tokenizer; include if present
179
+ **({"attention_mask": text_inputs["attention_mask"]} if "attention_mask" in text_inputs else {}),
180
+ # vision inputs
181
+ **({"pixel_values": vision_inputs["pixel_values"]} if "pixel_values" in vision_inputs else {}),
182
+ }
183
 
184
+ # 5) Streamer uses the same tokenizer
185
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
186
+
187
  gen_kwargs = dict(
188
  **model_inputs,
189
  streamer=streamer,
190
  max_new_tokens=512,
191
  do_sample=False,
192
+ # NOTE: some I2T models ignore temperature/top_p; avoid passing unsupported flags
193
  )
194
 
195
  import threading
 
201
  partial += token
202
  yield partial
203
 
 
 
204
  # -----------------------------
205
  # Gradio UI Orchestration
206
  # -----------------------------
 
213
  None,
214
  "No file loaded. Please upload a PDF/PNG/JPEG.",
215
  )
 
216
  choices = [f"Page {i+1}" for i in range(len(imgs))]
217
  safe_idx = 0 if page_index is None else max(0, min(len(imgs) - 1, int(page_index)))
218
  default_value = choices[safe_idx]
 
219
  return (
220
  gr.update(choices=choices, value=default_value),
221
  imgs,
 
223
  "Document ready. Select a page and ask questions.",
224
  )
225
 
 
226
  def page_picker_changed(pages_dropdown, images_state):
227
  if not images_state:
228
  return None, gr.update()
229
  idx = parse_page_selection(pages_dropdown, len(images_state))
230
  selected = images_state[idx]
231
+ return selected, selected # preview + selected state
 
232
 
233
  def chat(user_text, history, images_state, selected_img):
234
  if not user_text or not user_text.strip():
235
  return gr.update(), history
 
236
  sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None)
237
  if sel_img is None:
238
  history = history + [(user_text, "Please upload a document first.")]
 
244
  acc = chunk
245
  yield history + [(user_text, acc)], history + [(user_text, acc)]
246
 
 
247
  # -----------------------------
248
  # App definition
249
  # -----------------------------
 
253
  "Upload a PDF/PNG/JPEG, pick a page, and interrogate the document. "
254
  "Optimized for CPU-friendly, low-latency insights."
255
  )
 
256
  with gr.Row():
257
  with gr.Column(scale=1):
258
  file = gr.File(label="Upload invoice (PDF / PNG / JPEG)")
 
260
  label="Select page (for PDFs)",
261
  choices=[],
262
  value=None,
263
+ allow_custom_value=True,
264
  info="Type a page number (e.g., 2) or choose from the list."
265
  )
266
  load_btn = gr.Button("Prepare Document", variant="primary")
267
  with gr.Column(scale=2):
268
  image_view = gr.Image(label="Current page/image", interactive=False)
269
 
270
+ # Lock Chatbot type to silence deprecation warning
271
+ chatbot = gr.Chatbot(height=400, type="tuples")
272
  user_box = gr.Textbox(
273
  label="Your question",
274
  placeholder="e.g., What is the invoice number and total with tax?",
 
285
  inputs=[file, gr.State(0)],
286
  outputs=[pages, images_state, image_view, gr.Textbox(visible=False)]
287
  )
 
288
  pages.change(
289
  page_picker_changed,
290
  inputs=[pages, images_state],
291
  outputs=[image_view, selected_img_state]
292
  )
 
293
  ask_btn.click(
294
  chat,
295
  inputs=[user_box, chatbot, images_state, selected_img_state],