arasuezofis commited on
Commit
1e56cd8
·
verified ·
1 Parent(s): 55b1735

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -102
app.py CHANGED
@@ -1,6 +1,13 @@
 
 
 
 
 
 
1
  import io
2
- import time
3
- from typing import List, Tuple, Optional
 
4
 
5
  import gradio as gr
6
  import torch
@@ -12,95 +19,168 @@ from transformers import (
12
  TextIteratorStreamer,
13
  )
14
 
15
- MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M" # 250M instruct variant
16
- # If you ever need to swap models (e.g., 256M/500M), just change the ID.
17
-
18
- # Load once at startup
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
20
- dtype = torch.float16 if device == "cuda" else torch.float32
21
 
22
  processor = AutoProcessor.from_pretrained(MODEL_ID)
23
- model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, torch_dtype=dtype)
24
- model.to(device)
25
- model.eval()
26
 
27
  SYSTEM_PROMPT = (
28
- "You are an invoice assistant. Answer strictly based on the uploaded document. "
29
- "If asked for fields (invoice number, date, totals, etc.), extract them from the image."
30
  )
31
 
32
- def pdf_to_images(pdf_bytes: bytes, max_pages: int = 5, dpi: int = 216) -> List[Image.Image]:
33
- """
34
- Render first N pages of a PDF to PIL images (RGB).
35
- """
 
36
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
37
- images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  for i, page in enumerate(doc):
39
  if i >= max_pages:
40
  break
41
- # Render page
42
- pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72))
43
  img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
44
  images.append(img)
45
  return images
46
 
47
- def ensure_images(file: Optional[gr.File]) -> List[Image.Image]:
 
48
  """
49
- Accepts a PDF/PNG/JPEG and returns a list of PIL images.
50
- - PDF => multiple images (page picker will handle selection)
51
- - PNG/JPG => single image
 
 
52
  """
53
- if file is None:
54
  return []
55
- mime = file.mime_type or ""
56
- data = file.read()
57
 
58
- if "pdf" in mime or (file.name and file.name.lower().endswith(".pdf")):
59
- return pdf_to_images(data, max_pages=8)
60
- # Image path
61
- img = Image.open(io.BytesIO(data)).convert("RGB")
62
- return [img]
63
 
64
- def generate_reply(images: List[Image.Image], user_text: str, chat_history: List[Tuple[str, str]]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """
66
- Stream a reply grounded on chosen image(s) + chat history.
67
- We only keep a compact history to stay lean on memory.
68
  """
69
- # Build multimodal messages per transformers' chat template
70
- # Format: [{"role":"system","content":...}, {"role":"user","content":[text, image, ...]}, ...]
71
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
 
72
 
73
- # Keep only last 4 exchanges to avoid context bloat
74
- trimmed = chat_history[-4:] if chat_history else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  for u, a in trimmed:
77
  messages.append({"role": "user", "content": u})
78
  messages.append({"role": "assistant", "content": a})
79
 
80
- # Add the current turn with images
81
- multimodal_content = []
82
- if images:
83
- # SmolVLM supports multiple images; push them before the text question
84
- for im in images:
85
- multimodal_content.append(im)
86
  if user_text.strip():
87
- multimodal_content.append(user_text.strip())
88
 
89
- messages.append({"role": "user", "content": multimodal_content})
 
 
 
 
 
 
 
 
90
 
91
- # Tokenize with chat template
92
- inputs = processor.apply_chat_template(
93
  messages,
94
  add_generation_prompt=True,
95
  tokenize=True,
96
  return_tensors="pt"
97
- ).to(device)
98
 
99
- # Vision inputs: processor handles images separately
100
- vision_inputs = processor(images=images, return_tensors="pt").to(device)
101
 
102
- # Merge text & vision inputs
103
- model_inputs = {**inputs, **vision_inputs}
104
 
105
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
106
  gen_kwargs = dict(
@@ -111,86 +191,126 @@ def generate_reply(images: List[Image.Image], user_text: str, chat_history: List
111
  temperature=0.0,
112
  )
113
 
114
- # Non-blocking generation
115
  import threading
116
- thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
117
- thread.start()
118
 
119
  partial = ""
120
  for token in streamer:
121
  partial += token
122
  yield partial
123
 
124
- def start_chat(file, page_index):
125
- # Convert to images and preselect a page
126
- imgs = ensure_images(file)
 
 
 
127
  if not imgs:
128
- return gr.update(choices=[], value=None), None, "No file loaded yet."
 
 
 
 
 
 
129
 
130
  choices = [f"Page {i+1}" for i in range(len(imgs))]
131
- value = choices[min(page_index, len(imgs)-1)] if page_index is not None else choices[0]
132
- return gr.update(choices=choices, value=value), imgs, "Document ready. Select a page and ask questions."
 
 
 
 
 
 
 
 
133
 
134
  def page_picker_changed(pages_dropdown, images_state):
135
  if not images_state:
136
- return None
137
- idx = max(0, int(pages_dropdown.split()[-1]) - 1)
138
- return images_state[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- with gr.Blocks(title="Invoice Chat (SmolVLM-250M)") as demo:
141
- gr.Markdown("# Invoice Chat • SmolVLM-Instruct-250M\nAsk questions grounded on your uploaded invoice.")
142
  with gr.Row():
143
  with gr.Column(scale=1):
144
- file = gr.File(label="Upload invoice (PDF/PNG/JPEG)")
145
- pages = gr.Dropdown(label="Select page (for PDFs)", choices=[], value=None)
146
- load_btn = gr.Button("Prepare Document")
 
 
 
 
 
 
147
  with gr.Column(scale=2):
148
  image_view = gr.Image(label="Current page/image", interactive=False)
149
- chatbot = gr.Chatbot(height=380)
150
- user_box = gr.Textbox(label="Your question", placeholder="e.g., What is the invoice number and total?")
151
- ask_btn = gr.Button("Ask")
152
 
153
- # Hidden states
 
 
 
 
 
 
 
154
  images_state = gr.State([])
155
  selected_img_state = gr.State(None)
156
 
157
- # Wire events
158
  load_btn.click(
159
  start_chat,
160
  inputs=[file, gr.State(0)],
161
- outputs=[pages, images_state, gr.Textbox(visible=False)]
 
 
 
 
 
 
 
162
  )
163
- pages.change(page_picker_changed, inputs=[pages, images_state], outputs=[image_view])
164
-
165
- def chat(user_text, history, images_state, image_view):
166
- if not user_text.strip():
167
- return gr.update(), history
168
- # Choose the selected image; if none, fall back to first
169
- sel_img = None
170
- if image_view is not None and isinstance(image_view, dict) and image_view.get("image"):
171
- # gr.Image returns a dict in some contexts; handle robustly
172
- sel_img = Image.open(image_view["image"]).convert("RGB")
173
- elif images_state:
174
- sel_img = images_state[0]
175
-
176
- if sel_img is None:
177
- history = history + [(user_text, "Please upload a document first.")]
178
- return gr.update(value=history), history
179
-
180
- stream = generate_reply([sel_img], user_text, history)
181
- acc = ""
182
- for chunk in stream:
183
- acc = chunk
184
- yield history + [(user_text, acc)], history + [(user_text, acc)]
185
 
 
186
  ask_btn.click(
187
  chat,
188
- inputs=[user_box, chatbot, images_state, image_view],
189
  outputs=[chatbot, chatbot]
190
  )
191
  user_box.submit(
192
  chat,
193
- inputs=[user_box, chatbot, images_state, image_view],
194
  outputs=[chatbot, chatbot]
195
  )
196
 
 
1
+ # app.py
2
+ # ------------------------------------------------------------
3
+ # Invoice Chat • SmolVLM-Instruct-250M
4
+ # Operationalized for Hugging Face Spaces (Gradio SDK)
5
+ # ------------------------------------------------------------
6
+
7
  import io
8
+ import os
9
+ import re
10
+ from typing import List, Tuple, Optional, Union
11
 
12
  import gradio as gr
13
  import torch
 
19
  TextIteratorStreamer,
20
  )
21
 
22
+ # -----------------------------
23
+ # Model bootstrap (lean & mean)
24
+ # -----------------------------
25
+ MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M"
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
28
 
29
  processor = AutoProcessor.from_pretrained(MODEL_ID)
30
+ model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, torch_dtype=DTYPE)
31
+ model.to(DEVICE).eval()
 
32
 
33
  SYSTEM_PROMPT = (
34
+ "You are an invoice assistant. Respond ONLY using details visible in the uploaded document. "
35
+ "If a field (invoice number, date, totals, tax, vendor, etc.) is not clearly visible, say so."
36
  )
37
 
38
+ # -----------------------------
39
+ # Utilities
40
+ # -----------------------------
41
+ def pdf_to_images_from_bytes(pdf_bytes: bytes, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
42
+ """Render first N pages of a PDF (in-memory) as PIL RGB images."""
43
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
44
+ images: List[Image.Image] = []
45
+ for i, page in enumerate(doc):
46
+ if i >= max_pages:
47
+ break
48
+ pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72))
49
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
50
+ images.append(img)
51
+ return images
52
+
53
+
54
+ def pdf_to_images_from_path(path: str, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
55
+ """Render first N pages of a PDF (file path) as PIL RGB images."""
56
+ doc = fitz.open(path)
57
+ images: List[Image.Image] = []
58
  for i, page in enumerate(doc):
59
  if i >= max_pages:
60
  break
61
+ pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72))
 
62
  img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
63
  images.append(img)
64
  return images
65
 
66
+
67
+ def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> List[Image.Image]:
68
  """
69
+ Accept PDF/PNG/JPEG via Gradio File. Handles multiple shapes of input:
70
+ - str path (tempfile path)
71
+ - dict with 'name' or 'path' (some Gradio versions)
72
+ - bytes / BytesIO
73
+ Returns a list of PIL images. PDFs => multi-image; PNG/JPEG => single image.
74
  """
75
+ if not file_val:
76
  return []
 
 
77
 
78
+ # Normalize to path/bytes
79
+ path: Optional[str] = None
80
+ raw_bytes: Optional[bytes] = None
 
 
81
 
82
+ if isinstance(file_val, str) and os.path.exists(file_val):
83
+ path = file_val
84
+ elif isinstance(file_val, dict):
85
+ # Gradio sometimes passes a dict with keys like {'name': '/tmp/..', 'orig_name': 'x.pdf', 'size': ...}
86
+ maybe_path = file_val.get("name") or file_val.get("path")
87
+ if isinstance(maybe_path, str) and os.path.exists(maybe_path):
88
+ path = maybe_path
89
+ else:
90
+ # if dict contains 'data' or similar
91
+ data = file_val.get("data")
92
+ if isinstance(data, (bytes, bytearray)):
93
+ raw_bytes = bytes(data)
94
+ elif isinstance(file_val, (bytes, bytearray)):
95
+ raw_bytes = bytes(file_val)
96
+ elif isinstance(file_val, io.BytesIO):
97
+ raw_bytes = file_val.getvalue()
98
+
99
+ # Branch by PDF vs Image
100
+ def is_pdf_from_name(name: str) -> bool:
101
+ return name.lower().endswith(".pdf")
102
+
103
+ if path:
104
+ if is_pdf_from_name(path):
105
+ return pdf_to_images_from_path(path)
106
+ # Image path
107
+ with open(path, "rb") as f:
108
+ img = Image.open(io.BytesIO(f.read())).convert("RGB")
109
+ return [img]
110
+
111
+ if raw_bytes:
112
+ # Try sniffing PDF header
113
+ if raw_bytes[:5] == b"%PDF-":
114
+ return pdf_to_images_from_bytes(raw_bytes)
115
+ # Else treat as image bytes
116
+ img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
117
+ return [img]
118
+
119
+ # Fallback: nothing usable
120
+ return []
121
+
122
+
123
+ def parse_page_selection(value, num_pages: int) -> int:
124
  """
125
+ Accept 'Page 3', '3', 3, 'pg-2', etc. Return safe 0-based index clamped to [0, num_pages-1].
126
+ Defaults to 0 if unusable.
127
  """
128
+ if num_pages <= 0:
129
+ return 0
130
+ if value is None:
131
+ return 0
132
 
133
+ if isinstance(value, int):
134
+ idx = value - 1
135
+ else:
136
+ s = str(value).strip()
137
+ m = re.search(r"(\d+)", s)
138
+ idx = int(m.group(1)) - 1 if m else 0
139
+
140
+ return max(0, min(num_pages - 1, idx))
141
+
142
+
143
+ def build_messages(history: List[Tuple[str, str]], user_text: str, images: List[Image.Image]):
144
+ """
145
+ Construct chat-format messages compatible with processor.apply_chat_template.
146
+ We trim the history to avoid runaway context growth.
147
+ """
148
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
149
+ trimmed = history[-4:] if history else []
150
 
151
  for u, a in trimmed:
152
  messages.append({"role": "user", "content": u})
153
  messages.append({"role": "assistant", "content": a})
154
 
155
+ multimodal = []
156
+ for im in images:
157
+ multimodal.append(im)
 
 
 
158
  if user_text.strip():
159
+ multimodal.append(user_text.strip())
160
 
161
+ messages.append({"role": "user", "content": multimodal})
162
+ return messages
163
+
164
+
165
+ def generate_reply(images: List[Image.Image], user_text: str, chat_history: List[Tuple[str, str]]):
166
+ """
167
+ Stream a model reply grounded on provided images + user question + compact chat history.
168
+ """
169
+ messages = build_messages(chat_history, user_text, images)
170
 
171
+ # Text context
172
+ text_inputs = processor.apply_chat_template(
173
  messages,
174
  add_generation_prompt=True,
175
  tokenize=True,
176
  return_tensors="pt"
177
+ ).to(DEVICE)
178
 
179
+ # Vision tensors
180
+ vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE)
181
 
182
+ # Merge dicts
183
+ model_inputs = {**text_inputs, **vision_inputs}
184
 
185
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
186
  gen_kwargs = dict(
 
191
  temperature=0.0,
192
  )
193
 
 
194
  import threading
195
+ t = threading.Thread(target=model.generate, kwargs=gen_kwargs)
196
+ t.start()
197
 
198
  partial = ""
199
  for token in streamer:
200
  partial += token
201
  yield partial
202
 
203
+
204
+ # -----------------------------
205
+ # Gradio UI Orchestration
206
+ # -----------------------------
207
+ def start_chat(file_val, page_index):
208
+ imgs = ensure_images(file_val)
209
  if not imgs:
210
+ # Reset the dropdown & return empty
211
+ return (
212
+ gr.update(choices=[], value=None),
213
+ [],
214
+ None,
215
+ "No file loaded. Please upload a PDF/PNG/JPEG.",
216
+ )
217
 
218
  choices = [f"Page {i+1}" for i in range(len(imgs))]
219
+ safe_idx = 0 if page_index is None else max(0, min(len(imgs) - 1, int(page_index)))
220
+ default_value = choices[safe_idx]
221
+
222
+ return (
223
+ gr.update(choices=choices, value=default_value),
224
+ imgs,
225
+ imgs[safe_idx],
226
+ "Document ready. Select a page and ask questions.",
227
+ )
228
+
229
 
230
  def page_picker_changed(pages_dropdown, images_state):
231
  if not images_state:
232
+ return None, gr.update()
233
+ idx = parse_page_selection(pages_dropdown, len(images_state))
234
+ selected = images_state[idx]
235
+ return selected, selected # for preview and selected state
236
+
237
+
238
+ def chat(user_text, history, images_state, selected_img):
239
+ if not user_text or not user_text.strip():
240
+ # No update; just echo current state
241
+ return gr.update(), history
242
+
243
+ # Choose selected image; fallback to first page if needed
244
+ sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None)
245
+ if sel_img is None:
246
+ history = history + [(user_text, "Please upload a document first.")]
247
+ return gr.update(value=history), history
248
+
249
+ stream = generate_reply([sel_img], user_text, history)
250
+ acc = ""
251
+ for chunk in stream:
252
+ acc = chunk
253
+ yield history + [(user_text, acc)], history + [(user_text, acc)]
254
+
255
+
256
+ # -----------------------------
257
+ # App definition
258
+ # -----------------------------
259
+ with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
260
+ gr.Markdown(
261
+ "## Invoice Chat • SmolVLM-Instruct-250M\n"
262
+ "Upload a PDF/PNG/JPEG, pick a page, and interrogate the document. "
263
+ "This is a CPU-friendly, low-latency experience designed for rapid insight capture."
264
+ )
265
 
 
 
266
  with gr.Row():
267
  with gr.Column(scale=1):
268
+ file = gr.File(label="Upload invoice (PDF / PNG / JPEG)")
269
+ pages = gr.Dropdown(
270
+ label="Select page (for PDFs)",
271
+ choices=[],
272
+ value=None,
273
+ allow_custom_value=True, # set False to hard-lock to dropdown values
274
+ info="Type a page number (e.g., 2) or choose from the list."
275
+ )
276
+ load_btn = gr.Button("Prepare Document", variant="primary")
277
  with gr.Column(scale=2):
278
  image_view = gr.Image(label="Current page/image", interactive=False)
 
 
 
279
 
280
+ chatbot = gr.Chatbot(height=400)
281
+ user_box = gr.Textbox(
282
+ label="Your question",
283
+ placeholder="e.g., What is the invoice number and total with tax?",
284
+ )
285
+ ask_btn = gr.Button("Ask", variant="primary")
286
+
287
+ # Hidden session state
288
  images_state = gr.State([])
289
  selected_img_state = gr.State(None)
290
 
291
+ # Wire up events
292
  load_btn.click(
293
  start_chat,
294
  inputs=[file, gr.State(0)],
295
+ outputs=[pages, images_state, image_view, gr.Textbox(visible=False)]
296
+ )
297
+
298
+ # When the page dropdown changes, update both preview and the selected image state
299
+ pages.change(
300
+ page_picker_changed,
301
+ inputs=[pages, images_state],
302
+ outputs=[image_view, selected_img_state]
303
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ # Ask flows (streaming)
306
  ask_btn.click(
307
  chat,
308
+ inputs=[user_box, chatbot, images_state, selected_img_state],
309
  outputs=[chatbot, chatbot]
310
  )
311
  user_box.submit(
312
  chat,
313
+ inputs=[user_box, chatbot, images_state, selected_img_state],
314
  outputs=[chatbot, chatbot]
315
  )
316