Spaces:
Paused
Paused
| from __future__ import annotations | |
| import os, io, re, json, time, mimetypes, tempfile, string | |
| from typing import List, Union, Tuple, Any, Iterable | |
| from PIL import Image | |
| import pandas as pd | |
| import gradio as gr | |
| import google.generativeai as genai | |
| import requests | |
| # ================== CONFIG ================== | |
| # KHÔNG hardcode key. YÊU CẦU đặt biến môi trường GOOGLE_API_KEY. | |
| DEFAULT_API_KEY = "AIzaSyCwyYCNqWWA7jqcc5WAG5jQhnGdWKslD4o" # để trống. Nếu cần, bạn có thể set tạm thời ở ENV. | |
| INTERNAL_MODEL_MAP = { | |
| "Gemini 2.5 Flash": "gemini-2.5-flash", | |
| "Gemini 2.5 Pro": "gemini-2.5-pro", | |
| } | |
| EXTERNAL_MODEL_NAME = "prithivMLmods/Camel-Doc-OCR-062825 (External)" | |
| try: | |
| RESAMPLE = Image.Resampling.LANCZOS # Pillow >= 10 | |
| except AttributeError: | |
| RESAMPLE = Image.LANCZOS # Pillow < 10 | |
| # ================== HELPERS ================== | |
| import fitz # PyMuPDF | |
| def pdf_to_images(pdf_bytes: bytes) -> list[Image.Image]: | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| pages = [] | |
| for p in doc: | |
| pix = p.get_pixmap(dpi=200) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| pages.append(img) | |
| return pages | |
| def ensure_rgb(im: Image.Image) -> Image.Image: | |
| return im.convert("RGB") if im.mode != "RGB" else im | |
| def _read_file_bytes(upload: Union[str, os.PathLike, dict, object] | None) -> bytes: | |
| if upload is None: | |
| raise ValueError("No file uploaded.") | |
| if isinstance(upload, (str, os.PathLike)): | |
| with open(upload, "rb") as f: | |
| return f.read() | |
| if isinstance(upload, dict) and "path" in upload: | |
| with open(upload["path"], "rb") as f: | |
| return f.read() | |
| if hasattr(upload, "read"): | |
| return upload.read() | |
| raise TypeError(f"Unsupported file object: {type(upload)}") | |
| def _make_previews(file_bytes: bytes, max_side: int = 2000) -> List[Image.Image]: | |
| """Trả list PIL.Image đã RGB + resize theo max_side.""" | |
| if len(file_bytes) >= 4 and file_bytes[:4] == b"%PDF": | |
| pages = pdf_to_images(file_bytes) | |
| else: | |
| pages = [Image.open(io.BytesIO(file_bytes))] | |
| out = [] | |
| for im in pages: | |
| im = ensure_rgb(im) | |
| if max_side: | |
| w, h = im.size | |
| scale = min(max_side / float(w), max_side / float(h), 1.0) | |
| if scale < 1.0: | |
| im = im.resize((max(1, int(w*scale)), max(1, int(h*scale))), RESAMPLE) | |
| out.append(im) | |
| return out | |
| def _guess_name_and_mime(file, file_bytes: bytes) -> Tuple[str, str]: | |
| if isinstance(file, (str, os.PathLike)): | |
| filename = os.path.basename(str(file)) | |
| elif isinstance(file, dict) and "name" in file: | |
| filename = os.path.basename(file["name"]) | |
| elif isinstance(file, dict) and "path" in file: | |
| filename = os.path.basename(file["path"]) | |
| else: | |
| filename = "upload.bin" | |
| mime, _ = mimetypes.guess_type(filename) | |
| if not mime: | |
| if len(file_bytes) >= 4 and file_bytes[:4] == b"%PDF": | |
| mime = "application/pdf" | |
| if not filename.lower().endswith(".pdf"): | |
| filename += ".pdf" | |
| else: | |
| mime = "image/png" | |
| return filename, mime | |
| def _extract_json_from_message(msg: str): | |
| """Bóc JSON trong ```json ...``` nếu có. Trả về (obj, cleaned_string).""" | |
| s = (msg or "").strip() | |
| s = re.sub(r"^\s*```(?:json)?\s*", "", s, flags=re.IGNORECASE) | |
| s = re.sub(r"\s*```\s*$", "", s) | |
| try: | |
| return json.loads(s), s | |
| except Exception: | |
| return None, s | |
| def _pretty_message(msg: str) -> str: | |
| obj, s = _extract_json_from_message(msg) | |
| return json.dumps(obj, ensure_ascii=False, indent=2) if obj is not None else s | |
| def _safe_text_from_gemini(resp): | |
| try: | |
| return resp.text | |
| except Exception: | |
| pass | |
| texts = [] | |
| for c in getattr(resp, "candidates", []) or []: | |
| content = getattr(c, "content", None) | |
| parts = getattr(content, "parts", None) if content else None | |
| if not parts: | |
| continue | |
| for p in parts: | |
| t = getattr(p, "text", None) | |
| if t: | |
| texts.append(t) | |
| return "\n".join(texts).strip() | |
| def _wait_file_active(file_obj, timeout_s: int = 60) -> object: | |
| """Chờ file upload sang Gemini ở trạng thái ACTIVE, có timeout + backoff.""" | |
| start = time.time() | |
| delay = 0.5 | |
| while hasattr(file_obj, "state") and getattr(file_obj.state, "name", "") == "PROCESSING": | |
| if time.time() - start > timeout_s: | |
| raise TimeoutError("Upload processing timeout.") | |
| time.sleep(delay) | |
| delay = min(delay * 1.5, 2.0) | |
| file_obj = genai.get_file(file_obj.name) | |
| if not hasattr(file_obj, "state") or file_obj.state.name != "ACTIVE": | |
| st = getattr(file_obj, "state", None) | |
| raise RuntimeError(f"Upload failed or not active. State={getattr(st, 'name', 'UNKNOWN')}") | |
| return file_obj | |
| # ---------- JSON → Excel (schema-agnostic) ---------- | |
| def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict: | |
| """Flatten dict lồng nhau thành 1 level: {'a':{'b':1}} -> {'a.b':1}""" | |
| items = [] | |
| for k, v in (d or {}).items(): | |
| new_key = f"{parent_key}{sep}{k}" if parent_key else str(k) | |
| if isinstance(v, dict): | |
| items.extend(_flatten_dict(v, new_key, sep=sep).items()) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| def _sanitize_sheet_name(name: str, used: set[str]) -> str: | |
| # Excel sheet name ≤ 31 chars, không chứa []:*?/\ | |
| invalid = set(r'[]:*?/\'' + '"') | |
| clean = "".join(ch for ch in name if ch not in invalid) | |
| clean = clean.strip() | |
| if not clean: | |
| clean = "sheet" | |
| clean = clean[:31] | |
| # đảm bảo unique | |
| base, idx = clean, 1 | |
| while clean in used: | |
| suffix = f"_{idx}" | |
| clean = (base[: (31 - len(suffix))] + suffix) | |
| idx += 1 | |
| used.add(clean) | |
| return clean | |
| def _to_excel_generic(data: Any, path: str) -> str: | |
| """ | |
| Quy tắc: | |
| - Nếu là list[dict] -> 1 sheet "data" (json_normalize) | |
| - Nếu là dict: | |
| + Tạo 1 sheet "summary" từ các field dạng scalar/dict (flatten) | |
| + Với mỗi field là list: | |
| · list[dict] -> 1 sheet theo tên key (normalize) | |
| · list[scalar]-> 1 sheet 1 cột 'value' | |
| · list[mixed] -> chuyển thành cột 'value' dạng chuỗi | |
| """ | |
| with pd.ExcelWriter(path) as writer: | |
| used_names = set() | |
| def add_df(df: pd.DataFrame, sheet: str): | |
| sheetname = _sanitize_sheet_name(sheet, used_names) | |
| df.to_excel(writer, index=False, sheet_name=sheetname) | |
| if isinstance(data, list): | |
| # list tổng quát | |
| try: | |
| df = pd.json_normalize(data, sep=".") | |
| except Exception: | |
| df = pd.DataFrame({"value": [json.dumps(x, ensure_ascii=False) for x in data]}) | |
| add_df(df, "data") | |
| return path | |
| if isinstance(data, dict): | |
| scalars = {} | |
| list_sheets: list[tuple[str, pd.DataFrame]] = [] | |
| for k, v in data.items(): | |
| if isinstance(v, list): | |
| if len(v) == 0: | |
| list_sheets.append((k, pd.DataFrame())) | |
| elif isinstance(v[0], dict): | |
| try: | |
| df = pd.json_normalize(v, sep=".") | |
| except Exception: | |
| df = pd.DataFrame({"value": [json.dumps(x, ensure_ascii=False) for x in v]}) | |
| list_sheets.append((k, df)) | |
| elif not isinstance(v[0], (list, dict)): | |
| df = pd.DataFrame({"value": v}) | |
| list_sheets.append((k, df)) | |
| else: | |
| df = pd.DataFrame({"value": [json.dumps(x, ensure_ascii=False) for x in v]}) | |
| list_sheets.append((k, df)) | |
| elif isinstance(v, dict): | |
| scalars.update(_flatten_dict({k: v})) | |
| else: | |
| scalars[k] = v | |
| # summary sheet | |
| if len(scalars) > 0: | |
| add_df(pd.DataFrame([scalars]), "summary") | |
| # each list -> one sheet | |
| for k, df in list_sheets: | |
| add_df(df, k if k else "list") | |
| # nếu dict chỉ có list, không có summary => vẫn OK (chỉ có các sheet list) | |
| return path | |
| # kiểu khác: ghi thành 1 cột value | |
| add_df(pd.DataFrame({"value": [json.dumps(data, ensure_ascii=False)]}), "data") | |
| return path | |
| # ================== HANDLERS ================== | |
| def preview_process(file): | |
| """Trả list đường dẫn ảnh PNG tạm cho Gallery (ổn định hơn list PIL).""" | |
| if file is None: | |
| return [] | |
| try: | |
| file_bytes = _read_file_bytes(file) | |
| images = _make_previews(file_bytes, max_side=2000) | |
| paths = [] | |
| for i, im in enumerate(images): | |
| fd, path = tempfile.mkstemp(suffix=f"_preview_{i}.png") | |
| os.close(fd) | |
| im.save(path, format="PNG") | |
| paths.append(path) | |
| return paths | |
| except Exception as e: | |
| print(f"Preview error: {e}") | |
| return [] | |
| # -------- Internal (Gemini) - Base (1 lượt, không thinking) -------- | |
| def run_process_internal_base(file_bytes, filename, mime, question, model_choice, | |
| temperature, top_p): | |
| api_key = os.environ.get("GOOGLE_API_KEY", DEFAULT_API_KEY) | |
| if not api_key: | |
| return "ERROR: Missing GOOGLE_API_KEY.", None | |
| genai.configure(api_key=api_key) | |
| model_name = INTERNAL_MODEL_MAP.get(model_choice, INTERNAL_MODEL_MAP["Gemini 2.5 Flash"]) | |
| gen_config = {"temperature": float(temperature), "top_p": float(top_p)} | |
| model = genai.GenerativeModel(model_name=model_name, generation_config=gen_config) | |
| uploaded = None | |
| tmp_path = None | |
| try: | |
| if file_bytes: | |
| suffix = os.path.splitext(filename)[1] or ".bin" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| tmp.write(file_bytes) | |
| tmp_path = tmp.name | |
| uploaded = genai.upload_file(path=tmp_path, mime_type=mime) | |
| uploaded = _wait_file_active(uploaded, timeout_s=60) | |
| user_prompt = (question or "").strip() | |
| if not user_prompt: | |
| user_prompt = ( | |
| "Perform high-quality OCR on the provided file. If PDF: read all pages in order. " | |
| "Return clean plain text. If structure is obvious (tables, key:value), preserve it. " | |
| "If you can, output JSON that captures the structure." | |
| ) | |
| # Gọi model | |
| if uploaded: | |
| resp = model.generate_content([user_prompt, uploaded]) | |
| else: | |
| resp = model.generate_content(user_prompt) | |
| # Lấy đúng message LLM (pretty nếu là JSON) | |
| answer_raw = _safe_text_from_gemini(resp) | |
| message = _pretty_message(answer_raw) | |
| # Parse JSON (nếu có) để export. Không validate schema. | |
| parsed_obj, _ = _extract_json_from_message(answer_raw) | |
| return message, parsed_obj | |
| finally: | |
| if tmp_path and os.path.exists(tmp_path): | |
| try: os.remove(tmp_path) | |
| except Exception: pass | |
| try: | |
| if uploaded and hasattr(uploaded, "name"): | |
| genai.delete_file(uploaded.name) | |
| except Exception: | |
| pass | |
| # -------- External API -------- | |
| def run_process_external(file_bytes, filename, mime, question, api_url, | |
| temperature, top_p): | |
| if not api_url or not str(api_url).strip(): | |
| return "ERROR: Missing external API endpoint (hãy dán URL).", None | |
| try: | |
| user_prompt = (question or "").strip() | |
| if not user_prompt: | |
| user_prompt = ( | |
| "Perform high-quality OCR on the provided file. If PDF: read all pages in order. " | |
| "Return clean plain text. If structure is obvious (tables, key:value), preserve it. " | |
| "If you can, output JSON that captures the structure." | |
| ) | |
| data = {"prompt": user_prompt, "temperature": str(temperature), "top_p": str(top_p)} | |
| if file_bytes: | |
| files = {"file": (filename, file_bytes, mime)} | |
| r = requests.post(api_url, files=files, data=data, timeout=60) | |
| else: | |
| r = requests.post(api_url, json=data, timeout=60) | |
| if r.status_code >= 400: | |
| return f"ERROR: External API HTTP {r.status_code}: {r.text[:300]}", None | |
| answer = None | |
| try: | |
| j = r.json() | |
| answer = j.get("message") or j.get("text") or j.get("data") | |
| if isinstance(answer, (dict, list)): | |
| answer = json.dumps(answer, ensure_ascii=False) | |
| except Exception: | |
| answer = r.text | |
| answer = (answer or "").strip() | |
| message = _pretty_message(answer) | |
| parsed_obj, _ = _extract_json_from_message(answer) | |
| return message, parsed_obj | |
| except Exception as e: | |
| return f"ERROR: {type(e).__name__}: {str(e) or repr(e)}", None | |
| # -------- Router -------- | |
| def run_process(file, question, model_choice, temperature, top_p, external_api_url): | |
| """ | |
| Router (không Agent, không thinking): | |
| - Nếu chọn External model -> run_process_external | |
| - Ngược lại -> Gemini nội bộ (Base 1 lượt) | |
| """ | |
| try: | |
| has_file = file is not None | |
| file_bytes = filename = mime = None | |
| if has_file: | |
| file_bytes = _read_file_bytes(file) | |
| filename, mime = _guess_name_and_mime(file, file_bytes) | |
| if model_choice == EXTERNAL_MODEL_NAME: | |
| return run_process_external( | |
| file_bytes=file_bytes, filename=filename, mime=mime, | |
| question=question, api_url=external_api_url, | |
| temperature=temperature, top_p=top_p | |
| ) | |
| return run_process_internal_base( | |
| file_bytes=file_bytes, filename=filename, mime=mime, | |
| question=question, model_choice=model_choice, | |
| temperature=temperature, top_p=top_p | |
| ) | |
| except Exception as e: | |
| return f"ERROR: {type(e).__name__}: {str(e) or repr(e)}", None | |
| def on_export_excel(parsed_obj): | |
| try: | |
| if not parsed_obj: | |
| # không có JSON để export → giữ nguyên, không hiện nút tải | |
| return gr.update(value=None, visible=False) | |
| # tạo file an toàn, giữ lại sau khi request kết thúc | |
| fd, tmp_path = tempfile.mkstemp(suffix=".xlsx") | |
| os.close(fd) | |
| _to_excel_generic(parsed_obj, tmp_path) | |
| # trả về path và bật visible để hiện link download | |
| return gr.update(value=tmp_path, visible=True) | |
| except Exception as e: | |
| print(f"Export error: {e}") | |
| return gr.update(value=None, visible=False) | |
| def clear_all(): | |
| # file, preview, output_text, question, model, parsed_state, download, | |
| # temperature, top_p, external_api_url | |
| return ( | |
| None, [], "", "", | |
| "Gemini 2.5 Flash", None, None, | |
| 0.2, 0.95, "" | |
| ) | |
| # ================== UI ================== | |
| def _toggle_external_visibility(selected: str): | |
| return gr.update(visible=(selected == EXTERNAL_MODEL_NAME)) | |
| def main(): | |
| custom_css = """ | |
| .gradio-container { max-width: 1400px !important; margin: 0 auto; } | |
| #main-row { display: flex; gap: 20px; align-items: flex-start; } | |
| #left-column { flex: 1; min-width: 400px; max-width: 600px; } | |
| #right-column { flex: 1; min-width: 400px; } | |
| #file-upload { border: 2px dashed #d1d5db; border-radius: 12px; padding: 20px; text-align: center; transition: border-color 0.3s ease; } | |
| #file-upload:hover { border-color: #3b82f6; } | |
| #preview-gallery { max-height: 600px; overflow-y: auto; border: 1px solid #e5e7eb; border-radius: 12px; background: #f9fafb; padding: 10px; } | |
| #preview-gallery .grid { grid-template-columns: 1fr !important; gap: 10px !important; } | |
| #preview-gallery img { width: 100% !important; height: auto !important; object-fit: contain !important; background: white; } | |
| #controls-section { background: #f8fafc; padding: 20px; border-radius: 12px; margin-bottom: 20px; } | |
| #results-section { background: #ffffff; border: 1px solid #e5e7eb; border-radius: 12px; padding: 20px; } | |
| #llm-output { max-height: 500px; overflow-y: auto; font-family: monospace; font-size: 13px; } | |
| .primary-button { background: linear-gradient(90deg, #3b82f6, #1d4ed8) !important; color: white !important; border: none !important; border-radius: 8px !important; padding: 10px 20px !important; font-weight: 500 !important; } | |
| .primary-button:hover { transform: translateY(-1px) !important; box-shadow: 0 4px 12px rgba(59, 130, 246, 0.3) !important; } | |
| .secondary-button { background: #f3f4f6 !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-radius: 8px !important; padding: 8px 16px !important; } | |
| @media (max-width: 1024px) { #main-row { flex-direction: column; } #left-column, #right-column { min-width: 100%; max-width: 100%; } } | |
| """ | |
| with gr.Blocks(title="OCR Multi-Agent System", css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px 0; margin-bottom: 30px;"> | |
| <h1 style="color:#1f2937; font-size: 2.5rem; font-weight: bold; margin-bottom: 8px;">📄 OCR Extraction (LLM-first)</h1> | |
| <p style="color:#6b7280; font-size: 1.1rem; margin: 0;">Upload PDF/images → LLM produces raw text/JSON → Export Excel (schema-agnostic)</p> | |
| </div> | |
| """) | |
| last_parsed_state = gr.State(value=None) | |
| with gr.Row(elem_id="main-row"): | |
| # Left | |
| with gr.Column(elem_id="left-column"): | |
| gr.Markdown("### 📁 Upload Document") | |
| file = gr.File( | |
| label="Choose PDF or Image file", | |
| file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"], | |
| type="filepath", | |
| elem_id="file-upload" | |
| ) | |
| gr.Markdown("### 👁️ Document Preview") | |
| preview = gr.Gallery(columns=1, height=None, show_label=False, elem_id="preview-gallery", allow_preview=True) | |
| # Right | |
| with gr.Column(elem_id="right-column"): | |
| with gr.Group(elem_id="controls-section"): | |
| gr.Markdown("### ⚙️ Processing Options") | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=[*INTERNAL_MODEL_MAP.keys(), EXTERNAL_MODEL_NAME], | |
| value="Gemini 2.5 Flash", | |
| label="Model" | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider(0.0, 2.0, value=0.2, step=0.05, label="temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="top_p") | |
| external_api_url = gr.Textbox( | |
| label="External API endpoint (URL)", | |
| placeholder="https://your-host/path/to/ocr", | |
| visible=False | |
| ) | |
| question = gr.Textbox( | |
| label="Custom Prompt (optional)", | |
| placeholder="Leave blank for default OCR; or ask model to output JSON by your own schema...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| run_btn = gr.Button("🚀 Process Document", elem_classes=["primary-button"]) | |
| clear_btn = gr.Button("🗑️ Clear All", elem_classes=["secondary-button"]) | |
| with gr.Group(elem_id="results-section"): | |
| gr.Markdown("### 📊 LLM Message (raw/pretty)") | |
| output_text = gr.Code(label="LLM Message", language="json", elem_id="llm-output") | |
| with gr.Row(): | |
| export_btn = gr.Button("⬇️ Export to Excel", elem_classes=["secondary-button"]) | |
| download_file = gr.File(label="Download Excel", interactive=False, visible=False) | |
| # Events | |
| file.change(preview_process, inputs=[file], outputs=[preview]) | |
| model_choice.change(_toggle_external_visibility, inputs=[model_choice], outputs=[external_api_url]) | |
| run_btn.click( | |
| run_process, | |
| inputs=[file, question, model_choice, temperature, top_p, external_api_url], | |
| outputs=[output_text, last_parsed_state] | |
| ) | |
| export_btn.click(on_export_excel, inputs=[last_parsed_state], outputs=[download_file]) | |
| clear_btn.click( | |
| clear_all, | |
| inputs=[], | |
| outputs=[file, preview, output_text, question, model_choice, last_parsed_state, | |
| download_file, temperature, top_p, external_api_url] | |
| ) | |
| return demo | |
| demo = main() | |
| if __name__ == "__main__": | |
| demo.launch() |