BIBLETUM commited on
Commit
5a246b8
·
verified ·
1 Parent(s): f6c6193

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +304 -0
app.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import numpy as np
4
+ import gradio as gr
5
+ import soundfile as sf
6
+ from PIL import Image
7
+
8
+ import fitz # PyMuPDF
9
+
10
+ import torch
11
+ from transformers import (
12
+ pipeline,
13
+ DonutProcessor,
14
+ VisionEncoderDecoderModel,
15
+ AutoTokenizer,
16
+ AutoModelForSeq2SeqLM,
17
+ )
18
+ from sentence_transformers import SentenceTransformer
19
+
20
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
22
+
23
+ WHISPER_MODEL = os.getenv("WHISPER_MODEL", "openai/whisper-tiny")
24
+ DONUT_MODEL = os.getenv("DONUT_MODEL", "naver-clova-ix/donut-base-finetuned-docvqa")
25
+ T5_MODEL = os.getenv("T5_MODEL", "google/flan-t5-small")
26
+ EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
27
+
28
+ MAX_PAGES = int(os.getenv("MAX_PAGES", "2"))
29
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1600")) # px
30
+ TOPK = int(os.getenv("TOPK", "5"))
31
+
32
+ # ---------- Models ----------
33
+ asr = pipeline(
34
+ task="automatic-speech-recognition",
35
+ model=WHISPER_MODEL,
36
+ device=0 if DEVICE == "cuda" else -1,
37
+ )
38
+
39
+ donut_processor = DonutProcessor.from_pretrained(DONUT_MODEL)
40
+ donut_model = VisionEncoderDecoderModel.from_pretrained(DONUT_MODEL).to(DEVICE)
41
+ donut_model.eval()
42
+
43
+ t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL)
44
+ t5_model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL).to(DEVICE)
45
+ t5_model.eval()
46
+
47
+ embedder = SentenceTransformer(EMB_MODEL, device=DEVICE)
48
+
49
+
50
+ # ---------- Utils ----------
51
+ def _resize_max(im: Image.Image, max_side: int) -> Image.Image:
52
+ w, h = im.size
53
+ m = max(w, h)
54
+ if m <= max_side:
55
+ return im
56
+ scale = max_side / float(m)
57
+ nw, nh = max(1, int(w * scale)), max(1, int(h * scale))
58
+ return im.resize((nw, nh), Image.BICUBIC)
59
+
60
+
61
+ def load_document_to_images(file_path: str, max_pages: int = MAX_PAGES) -> list[Image.Image]:
62
+ if not file_path:
63
+ return []
64
+ ext = (os.path.splitext(file_path)[1] or "").lower()
65
+
66
+ if ext in [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"]:
67
+ im = Image.open(file_path).convert("RGB")
68
+ return [_resize_max(im, MAX_IMAGE_SIZE)]
69
+
70
+ if ext == ".pdf":
71
+ doc = fitz.open(file_path)
72
+ imgs = []
73
+ pages = min(len(doc), max_pages)
74
+ for i in range(pages):
75
+ page = doc.load_page(i)
76
+ pix = page.get_pixmap(alpha=False)
77
+ im = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
78
+ imgs.append(_resize_max(im, MAX_IMAGE_SIZE))
79
+ doc.close()
80
+ return imgs
81
+
82
+ return []
83
+
84
+
85
+ def donut_docvqa(image: Image.Image, question: str, max_new_tokens: int = 64) -> str:
86
+ if image is None or not (question or "").strip():
87
+ return ""
88
+ q = question.strip()
89
+ prompt = f"<s_docvqa><s_question>{q}</s_question><s_answer>"
90
+ inputs = donut_processor(image, prompt, return_tensors="pt")
91
+
92
+ pixel_values = inputs.pixel_values.to(DEVICE, dtype=DTYPE)
93
+ decoder_input_ids = inputs.decoder_input_ids.to(DEVICE)
94
+
95
+ with torch.inference_mode():
96
+ out = donut_model.generate(
97
+ pixel_values=pixel_values,
98
+ decoder_input_ids=decoder_input_ids,
99
+ max_new_tokens=max_new_tokens,
100
+ pad_token_id=donut_processor.tokenizer.pad_token_id,
101
+ eos_token_id=donut_processor.tokenizer.eos_token_id,
102
+ bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
103
+ )
104
+
105
+ text = donut_processor.batch_decode(out, skip_special_tokens=True)[0]
106
+ text = re.sub(r"\s+", " ", text).strip()
107
+ return text
108
+
109
+
110
+ def t5_summarize(text: str, max_new_tokens: int = 128) -> str:
111
+ t = (text or "").strip()
112
+ if not t:
113
+ return ""
114
+ prompt = f"Summarize this document briefly:\n{t}"
115
+ inputs = t5_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
116
+ with torch.inference_mode():
117
+ out = t5_model.generate(
118
+ **inputs,
119
+ max_new_tokens=max_new_tokens,
120
+ do_sample=False,
121
+ num_beams=2,
122
+ )
123
+ return t5_tokenizer.decode(out[0], skip_special_tokens=True).strip()
124
+
125
+
126
+ def embed_text(text: str) -> np.ndarray:
127
+ v = embedder.encode([text or ""], normalize_embeddings=True)[0]
128
+ return np.asarray(v, dtype=np.float32)
129
+
130
+
131
+ def cos_sim_matrix(query_vec: np.ndarray, mat: np.ndarray) -> np.ndarray:
132
+ # vectors already normalized -> dot is cosine
133
+ return mat @ query_vec
134
+
135
+
136
+ def format_kv(items: list[tuple[str, str]]) -> str:
137
+ lines = []
138
+ for k, v in items:
139
+ v = (v or "").strip()
140
+ if v:
141
+ lines.append(f"{k}: {v}")
142
+ return "\n".join(lines).strip()
143
+
144
+
145
+ # ---------- App State ----------
146
+ # archive_state: list[dict] where dict contains:
147
+ # { "name": str, "text": str, "vec": np.ndarray }
148
+ def ensure_state(archive_state):
149
+ if archive_state is None:
150
+ return []
151
+ return archive_state
152
+
153
+
154
+ # ---------- Actions ----------
155
+ DEFAULT_FIELDS = [
156
+ ("amount", "What is the total amount to pay? Return only the amount."),
157
+ ("due_date", "What is the due date? Return only the date."),
158
+ ("period", "What is the billing period?"),
159
+ ("recipient", "Who is the recipient/payee?"),
160
+ ("account", "What is the account / invoice number?"),
161
+ ]
162
+
163
+
164
+ def act_extract(file_obj, pages, archive_state):
165
+ archive_state = ensure_state(archive_state)
166
+ if not file_obj:
167
+ return None, None, "", "", archive_state
168
+
169
+ images = load_document_to_images(file_obj, max_pages=MAX_PAGES)
170
+ if not images:
171
+ return None, None, "", "", archive_state
172
+
173
+ first = images[0]
174
+ answers = []
175
+ for name, q in DEFAULT_FIELDS:
176
+ a = donut_docvqa(first, q)
177
+ answers.append((name, a))
178
+
179
+ extracted = format_kv(answers)
180
+ summary = t5_summarize(extracted)
181
+ page_gallery = images
182
+ return first, page_gallery, extracted, summary, archive_state
183
+
184
+
185
+ def act_ask(file_obj, question, use_audio_text, audio_text):
186
+ q = (question or "").strip()
187
+ if use_audio_text and (audio_text or "").strip():
188
+ q = (audio_text or "").strip()
189
+ if not file_obj or not q:
190
+ return ""
191
+
192
+ images = load_document_to_images(file_obj, max_pages=MAX_PAGES)
193
+ if not images:
194
+ return ""
195
+ return donut_docvqa(images[0], q)
196
+
197
+
198
+ def act_transcribe(audio_path):
199
+ if not audio_path:
200
+ return ""
201
+ data, sr = sf.read(audio_path)
202
+ if data.ndim > 1:
203
+ data = data.mean(axis=1)
204
+ out = asr({"raw": data, "sampling_rate": sr})
205
+ if isinstance(out, dict) and "text" in out:
206
+ return (out["text"] or "").strip()
207
+ return str(out).strip()
208
+
209
+
210
+ def act_add_to_archive(file_obj, extracted, summary, archive_state):
211
+ archive_state = ensure_state(archive_state)
212
+ if not file_obj:
213
+ return archive_state, "0"
214
+ name = os.path.basename(file_obj)
215
+
216
+ payload = "\n".join([t for t in [name, extracted or "", summary or ""] if (t or "").strip()]).strip()
217
+ if not payload:
218
+ payload = name
219
+
220
+ vec = embed_text(payload)
221
+ archive_state.append({"name": name, "text": payload, "vec": vec})
222
+ return archive_state, str(len(archive_state))
223
+
224
+
225
+ def act_search_archive(query, archive_state):
226
+ archive_state = ensure_state(archive_state)
227
+ q = (query or "").strip()
228
+ if not q or not archive_state:
229
+ return ""
230
+
231
+ qv = embed_text(q)
232
+ mat = np.vstack([it["vec"] for it in archive_state]).astype(np.float32)
233
+ sims = cos_sim_matrix(qv, mat)
234
+ idx = np.argsort(-sims)[: min(TOPK, len(archive_state))]
235
+
236
+ lines = []
237
+ for rank, i in enumerate(idx, start=1):
238
+ it = archive_state[int(i)]
239
+ s = float(sims[int(i)])
240
+ lines.append(f"{rank}. [{s:.3f}] {it['name']}\n{it['text'][:600]}")
241
+ return "\n\n".join(lines).strip()
242
+
243
+
244
+ # ---------- UI ----------
245
+ with gr.Blocks(title="DocuVoice Assistant (MVP)") as demo:
246
+ archive_state = gr.State([])
247
+
248
+ with gr.Row():
249
+ file_in = gr.File(label="PDF/Image", file_types=[".pdf", ".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"])
250
+
251
+ with gr.Tabs():
252
+ with gr.Tab("Document"):
253
+ with gr.Row():
254
+ btn_extract = gr.Button("Extract + Summarize", variant="primary")
255
+ btn_add = gr.Button("Add to Archive")
256
+ with gr.Row():
257
+ img_preview = gr.Image(label="Preview (page 1)", type="pil")
258
+ pages_gallery = gr.Gallery(label="Pages", columns=3, height=280, preview=True)
259
+ with gr.Row():
260
+ extracted_out = gr.Textbox(label="Extracted (Donut Q&A)", lines=8)
261
+ summary_out = gr.Textbox(label="Summary (T5)", lines=8)
262
+ with gr.Row():
263
+ question_in = gr.Textbox(label="Question", lines=2, placeholder="Ask about the document...")
264
+ with gr.Row():
265
+ use_audio = gr.Checkbox(label="Use transcribed audio as question", value=False)
266
+ with gr.Row():
267
+ btn_ask = gr.Button("Ask (Donut DocVQA)")
268
+ answer_out = gr.Textbox(label="Answer", lines=6)
269
+
270
+ with gr.Row():
271
+ archive_count = gr.Textbox(label="Archive size", value="0", interactive=False)
272
+
273
+ with gr.Tab("Voice"):
274
+ audio_in = gr.Audio(label="Audio", sources=["microphone", "upload"], type="filepath")
275
+ btn_asr = gr.Button("Transcribe (Whisper)", variant="primary")
276
+ transcript_out = gr.Textbox(label="Transcript", lines=4)
277
+ btn_asr.click(act_transcribe, inputs=[audio_in], outputs=[transcript_out])
278
+
279
+ with gr.Tab("Archive"):
280
+ query_in = gr.Textbox(label="Search query", lines=2, placeholder="e.g., electricity bill October")
281
+ btn_search = gr.Button("Search (Embeddings)")
282
+ results_out = gr.Textbox(label="Results", lines=16)
283
+ btn_search.click(act_search_archive, inputs=[query_in, archive_state], outputs=[results_out])
284
+
285
+ btn_extract.click(
286
+ act_extract,
287
+ inputs=[file_in, pages_gallery, archive_state],
288
+ outputs=[img_preview, pages_gallery, extracted_out, summary_out, archive_state],
289
+ )
290
+
291
+ btn_ask.click(
292
+ act_ask,
293
+ inputs=[file_in, question_in, use_audio, transcript_out],
294
+ outputs=[answer_out],
295
+ )
296
+
297
+ btn_add.click(
298
+ act_add_to_archive,
299
+ inputs=[file_in, extracted_out, summary_out, archive_state],
300
+ outputs=[archive_state, archive_count],
301
+ )
302
+
303
+ if __name__ == "__main__":
304
+ demo.launch()