arasuezofis's picture
Update app.py
3149ed3 verified
# app.py
# ------------------------------------------------------------
# Invoice Chat • SmolVLM-Instruct-250M (messages-mode, streaming)
# ------------------------------------------------------------
import io
import os
import re
from typing import List, Optional, Union
import gradio as gr
import torch
from PIL import Image
import fitz # PyMuPDF
from transformers import (
AutoProcessor,
AutoTokenizer,
AutoModelForImageTextToText, # modern replacement for AutoModelForVision2Seq
TextIteratorStreamer,
)
# -----------------------------
# Model bootstrap
# -----------------------------
MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype=DTYPE).to(DEVICE).eval()
SYSTEM_PROMPT = (
"You are an invoice assistant. Respond ONLY using details visible in the uploaded document. "
"If a field (invoice number, date, totals, tax, vendor, etc.) is not clearly visible, say so."
)
# -----------------------------
# Utilities
# -----------------------------
def pdf_to_images_from_bytes(pdf_bytes: bytes, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
images: List[Image.Image] = []
for i, page in enumerate(doc):
if i >= max_pages:
break
pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72))
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
return images
def pdf_to_images_from_path(path: str, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
doc = fitz.open(path)
images: List[Image.Image] = []
for i, page in enumerate(doc):
if i >= max_pages:
break
pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72))
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
return images
def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> List[Image.Image]:
"""
Accept PDF/PNG/JPEG (path/dict/bytes/BytesIO) and return a list of PIL images.
"""
if not file_val:
return []
path: Optional[str] = None
raw_bytes: Optional[bytes] = None
if isinstance(file_val, str) and os.path.exists(file_val):
path = file_val
elif isinstance(file_val, dict):
maybe_path = file_val.get("name") or file_val.get("path")
if isinstance(maybe_path, str) and os.path.exists(maybe_path):
path = maybe_path
else:
data = file_val.get("data")
if isinstance(data, (bytes, bytearray)):
raw_bytes = bytes(data)
elif isinstance(file_val, (bytes, bytearray)):
raw_bytes = bytes(file_val)
elif isinstance(file_val, io.BytesIO):
raw_bytes = file_val.getvalue()
if path:
if path.lower().endswith(".pdf"):
return pdf_to_images_from_path(path)
with open(path, "rb") as f:
img = Image.open(io.BytesIO(f.read())).convert("RGB")
return [img]
if raw_bytes:
if raw_bytes[:5] == b"%PDF-":
return pdf_to_images_from_bytes(raw_bytes)
img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
return [img]
return []
def parse_page_selection(value, num_pages: int) -> int:
"""
Accept 'Page 3', '3', 3, 'pg-2', etc. Return safe 0-based index.
"""
if num_pages <= 0 or value is None:
return 0
if isinstance(value, int):
idx = value - 1
else:
m = re.search(r"(\d+)", str(value).strip())
idx = int(m.group(1)) - 1 if m else 0
return max(0, min(num_pages - 1, idx))
def build_messages(history_msgs: list, user_text: str, images: List[Image.Image]):
"""
Compose the model prompt using OpenAI-style messages:
- system prompt
- trimmed prior messages
- current user turn (images + text)
"""
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
trimmed = history_msgs[-8:] if history_msgs else [] # keep the window tight
messages.extend(trimmed)
multimodal = []
for im in images:
multimodal.append(im)
if user_text.strip():
multimodal.append(user_text.strip())
messages.append({"role": "user", "content": multimodal})
return messages
# -----------------------------
# Core generation (streaming)
# -----------------------------
def generate_reply(images: List[Image.Image], user_text: str, history_msgs: list):
"""
Stream a model reply grounded on provided images + user question + compact chat history.
- Build prompt text (chat template) -> tokenize (dict)
- Vision tensors via processor (dict)
- Allow-list kwargs to model.generate
"""
messages = build_messages(history_msgs, user_text, images)
# 1) Build prompt as TEXT (not tokens)
prompt_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
# 2) Tokenize → mapping with input_ids/attention_mask
text_inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE)
# 3) Vision tensors (pixel_values)
vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE)
model_inputs = {
"input_ids": text_inputs["input_ids"],
**({"attention_mask": text_inputs["attention_mask"]} if "attention_mask" in text_inputs else {}),
**({"pixel_values": vision_inputs["pixel_values"]} if "pixel_values" in vision_inputs else {}),
}
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=512,
do_sample=False, # keep deterministic for enterprise-grade UX
)
import threading
t = threading.Thread(target=model.generate, kwargs=gen_kwargs)
t.start()
partial = ""
for token in streamer:
partial += token
yield partial
# -----------------------------
# Gradio UI Orchestration
# -----------------------------
def start_chat(file_val, page_index):
imgs = ensure_images(file_val)
if not imgs:
return (
gr.update(choices=[], value=None),
[],
None,
"No file loaded. Please upload a PDF/PNG/JPEG.",
)
choices = [f"Page {i+1}" for i in range(len(imgs))]
safe_idx = 0 if page_index is None else max(0, min(len(imgs) - 1, int(page_index)))
default_value = choices[safe_idx]
return (
gr.update(choices=choices, value=default_value),
imgs,
imgs[safe_idx],
"Document ready. Select a page and ask questions.",
)
def page_picker_changed(pages_dropdown, images_state):
if not images_state:
return None, gr.update()
idx = parse_page_selection(pages_dropdown, len(images_state))
selected = images_state[idx]
return selected, selected # preview + selected state
def chat(user_text, history_msgs, images_state, selected_img):
if not user_text or not user_text.strip():
return gr.update(), history_msgs
sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None)
if sel_img is None:
history_msgs = history_msgs + [
{"role": "user", "content": user_text},
{"role": "assistant", "content": "Please upload a document first."},
]
return gr.update(value=history_msgs), history_msgs
stream = generate_reply([sel_img], user_text, history_msgs)
acc = ""
for chunk in stream:
acc = chunk
yield (
history_msgs + [
{"role": "user", "content": user_text},
{"role": "assistant", "content": acc},
],
history_msgs + [
{"role": "user", "content": user_text},
{"role": "assistant", "content": acc},
],
)
# -----------------------------
# App definition
# -----------------------------
with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
gr.Markdown(
"## Invoice Chat • SmolVLM-Instruct-250M\n"
"Upload a PDF/PNG/JPEG, pick a page, and interrogate the document. "
"Optimized for CPU-friendly, low-latency insights."
)
with gr.Row():
with gr.Column(scale=1):
file = gr.File(label="Upload invoice (PDF / PNG / JPEG)")
pages = gr.Dropdown(
label="Select page (for PDFs)",
choices=[],
value=None,
allow_custom_value=True,
info="Type a page number (e.g., 2) or choose from the list.",
)
load_btn = gr.Button("Prepare Document", variant="primary")
with gr.Column(scale=2):
image_view = gr.Image(label="Current page/image", interactive=False)
# ✅ messages mode (no more tuples warnings)
chatbot = gr.Chatbot(height=400, type="messages")
user_box = gr.Textbox(
label="Your question",
placeholder="e.g., What is the invoice number and total with tax?",
)
ask_btn = gr.Button("Ask", variant="primary")
# Session state
images_state = gr.State([])
selected_img_state = gr.State(None)
# Events
load_btn.click(
start_chat,
inputs=[file, gr.State(0)],
outputs=[pages, images_state, image_view, gr.Textbox(visible=False)],
)
pages.change(
page_picker_changed,
inputs=[pages, images_state],
outputs=[image_view, selected_img_state],
)
ask_btn.click(
chat,
inputs=[user_box, chatbot, images_state, selected_img_state],
outputs=[chatbot, chatbot],
)
user_box.submit(
chat,
inputs=[user_box, chatbot, images_state, selected_img_state],
outputs=[chatbot, chatbot],
)
if __name__ == "__main__":
demo.launch()