# app.py # ------------------------------------------------------------ # Invoice Chat • SmolVLM-Instruct-250M (messages-mode, streaming) # ------------------------------------------------------------ import io import os import re from typing import List, Optional, Union import gradio as gr import torch from PIL import Image import fitz # PyMuPDF from transformers import ( AutoProcessor, AutoTokenizer, AutoModelForImageTextToText, # modern replacement for AutoModelForVision2Seq TextIteratorStreamer, ) # ----------------------------- # Model bootstrap # ----------------------------- MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) processor = AutoProcessor.from_pretrained(MODEL_ID) model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype=DTYPE).to(DEVICE).eval() SYSTEM_PROMPT = ( "You are an invoice assistant. Respond ONLY using details visible in the uploaded document. " "If a field (invoice number, date, totals, tax, vendor, etc.) is not clearly visible, say so." ) # ----------------------------- # Utilities # ----------------------------- def pdf_to_images_from_bytes(pdf_bytes: bytes, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]: doc = fitz.open(stream=pdf_bytes, filetype="pdf") images: List[Image.Image] = [] for i, page in enumerate(doc): if i >= max_pages: break pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72)) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) images.append(img) return images def pdf_to_images_from_path(path: str, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]: doc = fitz.open(path) images: List[Image.Image] = [] for i, page in enumerate(doc): if i >= max_pages: break pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72)) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) images.append(img) return images def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> List[Image.Image]: """ Accept PDF/PNG/JPEG (path/dict/bytes/BytesIO) and return a list of PIL images. """ if not file_val: return [] path: Optional[str] = None raw_bytes: Optional[bytes] = None if isinstance(file_val, str) and os.path.exists(file_val): path = file_val elif isinstance(file_val, dict): maybe_path = file_val.get("name") or file_val.get("path") if isinstance(maybe_path, str) and os.path.exists(maybe_path): path = maybe_path else: data = file_val.get("data") if isinstance(data, (bytes, bytearray)): raw_bytes = bytes(data) elif isinstance(file_val, (bytes, bytearray)): raw_bytes = bytes(file_val) elif isinstance(file_val, io.BytesIO): raw_bytes = file_val.getvalue() if path: if path.lower().endswith(".pdf"): return pdf_to_images_from_path(path) with open(path, "rb") as f: img = Image.open(io.BytesIO(f.read())).convert("RGB") return [img] if raw_bytes: if raw_bytes[:5] == b"%PDF-": return pdf_to_images_from_bytes(raw_bytes) img = Image.open(io.BytesIO(raw_bytes)).convert("RGB") return [img] return [] def parse_page_selection(value, num_pages: int) -> int: """ Accept 'Page 3', '3', 3, 'pg-2', etc. Return safe 0-based index. """ if num_pages <= 0 or value is None: return 0 if isinstance(value, int): idx = value - 1 else: m = re.search(r"(\d+)", str(value).strip()) idx = int(m.group(1)) - 1 if m else 0 return max(0, min(num_pages - 1, idx)) def build_messages(history_msgs: list, user_text: str, images: List[Image.Image]): """ Compose the model prompt using OpenAI-style messages: - system prompt - trimmed prior messages - current user turn (images + text) """ messages = [{"role": "system", "content": SYSTEM_PROMPT}] trimmed = history_msgs[-8:] if history_msgs else [] # keep the window tight messages.extend(trimmed) multimodal = [] for im in images: multimodal.append(im) if user_text.strip(): multimodal.append(user_text.strip()) messages.append({"role": "user", "content": multimodal}) return messages # ----------------------------- # Core generation (streaming) # ----------------------------- def generate_reply(images: List[Image.Image], user_text: str, history_msgs: list): """ Stream a model reply grounded on provided images + user question + compact chat history. - Build prompt text (chat template) -> tokenize (dict) - Vision tensors via processor (dict) - Allow-list kwargs to model.generate """ messages = build_messages(history_msgs, user_text, images) # 1) Build prompt as TEXT (not tokens) prompt_text = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, ) # 2) Tokenize → mapping with input_ids/attention_mask text_inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE) # 3) Vision tensors (pixel_values) vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE) model_inputs = { "input_ids": text_inputs["input_ids"], **({"attention_mask": text_inputs["attention_mask"]} if "attention_mask" in text_inputs else {}), **({"pixel_values": vision_inputs["pixel_values"]} if "pixel_values" in vision_inputs else {}), } streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( **model_inputs, streamer=streamer, max_new_tokens=512, do_sample=False, # keep deterministic for enterprise-grade UX ) import threading t = threading.Thread(target=model.generate, kwargs=gen_kwargs) t.start() partial = "" for token in streamer: partial += token yield partial # ----------------------------- # Gradio UI Orchestration # ----------------------------- def start_chat(file_val, page_index): imgs = ensure_images(file_val) if not imgs: return ( gr.update(choices=[], value=None), [], None, "No file loaded. Please upload a PDF/PNG/JPEG.", ) choices = [f"Page {i+1}" for i in range(len(imgs))] safe_idx = 0 if page_index is None else max(0, min(len(imgs) - 1, int(page_index))) default_value = choices[safe_idx] return ( gr.update(choices=choices, value=default_value), imgs, imgs[safe_idx], "Document ready. Select a page and ask questions.", ) def page_picker_changed(pages_dropdown, images_state): if not images_state: return None, gr.update() idx = parse_page_selection(pages_dropdown, len(images_state)) selected = images_state[idx] return selected, selected # preview + selected state def chat(user_text, history_msgs, images_state, selected_img): if not user_text or not user_text.strip(): return gr.update(), history_msgs sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None) if sel_img is None: history_msgs = history_msgs + [ {"role": "user", "content": user_text}, {"role": "assistant", "content": "Please upload a document first."}, ] return gr.update(value=history_msgs), history_msgs stream = generate_reply([sel_img], user_text, history_msgs) acc = "" for chunk in stream: acc = chunk yield ( history_msgs + [ {"role": "user", "content": user_text}, {"role": "assistant", "content": acc}, ], history_msgs + [ {"role": "user", "content": user_text}, {"role": "assistant", "content": acc}, ], ) # ----------------------------- # App definition # ----------------------------- with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo: gr.Markdown( "## Invoice Chat • SmolVLM-Instruct-250M\n" "Upload a PDF/PNG/JPEG, pick a page, and interrogate the document. " "Optimized for CPU-friendly, low-latency insights." ) with gr.Row(): with gr.Column(scale=1): file = gr.File(label="Upload invoice (PDF / PNG / JPEG)") pages = gr.Dropdown( label="Select page (for PDFs)", choices=[], value=None, allow_custom_value=True, info="Type a page number (e.g., 2) or choose from the list.", ) load_btn = gr.Button("Prepare Document", variant="primary") with gr.Column(scale=2): image_view = gr.Image(label="Current page/image", interactive=False) # ✅ messages mode (no more tuples warnings) chatbot = gr.Chatbot(height=400, type="messages") user_box = gr.Textbox( label="Your question", placeholder="e.g., What is the invoice number and total with tax?", ) ask_btn = gr.Button("Ask", variant="primary") # Session state images_state = gr.State([]) selected_img_state = gr.State(None) # Events load_btn.click( start_chat, inputs=[file, gr.State(0)], outputs=[pages, images_state, image_view, gr.Textbox(visible=False)], ) pages.change( page_picker_changed, inputs=[pages, images_state], outputs=[image_view, selected_img_state], ) ask_btn.click( chat, inputs=[user_box, chatbot, images_state, selected_img_state], outputs=[chatbot, chatbot], ) user_box.submit( chat, inputs=[user_box, chatbot, images_state, selected_img_state], outputs=[chatbot, chatbot], ) if __name__ == "__main__": demo.launch()