arasuezofis commited on
Commit
1fcca49
·
verified ·
1 Parent(s): 1e56cd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -28
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # app.py
2
  # ------------------------------------------------------------
3
  # Invoice Chat • SmolVLM-Instruct-250M
4
- # Operationalized for Hugging Face Spaces (Gradio SDK)
5
  # ------------------------------------------------------------
6
 
7
  import io
@@ -15,17 +15,22 @@ from PIL import Image
15
  import fitz # PyMuPDF
16
  from transformers import (
17
  AutoProcessor,
 
18
  AutoModelForVision2Seq,
19
  TextIteratorStreamer,
20
  )
21
 
22
  # -----------------------------
23
- # Model bootstrap (lean & mean)
24
  # -----------------------------
25
  MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M"
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
27
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
28
 
 
 
29
  processor = AutoProcessor.from_pretrained(MODEL_ID)
30
  model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, torch_dtype=DTYPE)
31
  model.to(DEVICE).eval()
@@ -68,26 +73,23 @@ def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> Li
68
  """
69
  Accept PDF/PNG/JPEG via Gradio File. Handles multiple shapes of input:
70
  - str path (tempfile path)
71
- - dict with 'name' or 'path' (some Gradio versions)
72
  - bytes / BytesIO
73
  Returns a list of PIL images. PDFs => multi-image; PNG/JPEG => single image.
74
  """
75
  if not file_val:
76
  return []
77
 
78
- # Normalize to path/bytes
79
  path: Optional[str] = None
80
  raw_bytes: Optional[bytes] = None
81
 
82
  if isinstance(file_val, str) and os.path.exists(file_val):
83
  path = file_val
84
  elif isinstance(file_val, dict):
85
- # Gradio sometimes passes a dict with keys like {'name': '/tmp/..', 'orig_name': 'x.pdf', 'size': ...}
86
  maybe_path = file_val.get("name") or file_val.get("path")
87
  if isinstance(maybe_path, str) and os.path.exists(maybe_path):
88
  path = maybe_path
89
  else:
90
- # if dict contains 'data' or similar
91
  data = file_val.get("data")
92
  if isinstance(data, (bytes, bytearray)):
93
  raw_bytes = bytes(data)
@@ -96,27 +98,23 @@ def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> Li
96
  elif isinstance(file_val, io.BytesIO):
97
  raw_bytes = file_val.getvalue()
98
 
99
- # Branch by PDF vs Image
100
- def is_pdf_from_name(name: str) -> bool:
101
  return name.lower().endswith(".pdf")
102
 
103
  if path:
104
- if is_pdf_from_name(path):
105
  return pdf_to_images_from_path(path)
106
- # Image path
107
  with open(path, "rb") as f:
108
  img = Image.open(io.BytesIO(f.read())).convert("RGB")
109
  return [img]
110
 
111
  if raw_bytes:
112
- # Try sniffing PDF header
113
  if raw_bytes[:5] == b"%PDF-":
114
  return pdf_to_images_from_bytes(raw_bytes)
115
- # Else treat as image bytes
116
  img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
117
  return [img]
118
 
119
- # Fallback: nothing usable
120
  return []
121
 
122
 
@@ -142,7 +140,7 @@ def parse_page_selection(value, num_pages: int) -> int:
142
 
143
  def build_messages(history: List[Tuple[str, str]], user_text: str, images: List[Image.Image]):
144
  """
145
- Construct chat-format messages compatible with processor.apply_chat_template.
146
  We trim the history to avoid runaway context growth.
147
  """
148
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
@@ -165,24 +163,25 @@ def build_messages(history: List[Tuple[str, str]], user_text: str, images: List[
165
  def generate_reply(images: List[Image.Image], user_text: str, chat_history: List[Tuple[str, str]]):
166
  """
167
  Stream a model reply grounded on provided images + user question + compact chat history.
 
168
  """
169
  messages = build_messages(chat_history, user_text, images)
170
 
171
- # Text context
172
- text_inputs = processor.apply_chat_template(
173
  messages,
174
  add_generation_prompt=True,
175
  tokenize=True,
176
  return_tensors="pt"
177
  ).to(DEVICE)
178
 
179
- # Vision tensors
180
  vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE)
181
 
182
- # Merge dicts
183
  model_inputs = {**text_inputs, **vision_inputs}
184
 
185
- streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
186
  gen_kwargs = dict(
187
  **model_inputs,
188
  streamer=streamer,
@@ -207,7 +206,6 @@ def generate_reply(images: List[Image.Image], user_text: str, chat_history: List
207
  def start_chat(file_val, page_index):
208
  imgs = ensure_images(file_val)
209
  if not imgs:
210
- # Reset the dropdown & return empty
211
  return (
212
  gr.update(choices=[], value=None),
213
  [],
@@ -237,10 +235,8 @@ def page_picker_changed(pages_dropdown, images_state):
237
 
238
  def chat(user_text, history, images_state, selected_img):
239
  if not user_text or not user_text.strip():
240
- # No update; just echo current state
241
  return gr.update(), history
242
 
243
- # Choose selected image; fallback to first page if needed
244
  sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None)
245
  if sel_img is None:
246
  history = history + [(user_text, "Please upload a document first.")]
@@ -260,7 +256,7 @@ with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
260
  gr.Markdown(
261
  "## Invoice Chat • SmolVLM-Instruct-250M\n"
262
  "Upload a PDF/PNG/JPEG, pick a page, and interrogate the document. "
263
- "This is a CPU-friendly, low-latency experience designed for rapid insight capture."
264
  )
265
 
266
  with gr.Row():
@@ -270,7 +266,7 @@ with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
270
  label="Select page (for PDFs)",
271
  choices=[],
272
  value=None,
273
- allow_custom_value=True, # set False to hard-lock to dropdown values
274
  info="Type a page number (e.g., 2) or choose from the list."
275
  )
276
  load_btn = gr.Button("Prepare Document", variant="primary")
@@ -284,25 +280,23 @@ with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
284
  )
285
  ask_btn = gr.Button("Ask", variant="primary")
286
 
287
- # Hidden session state
288
  images_state = gr.State([])
289
  selected_img_state = gr.State(None)
290
 
291
- # Wire up events
292
  load_btn.click(
293
  start_chat,
294
  inputs=[file, gr.State(0)],
295
  outputs=[pages, images_state, image_view, gr.Textbox(visible=False)]
296
  )
297
 
298
- # When the page dropdown changes, update both preview and the selected image state
299
  pages.change(
300
  page_picker_changed,
301
  inputs=[pages, images_state],
302
  outputs=[image_view, selected_img_state]
303
  )
304
 
305
- # Ask flows (streaming)
306
  ask_btn.click(
307
  chat,
308
  inputs=[user_box, chatbot, images_state, selected_img_state],
 
1
  # app.py
2
  # ------------------------------------------------------------
3
  # Invoice Chat • SmolVLM-Instruct-250M
4
+ # Gradio Space with resilient page picker + streaming chat
5
  # ------------------------------------------------------------
6
 
7
  import io
 
15
  import fitz # PyMuPDF
16
  from transformers import (
17
  AutoProcessor,
18
+ AutoTokenizer,
19
  AutoModelForVision2Seq,
20
  TextIteratorStreamer,
21
  )
22
 
23
  # -----------------------------
24
+ # Model bootstrap
25
  # -----------------------------
26
  MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M"
27
+
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ # float16 only if CUDA is available; on CPU use float32
30
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
31
 
32
+ # Load tokenizer (has the chat template), processor (images), and model
33
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
34
  processor = AutoProcessor.from_pretrained(MODEL_ID)
35
  model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, torch_dtype=DTYPE)
36
  model.to(DEVICE).eval()
 
73
  """
74
  Accept PDF/PNG/JPEG via Gradio File. Handles multiple shapes of input:
75
  - str path (tempfile path)
76
+ - dict with 'name' or 'path'
77
  - bytes / BytesIO
78
  Returns a list of PIL images. PDFs => multi-image; PNG/JPEG => single image.
79
  """
80
  if not file_val:
81
  return []
82
 
 
83
  path: Optional[str] = None
84
  raw_bytes: Optional[bytes] = None
85
 
86
  if isinstance(file_val, str) and os.path.exists(file_val):
87
  path = file_val
88
  elif isinstance(file_val, dict):
 
89
  maybe_path = file_val.get("name") or file_val.get("path")
90
  if isinstance(maybe_path, str) and os.path.exists(maybe_path):
91
  path = maybe_path
92
  else:
 
93
  data = file_val.get("data")
94
  if isinstance(data, (bytes, bytearray)):
95
  raw_bytes = bytes(data)
 
98
  elif isinstance(file_val, io.BytesIO):
99
  raw_bytes = file_val.getvalue()
100
 
101
+ # PDF vs Image
102
+ def is_pdf_name(name: str) -> bool:
103
  return name.lower().endswith(".pdf")
104
 
105
  if path:
106
+ if is_pdf_name(path):
107
  return pdf_to_images_from_path(path)
 
108
  with open(path, "rb") as f:
109
  img = Image.open(io.BytesIO(f.read())).convert("RGB")
110
  return [img]
111
 
112
  if raw_bytes:
 
113
  if raw_bytes[:5] == b"%PDF-":
114
  return pdf_to_images_from_bytes(raw_bytes)
 
115
  img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
116
  return [img]
117
 
 
118
  return []
119
 
120
 
 
140
 
141
  def build_messages(history: List[Tuple[str, str]], user_text: str, images: List[Image.Image]):
142
  """
143
+ Construct chat-format messages compatible with tokenizer.apply_chat_template.
144
  We trim the history to avoid runaway context growth.
145
  """
146
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
 
163
  def generate_reply(images: List[Image.Image], user_text: str, chat_history: List[Tuple[str, str]]):
164
  """
165
  Stream a model reply grounded on provided images + user question + compact chat history.
166
+ Key fix: use tokenizer.apply_chat_template and a streamer built with the same tokenizer.
167
  """
168
  messages = build_messages(chat_history, user_text, images)
169
 
170
+ # Text inputs via tokenizer chat template
171
+ text_inputs = tokenizer.apply_chat_template(
172
  messages,
173
  add_generation_prompt=True,
174
  tokenize=True,
175
  return_tensors="pt"
176
  ).to(DEVICE)
177
 
178
+ # Vision tensors via processor
179
  vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE)
180
 
181
+ # Merge dicts (input_ids, attention_mask, pixel_values)
182
  model_inputs = {**text_inputs, **vision_inputs}
183
 
184
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
185
  gen_kwargs = dict(
186
  **model_inputs,
187
  streamer=streamer,
 
206
  def start_chat(file_val, page_index):
207
  imgs = ensure_images(file_val)
208
  if not imgs:
 
209
  return (
210
  gr.update(choices=[], value=None),
211
  [],
 
235
 
236
  def chat(user_text, history, images_state, selected_img):
237
  if not user_text or not user_text.strip():
 
238
  return gr.update(), history
239
 
 
240
  sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None)
241
  if sel_img is None:
242
  history = history + [(user_text, "Please upload a document first.")]
 
256
  gr.Markdown(
257
  "## Invoice Chat • SmolVLM-Instruct-250M\n"
258
  "Upload a PDF/PNG/JPEG, pick a page, and interrogate the document. "
259
+ "Optimized for CPU-friendly, low-latency insights."
260
  )
261
 
262
  with gr.Row():
 
266
  label="Select page (for PDFs)",
267
  choices=[],
268
  value=None,
269
+ allow_custom_value=True, # set False to lock to dropdown values
270
  info="Type a page number (e.g., 2) or choose from the list."
271
  )
272
  load_btn = gr.Button("Prepare Document", variant="primary")
 
280
  )
281
  ask_btn = gr.Button("Ask", variant="primary")
282
 
283
+ # Session state
284
  images_state = gr.State([])
285
  selected_img_state = gr.State(None)
286
 
287
+ # Events
288
  load_btn.click(
289
  start_chat,
290
  inputs=[file, gr.State(0)],
291
  outputs=[pages, images_state, image_view, gr.Textbox(visible=False)]
292
  )
293
 
 
294
  pages.change(
295
  page_picker_changed,
296
  inputs=[pages, images_state],
297
  outputs=[image_view, selected_img_state]
298
  )
299
 
 
300
  ask_btn.click(
301
  chat,
302
  inputs=[user_box, chatbot, images_state, selected_img_state],