File size: 10,197 Bytes
1e56cd8
 
3149ed3
1e56cd8
 
638e61c
1e56cd8
 
3149ed3
638e61c
 
 
 
 
 
 
1fcca49
3149ed3
638e61c
 
 
1e56cd8
1fcca49
1e56cd8
 
 
 
638e61c
1fcca49
638e61c
3149ed3
638e61c
 
1e56cd8
 
638e61c
 
1e56cd8
 
 
 
638e61c
1e56cd8
 
 
 
 
 
 
 
 
 
 
 
638e61c
 
 
1e56cd8
638e61c
 
 
 
1e56cd8
638e61c
3149ed3
638e61c
1e56cd8
638e61c
1e56cd8
 
638e61c
1e56cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3149ed3
1e56cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
638e61c
3149ed3
638e61c
3149ed3
1e56cd8
 
 
 
3149ed3
1e56cd8
 
 
62c1db6
1e56cd8
3149ed3
62c1db6
3149ed3
 
1e56cd8
 
3149ed3
62c1db6
 
1e56cd8
 
 
638e61c
1e56cd8
62c1db6
1e56cd8
 
 
9037c59
 
 
3149ed3
1e56cd8
 
3149ed3
 
 
1e56cd8
3149ed3
638e61c
3149ed3
3a1ba6d
638e61c
 
3149ed3
3a1ba6d
 
3149ed3
9037c59
638e61c
3149ed3
1e56cd8
638e61c
9037c59
 
 
 
 
638e61c
1fcca49
638e61c
 
 
 
3149ed3
638e61c
 
 
1e56cd8
 
638e61c
 
 
 
 
 
1e56cd8
 
 
 
 
638e61c
1e56cd8
 
 
 
 
 
638e61c
1e56cd8
 
 
 
 
 
 
 
 
638e61c
 
1e56cd8
 
 
9037c59
1e56cd8
62c1db6
1e56cd8
62c1db6
 
1e56cd8
 
62c1db6
 
3149ed3
62c1db6
 
 
 
1e56cd8
 
 
62c1db6
 
 
 
 
 
 
 
3149ed3
62c1db6
 
1e56cd8
 
 
 
 
 
 
1fcca49
1e56cd8
638e61c
 
1e56cd8
 
 
 
 
9037c59
3149ed3
1e56cd8
 
638e61c
 
 
3149ed3
 
1e56cd8
 
 
 
 
 
1fcca49
638e61c
 
 
1fcca49
638e61c
 
 
3149ed3
1e56cd8
 
 
 
3149ed3
638e61c
 
 
1e56cd8
3149ed3
638e61c
 
 
1e56cd8
3149ed3
638e61c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# app.py
# ------------------------------------------------------------
# Invoice Chat • SmolVLM-Instruct-250M (messages-mode, streaming)
# ------------------------------------------------------------

import io
import os
import re
from typing import List, Optional, Union

import gradio as gr
import torch
from PIL import Image
import fitz  # PyMuPDF
from transformers import (
    AutoProcessor,
    AutoTokenizer,
    AutoModelForImageTextToText,  # modern replacement for AutoModelForVision2Seq
    TextIteratorStreamer,
)

# -----------------------------
# Model bootstrap
# -----------------------------
MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct-250M"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype=DTYPE).to(DEVICE).eval()

SYSTEM_PROMPT = (
    "You are an invoice assistant. Respond ONLY using details visible in the uploaded document. "
    "If a field (invoice number, date, totals, tax, vendor, etc.) is not clearly visible, say so."
)

# -----------------------------
# Utilities
# -----------------------------
def pdf_to_images_from_bytes(pdf_bytes: bytes, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
    doc = fitz.open(stream=pdf_bytes, filetype="pdf")
    images: List[Image.Image] = []
    for i, page in enumerate(doc):
        if i >= max_pages:
            break
        pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72))
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        images.append(img)
    return images

def pdf_to_images_from_path(path: str, max_pages: int = 8, dpi: int = 216) -> List[Image.Image]:
    doc = fitz.open(path)
    images: List[Image.Image] = []
    for i, page in enumerate(doc):
        if i >= max_pages:
            break
        pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 72, dpi / 72))
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        images.append(img)
    return images

def ensure_images(file_val: Optional[Union[str, dict, bytes, io.BytesIO]]) -> List[Image.Image]:
    """
    Accept PDF/PNG/JPEG (path/dict/bytes/BytesIO) and return a list of PIL images.
    """
    if not file_val:
        return []
    path: Optional[str] = None
    raw_bytes: Optional[bytes] = None

    if isinstance(file_val, str) and os.path.exists(file_val):
        path = file_val
    elif isinstance(file_val, dict):
        maybe_path = file_val.get("name") or file_val.get("path")
        if isinstance(maybe_path, str) and os.path.exists(maybe_path):
            path = maybe_path
        else:
            data = file_val.get("data")
            if isinstance(data, (bytes, bytearray)):
                raw_bytes = bytes(data)
    elif isinstance(file_val, (bytes, bytearray)):
        raw_bytes = bytes(file_val)
    elif isinstance(file_val, io.BytesIO):
        raw_bytes = file_val.getvalue()

    if path:
        if path.lower().endswith(".pdf"):
            return pdf_to_images_from_path(path)
        with open(path, "rb") as f:
            img = Image.open(io.BytesIO(f.read())).convert("RGB")
        return [img]

    if raw_bytes:
        if raw_bytes[:5] == b"%PDF-":
            return pdf_to_images_from_bytes(raw_bytes)
        img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
        return [img]

    return []

def parse_page_selection(value, num_pages: int) -> int:
    """
    Accept 'Page 3', '3', 3, 'pg-2', etc. Return safe 0-based index.
    """
    if num_pages <= 0 or value is None:
        return 0
    if isinstance(value, int):
        idx = value - 1
    else:
        m = re.search(r"(\d+)", str(value).strip())
        idx = int(m.group(1)) - 1 if m else 0
    return max(0, min(num_pages - 1, idx))

def build_messages(history_msgs: list, user_text: str, images: List[Image.Image]):
    """
    Compose the model prompt using OpenAI-style messages:
      - system prompt
      - trimmed prior messages
      - current user turn (images + text)
    """
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    trimmed = history_msgs[-8:] if history_msgs else []  # keep the window tight
    messages.extend(trimmed)

    multimodal = []
    for im in images:
        multimodal.append(im)
    if user_text.strip():
        multimodal.append(user_text.strip())

    messages.append({"role": "user", "content": multimodal})
    return messages

# -----------------------------
# Core generation (streaming)
# -----------------------------
def generate_reply(images: List[Image.Image], user_text: str, history_msgs: list):
    """
    Stream a model reply grounded on provided images + user question + compact chat history.
    - Build prompt text (chat template) -> tokenize (dict)
    - Vision tensors via processor (dict)
    - Allow-list kwargs to model.generate
    """
    messages = build_messages(history_msgs, user_text, images)

    # 1) Build prompt as TEXT (not tokens)
    prompt_text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )

    # 2) Tokenize → mapping with input_ids/attention_mask
    text_inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE)

    # 3) Vision tensors (pixel_values)
    vision_inputs = processor(images=images, return_tensors="pt").to(DEVICE)

    model_inputs = {
        "input_ids": text_inputs["input_ids"],
        **({"attention_mask": text_inputs["attention_mask"]} if "attention_mask" in text_inputs else {}),
        **({"pixel_values": vision_inputs["pixel_values"]} if "pixel_values" in vision_inputs else {}),
    }

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs = dict(
        **model_inputs,
        streamer=streamer,
        max_new_tokens=512,
        do_sample=False,  # keep deterministic for enterprise-grade UX
    )

    import threading
    t = threading.Thread(target=model.generate, kwargs=gen_kwargs)
    t.start()

    partial = ""
    for token in streamer:
        partial += token
        yield partial

# -----------------------------
# Gradio UI Orchestration
# -----------------------------
def start_chat(file_val, page_index):
    imgs = ensure_images(file_val)
    if not imgs:
        return (
            gr.update(choices=[], value=None),
            [],
            None,
            "No file loaded. Please upload a PDF/PNG/JPEG.",
        )
    choices = [f"Page {i+1}" for i in range(len(imgs))]
    safe_idx = 0 if page_index is None else max(0, min(len(imgs) - 1, int(page_index)))
    default_value = choices[safe_idx]
    return (
        gr.update(choices=choices, value=default_value),
        imgs,
        imgs[safe_idx],
        "Document ready. Select a page and ask questions.",
    )

def page_picker_changed(pages_dropdown, images_state):
    if not images_state:
        return None, gr.update()
    idx = parse_page_selection(pages_dropdown, len(images_state))
    selected = images_state[idx]
    return selected, selected  # preview + selected state

def chat(user_text, history_msgs, images_state, selected_img):
    if not user_text or not user_text.strip():
        return gr.update(), history_msgs

    sel_img = selected_img if selected_img is not None else (images_state[0] if images_state else None)
    if sel_img is None:
        history_msgs = history_msgs + [
            {"role": "user", "content": user_text},
            {"role": "assistant", "content": "Please upload a document first."},
        ]
        return gr.update(value=history_msgs), history_msgs

    stream = generate_reply([sel_img], user_text, history_msgs)
    acc = ""
    for chunk in stream:
        acc = chunk
        yield (
            history_msgs + [
                {"role": "user", "content": user_text},
                {"role": "assistant", "content": acc},
            ],
            history_msgs + [
                {"role": "user", "content": user_text},
                {"role": "assistant", "content": acc},
            ],
        )

# -----------------------------
# App definition
# -----------------------------
with gr.Blocks(title="Invoice Chat • SmolVLM-250M") as demo:
    gr.Markdown(
        "## Invoice Chat • SmolVLM-Instruct-250M\n"
        "Upload a PDF/PNG/JPEG, pick a page, and interrogate the document. "
        "Optimized for CPU-friendly, low-latency insights."
    )
    with gr.Row():
        with gr.Column(scale=1):
            file = gr.File(label="Upload invoice (PDF / PNG / JPEG)")
            pages = gr.Dropdown(
                label="Select page (for PDFs)",
                choices=[],
                value=None,
                allow_custom_value=True,
                info="Type a page number (e.g., 2) or choose from the list.",
            )
            load_btn = gr.Button("Prepare Document", variant="primary")
        with gr.Column(scale=2):
            image_view = gr.Image(label="Current page/image", interactive=False)

    # ✅ messages mode (no more tuples warnings)
    chatbot = gr.Chatbot(height=400, type="messages")
    user_box = gr.Textbox(
        label="Your question",
        placeholder="e.g., What is the invoice number and total with tax?",
    )
    ask_btn = gr.Button("Ask", variant="primary")

    # Session state
    images_state = gr.State([])
    selected_img_state = gr.State(None)

    # Events
    load_btn.click(
        start_chat,
        inputs=[file, gr.State(0)],
        outputs=[pages, images_state, image_view, gr.Textbox(visible=False)],
    )
    pages.change(
        page_picker_changed,
        inputs=[pages, images_state],
        outputs=[image_view, selected_img_state],
    )
    ask_btn.click(
        chat,
        inputs=[user_box, chatbot, images_state, selected_img_state],
        outputs=[chatbot, chatbot],
    )
    user_box.submit(
        chat,
        inputs=[user_box, chatbot, images_state, selected_img_state],
        outputs=[chatbot, chatbot],
    )

if __name__ == "__main__":
    demo.launch()