Spaces:
Sleeping
Sleeping
| # app.py | |
| # ------------------------------------------------------------ | |
| # Invoice Chat • SmolVLM-Instruct-250M (messages-mode, streaming) | |
| # ------------------------------------------------------------ | |
| import io | |
| import os | |
| import re | |
| from typing import List, Optional, Union | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import fitz # PyMuPDF | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoTokenizer, | |
| AutoModelForImageTextToText, # modern replacement for AutoModelForVision2Seq | |
| TextIteratorStreamer, | |
| ) | |
| # ----------------------------- | |
| # Model bootstrap | |
| # ----------------------------- | |
| MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype=DTYPE).to(DEVICE).eval() | |
| SYSTEM_PROMPT = ( | |
| "You are an invoice assistant. Respond ONLY using details visible in the uploaded document. " | |
| "If a field (invoice number, date, totals, tax, vendor, etc.) is not clearly visible, say so." | |
| ) | |
| # ----------------------------- | |
| # Utilities | |
| # ----------------------------- | |
| def pdf_to_images_from_bytes(pdf_bytes: bytes, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]: | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| images: List[Image.Image] = [] | |
| for i, page in enumerate(doc): | |
| if i >= max_pages: | |
| break | |
| pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72)) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| images.append(img) | |
| return images | |
| def pdf_to_images_from_path(path: str, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]: | |
| doc = fitz.open(path) | |
| images: List[Image.Image] = [] | |
| for i, page in enumerate(doc): | |
| if i >= max_pages: | |
| break | |
| pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72)) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| images.append(img) | |
| return images | |
| def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> List[Image.Image]: | |
| """ | |
| Accept PDF/PNG/JPEG (path/dict/bytes/BytesIO) and return a list of PIL images. | |
| """ | |
| if not file_val: | |
| return [] | |
| path: Optional[str] = None | |
| raw_bytes: Optional[bytes] = None | |
| if isinstance(file_val, str) and os.path.exists(file_val): | |
| path = file_val | |
| elif isinstance(file_val, dict): | |
| maybe_path = file_val.get("name") or file_val.get("path") | |
| if isinstance(maybe_path, str) and os.path.exists(maybe_path): | |
| path = maybe_path | |
| else: | |
| data = file_val.get("data") | |
| if isinstance(data, (bytes, bytearray)): | |
| raw_bytes = bytes(data) | |
| elif isinstance(file_val, (bytes, bytearray)): | |
| raw_bytes = bytes(file_val) | |
| elif isinstance(file_val, io.BytesIO): | |
| raw_bytes = file_val.getvalue() | |
| if path: | |
| if path.lower().endswith(".pdf"): | |
| return pdf_to_images_from_path(path) | |
| with open(path, "rb") as f: | |
| img = Image.open(io.BytesIO(f.read())).convert("RGB") | |
| return [img] | |
| if raw_bytes: | |
| if raw_bytes[:5] == b"%PDF-": | |
| return pdf_to_images_from_bytes(raw_bytes) | |
| img = Image.open(io.BytesIO(raw_bytes)).convert("RGB") | |
| return [img] | |
| return [] | |
| def parse_page_selection(value, num_pages: int) -> int: | |
| """ | |
| Accept 'Page 3', '3', 3, 'pg-2', etc. Return safe 0-based index. | |
| """ | |
| if num_pages <= 0 or value is None: | |
| return 0 | |
| if isinstance(value, int): | |
| idx = value - 1 | |
| else: | |
| m = re.search(r"(\d+)", str(value).strip()) | |
| idx = int(m.group(1)) - 1 if m else 0 | |
| return max(0, min(num_pages - 1, idx)) | |
| def build_messages(history_msgs: list, user_text: str, images: List[Image.Image]): | |
| """ | |
| Compose the model prompt using OpenAI-style messages: | |
| - system prompt | |
| - trimmed prior messages | |
| - current user turn (images + text) | |
| """ | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| trimmed = history_msgs[-8:] if history_msgs else [] # keep the window tight | |
| messages.extend(trimmed) | |
| multimodal = [] | |
| for im in images: | |
| multimodal.append(im) | |
| if user_text.strip(): | |
| multimodal.append(user_text.strip()) | |
| messages.append({"role": "user", "content": multimodal}) | |
| return messages | |
| # ----------------------------- | |
| # Core generation (streaming) | |
| # ----------------------------- | |
| def generate_reply(images: List[Image.Image], user_text: str, history_msgs: list): | |
| """ | |
| Stream a model reply grounded on provided images + user question + compact chat history. | |
| - Build prompt text (chat template) -> tokenize (dict) | |
| - Vision tensors via processor (dict) | |
| - Allow-list kwargs to model.generate | |
| """ | |
| messages = build_messages(history_msgs, user_text, images) | |
| # 1) Build prompt as TEXT (not tokens) | |
| prompt_text = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ) | |
| # 2) Tokenize → mapping with input_ids/attention_mask | |
| text_inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE) | |
| # 3) Vision tensors (pixel_values) | |
| vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE) | |
| model_inputs = { | |
| "input_ids": text_inputs["input_ids"], | |
| **({"attention_mask": text_inputs["attention_mask"]} if "attention_mask" in text_inputs else {}), | |
| **({"pixel_values": vision_inputs["pixel_values"]} if "pixel_values" in vision_inputs else {}), | |
| } | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| **model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=512, | |
| do_sample=False, # keep deterministic for enterprise-grade UX | |
| ) | |
| import threading | |
| t = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| t.start() | |
| partial = "" | |
| for token in streamer: | |
| partial += token | |
| yield partial | |
| # ----------------------------- | |
| # Gradio UI Orchestration | |
| # ----------------------------- | |
| def start_chat(file_val, page_index): | |
| imgs = ensure_images(file_val) | |
| if not imgs: | |
| return ( | |
| gr.update(choices=[], value=None), | |
| [], | |
| None, | |
| "No file loaded. Please upload a PDF/PNG/JPEG.", | |
| ) | |
| choices = [f"Page {i+1}" for i in range(len(imgs))] | |
| safe_idx = 0 if page_index is None else max(0, min(len(imgs) - 1, int(page_index))) | |
| default_value = choices[safe_idx] | |
| return ( | |
| gr.update(choices=choices, value=default_value), | |
| imgs, | |
| imgs[safe_idx], | |
| "Document ready. Select a page and ask questions.", | |
| ) | |
| def page_picker_changed(pages_dropdown, images_state): | |
| if not images_state: | |
| return None, gr.update() | |
| idx = parse_page_selection(pages_dropdown, len(images_state)) | |
| selected = images_state[idx] | |
| return selected, selected # preview + selected state | |
| def chat(user_text, history_msgs, images_state, selected_img): | |
| if not user_text or not user_text.strip(): | |
| return gr.update(), history_msgs | |
| sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None) | |
| if sel_img is None: | |
| history_msgs = history_msgs + [ | |
| {"role": "user", "content": user_text}, | |
| {"role": "assistant", "content": "Please upload a document first."}, | |
| ] | |
| return gr.update(value=history_msgs), history_msgs | |
| stream = generate_reply([sel_img], user_text, history_msgs) | |
| acc = "" | |
| for chunk in stream: | |
| acc = chunk | |
| yield ( | |
| history_msgs + [ | |
| {"role": "user", "content": user_text}, | |
| {"role": "assistant", "content": acc}, | |
| ], | |
| history_msgs + [ | |
| {"role": "user", "content": user_text}, | |
| {"role": "assistant", "content": acc}, | |
| ], | |
| ) | |
| # ----------------------------- | |
| # App definition | |
| # ----------------------------- | |
| with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo: | |
| gr.Markdown( | |
| "## Invoice Chat • SmolVLM-Instruct-250M\n" | |
| "Upload a PDF/PNG/JPEG, pick a page, and interrogate the document. " | |
| "Optimized for CPU-friendly, low-latency insights." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file = gr.File(label="Upload invoice (PDF / PNG / JPEG)") | |
| pages = gr.Dropdown( | |
| label="Select page (for PDFs)", | |
| choices=[], | |
| value=None, | |
| allow_custom_value=True, | |
| info="Type a page number (e.g., 2) or choose from the list.", | |
| ) | |
| load_btn = gr.Button("Prepare Document", variant="primary") | |
| with gr.Column(scale=2): | |
| image_view = gr.Image(label="Current page/image", interactive=False) | |
| # ✅ messages mode (no more tuples warnings) | |
| chatbot = gr.Chatbot(height=400, type="messages") | |
| user_box = gr.Textbox( | |
| label="Your question", | |
| placeholder="e.g., What is the invoice number and total with tax?", | |
| ) | |
| ask_btn = gr.Button("Ask", variant="primary") | |
| # Session state | |
| images_state = gr.State([]) | |
| selected_img_state = gr.State(None) | |
| # Events | |
| load_btn.click( | |
| start_chat, | |
| inputs=[file, gr.State(0)], | |
| outputs=[pages, images_state, image_view, gr.Textbox(visible=False)], | |
| ) | |
| pages.change( | |
| page_picker_changed, | |
| inputs=[pages, images_state], | |
| outputs=[image_view, selected_img_state], | |
| ) | |
| ask_btn.click( | |
| chat, | |
| inputs=[user_box, chatbot, images_state, selected_img_state], | |
| outputs=[chatbot, chatbot], | |
| ) | |
| user_box.submit( | |
| chat, | |
| inputs=[user_box, chatbot, images_state, selected_img_state], | |
| outputs=[chatbot, chatbot], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |