vithacocf commited on
Commit
9becdf5
·
verified ·
1 Parent(s): e2a9259

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os, io, re, json, time, mimetypes, tempfile, string
3
+ from typing import List, Union, Tuple, Any, Iterable
4
+
5
+ from PIL import Image
6
+ import pandas as pd
7
+ import gradio as gr
8
+ import google.generativeai as genai
9
+ import requests
10
+ import pdfplumber
11
+
12
+ # ================== CONFIG ==================
13
+ DEFAULT_API_KEY = "AIzaSyBbK-1P3JD6HPyE3QLhkOps6_-Xo3wUFbs"
14
+
15
+ INTERNAL_MODEL_MAP = {
16
+ "Gemini 2.5 Flash": "gemini-2.5-flash",
17
+ "Gemini 2.5 Pro": "gemini-2.5-pro",
18
+ }
19
+ EXTERNAL_MODEL_NAME = "prithivMLmods/Camel-Doc-OCR-062825 (External)"
20
+
21
+ try:
22
+ RESAMPLE = Image.Resampling.LANCZOS
23
+ except AttributeError:
24
+ RESAMPLE = Image.LANCZOS
25
+
26
+ PROMPT_FREIGHT_JSON = """
27
+ Please analyze the freight rate table in the file I provide and convert it into JSON in the following structure:
28
+ { ... } # (rút gọn lại vì bạn đã có)
29
+ """
30
+
31
+ # ================== HELPERS ==================
32
+ import fitz # PyMuPDF
33
+
34
+ def pdf_to_images(pdf_bytes: bytes) -> list[Image.Image]:
35
+ doc = fitz.open(stream=pdf_bytes, filetype="pdf")
36
+ pages = []
37
+ for p in doc:
38
+ pix = p.get_pixmap(dpi=200)
39
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
40
+ pages.append(img)
41
+ return pages
42
+
43
+ def ensure_rgb(im: Image.Image) -> Image.Image:
44
+ return im.convert("RGB") if im.mode != "RGB" else im
45
+
46
+ def _read_file_bytes(upload: Union[str, os.PathLike, dict, object] | None) -> bytes:
47
+ if upload is None:
48
+ raise ValueError("No file uploaded.")
49
+ if isinstance(upload, (str, os.PathLike)):
50
+ with open(upload, "rb") as f:
51
+ return f.read()
52
+ if isinstance(upload, dict) and "path" in upload:
53
+ with open(upload["path"], "rb") as f:
54
+ return f.read()
55
+ if hasattr(upload, "read"):
56
+ return upload.read()
57
+ raise TypeError(f"Unsupported file object: {type(upload)}")
58
+
59
+ def _guess_name_and_mime(file, file_bytes: bytes) -> Tuple[str, str]:
60
+ if isinstance(file, (str, os.PathLike)):
61
+ filename = os.path.basename(str(file))
62
+ elif isinstance(file, dict) and "name" in file:
63
+ filename = os.path.basename(file["name"])
64
+ elif isinstance(file, dict) and "path" in file:
65
+ filename = os.path.basename(file["path"])
66
+ else:
67
+ filename = "upload.bin"
68
+ mime, _ = mimetypes.guess_type(filename)
69
+ if not mime:
70
+ if len(file_bytes) >= 4 and file_bytes[:4] == b"%PDF":
71
+ mime = "application/pdf"
72
+ if not filename.lower().endswith(".pdf"):
73
+ filename += ".pdf"
74
+ else:
75
+ mime = "image/png"
76
+ return filename, mime
77
+
78
+ # ================== PDF CHECK STEP ==================
79
+ def check_pdf_structure(file_bytes: bytes) -> str:
80
+ """Kiểm tra nhanh file PDF có phải bảng nhiều cột, nhiều trang không."""
81
+ try:
82
+ with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
83
+ if len(pdf.pages) <= 2:
84
+ return "không"
85
+ table_pages = 0
86
+ for page in pdf.pages[:3]:
87
+ tables = page.find_tables()
88
+ if tables and len(tables) > 0:
89
+ table_pages += 1
90
+ if table_pages >= 1:
91
+ return "có"
92
+ text = "\n".join([(p.extract_text() or "") for p in pdf.pages[:2]])
93
+ num_tokens = sum(ch.isdigit() for ch in text)
94
+ line_count = len(text.splitlines())
95
+ if num_tokens > 100 and line_count > 20:
96
+ return "có"
97
+ return "không"
98
+ except Exception as e:
99
+ print("PDF check error:", e)
100
+ return "không"
101
+
102
+ # ================== OCR CORE (Gemini) ==================
103
+ def run_process_internal_base_v2(file_bytes, filename, mime, question, model_choice, temperature, top_p, batch_size=3):
104
+ api_key = os.environ.get("GOOGLE_API_KEY", DEFAULT_API_KEY)
105
+ if not api_key:
106
+ return "ERROR: Missing GOOGLE_API_KEY.", None
107
+ genai.configure(api_key=api_key)
108
+ model_name = INTERNAL_MODEL_MAP.get(model_choice, "gemini-2.5-flash")
109
+ model = genai.GenerativeModel(model_name=model_name,
110
+ generation_config={"temperature": float(temperature), "top_p": float(top_p)})
111
+
112
+ if file_bytes[:4] == b"%PDF":
113
+ pages = pdf_to_images(file_bytes)
114
+ else:
115
+ pages = [Image.open(io.BytesIO(file_bytes))]
116
+
117
+ user_prompt = (question or "").strip() or PROMPT_FREIGHT_JSON
118
+ all_json_results, all_text_results = [], []
119
+ previous_header_json = None
120
+
121
+ def _safe_text(resp):
122
+ try:
123
+ return resp.text
124
+ except:
125
+ return ""
126
+
127
+ for i in range(0, len(pages), batch_size):
128
+ batch = pages[i:i+batch_size]
129
+ uploaded = []
130
+ for im in batch:
131
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
132
+ im.save(tmp.name)
133
+ up = genai.upload_file(path=tmp.name, mime_type="image/png")
134
+ up = genai.get_file(up.name)
135
+ uploaded.append(up)
136
+
137
+ context_prompt = user_prompt
138
+ resp = model.generate_content([context_prompt] + uploaded)
139
+ text = _safe_text(resp)
140
+ all_text_results.append(text)
141
+ for up in uploaded:
142
+ try:
143
+ genai.delete_file(up.name)
144
+ except:
145
+ pass
146
+
147
+ return "\n\n".join(all_text_results), None
148
+
149
+ # ================== EXTERNAL API (nếu có) ==================
150
+ def run_process_external(file_bytes, filename, mime, question, api_url, temperature, top_p):
151
+ if not api_url:
152
+ return "ERROR: Missing external API endpoint.", None
153
+ data = {"prompt": question or "", "temperature": str(temperature), "top_p": str(top_p)}
154
+ files = {"file": (filename, file_bytes, mime)}
155
+ r = requests.post(api_url, files=files, data=data, timeout=60)
156
+ if r.status_code >= 400:
157
+ return f"ERROR: External API HTTP {r.status_code}: {r.text[:200]}", None
158
+ return r.text, None
159
+
160
+ # ================== MAIN ROUTER (đã thêm STEP CHECK) ==================
161
+ def run_process(file, question, model_choice, temperature, top_p, external_api_url):
162
+ """
163
+ Router (có bước kiểm tra PDF/table trước khi xử lý):
164
+ - Nếu PDF nhiều trang/nhiều bảng -> extract trước (pdfplumber)
165
+ - Ngược lại -> OCR trực tiếp Gemini
166
+ """
167
+ try:
168
+ if file is None:
169
+ return "ERROR: No file uploaded.", None
170
+
171
+ file_bytes = _read_file_bytes(file)
172
+ filename, mime = _guess_name_and_mime(file, file_bytes)
173
+
174
+ # STEP 1️⃣: Check PDF structure
175
+ if mime == "application/pdf" or file_bytes[:4] == b"%PDF":
176
+ check_result = check_pdf_structure(file_bytes)
177
+ print(f"[PDF Check] {filename}: {check_result}")
178
+
179
+ if check_result == "có":
180
+ print("➡️ PDF có nhiều cột/nhiều trang → dùng pdfplumber extract trước rồi Gemini.")
181
+ try:
182
+ tables_all = []
183
+ with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
184
+ for page in pdf.pages:
185
+ for tb in page.extract_tables():
186
+ if not tb or len(tb) < 2:
187
+ continue
188
+ header = tb[0]
189
+ df = pd.DataFrame(tb[1:], columns=header)
190
+ tables_all.append(df)
191
+ if tables_all:
192
+ df_all = pd.concat(tables_all, ignore_index=True)
193
+ table_text = df_all.to_csv(index=False)
194
+ question = (
195
+ f"{PROMPT_FREIGHT_JSON}\n"
196
+ "Below is the table text extracted from the PDF (CSV format):\n"
197
+ f"{table_text}\n\n"
198
+ "Please convert this into valid JSON as per the schema."
199
+ )
200
+ except Exception as e:
201
+ print("pdfplumber extract failed:", e)
202
+
203
+ # STEP 2️⃣: Route model
204
+ if model_choice == EXTERNAL_MODEL_NAME:
205
+ return run_process_external(
206
+ file_bytes=file_bytes, filename=filename, mime=mime,
207
+ question=question, api_url=external_api_url,
208
+ temperature=temperature, top_p=top_p
209
+ )
210
+
211
+ return run_process_internal_base_v2(
212
+ file_bytes=file_bytes, filename=filename, mime=mime,
213
+ question=question, model_choice=model_choice,
214
+ temperature=temperature, top_p=top_p
215
+ )
216
+
217
+ except Exception as e:
218
+ return f"ERROR: {type(e).__name__}: {str(e)}", None
219
+
220
+ # ================== UI ==================
221
+ def main():
222
+ with gr.Blocks(title="OCR Multi-Agent System") as demo:
223
+ file = gr.File(label="Upload PDF/Image")
224
+ question = gr.Textbox(label="Prompt", lines=2)
225
+ model_choice = gr.Dropdown(choices=[*INTERNAL_MODEL_MAP.keys(), EXTERNAL_MODEL_NAME],
226
+ value="Gemini 2.5 Flash", label="Model")
227
+ temperature = gr.Slider(0.0, 2.0, value=0.2, step=0.05)
228
+ top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01)
229
+ external_api_url = gr.Textbox(label="External API URL", visible=False)
230
+ output_text = gr.Code(label="Output", language="json")
231
+ run_btn = gr.Button("🚀 Process")
232
+
233
+ run_btn.click(
234
+ run_process,
235
+ inputs=[file, question, model_choice, temperature, top_p, external_api_url],
236
+ outputs=[output_text, gr.State()]
237
+ )
238
+
239
+ return demo
240
+
241
+ demo = main()
242
+
243
+ if __name__ == "__main__":
244
+ demo.launch()