Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import base64 | |
| import json | |
| import re | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from openai import OpenAI | |
| # ---------------------------- | |
| # Known templates (mirror your main system) | |
| # ---------------------------- | |
| KNOWN_TEMPLATES: List[Dict[str, Any]] = [ | |
| { | |
| "template_id": "T1_IFACTOR_DELIVERED_ORDER", | |
| "name": "I-FACTOR Delivered Order Form", | |
| "keywords_all": ["delivered order form"], | |
| "keywords_any": ["i-factor", "cerapedics", "product information", "stickers", "bill to", "delivered to"], | |
| }, | |
| { | |
| "template_id": "T2_SEASPINE_DELIVERED_GOODS_FORM", | |
| "name": "SeaSpine Delivered Goods Form", | |
| "keywords_all": ["delivered goods form"], | |
| "keywords_any": ["seaspine", "isotis", "handling fee", "sales order", "invoice"], | |
| }, | |
| { | |
| "template_id": "T3_ASTURA_SALES_ORDER_FORM", | |
| "name": "Astura Sales Order Form", | |
| "keywords_all": [], | |
| "keywords_any": ["astura", "dc141", "ca200", "cbba", "sales order"], | |
| }, | |
| { | |
| "template_id": "T4_MEDICAL_ESTIMATION_OF_CHARGES", | |
| "name": "Medical Estimation of Charges", | |
| "keywords_all": [], | |
| "keywords_any": ["estimation of charges", "good faith estimate", "patient responsibility", "insurance"], | |
| }, | |
| { | |
| "template_id": "T5_CLINICAL_PROGRESS_NOTE_POSTOP", | |
| "name": "Clinical Progress Note Postop", | |
| "keywords_all": [], | |
| "keywords_any": ["clinical progress note", "progress note", "post-op", "assessment", "plan"], | |
| }, | |
| { | |
| "template_id": "T6_CUSTOMER_CHARGE_SHEET_SPINE", | |
| "name": "Customer Charge Sheet Spine", | |
| "keywords_all": [], | |
| "keywords_any": ["customer charge sheet", "charge sheet", "spine", "qty", "unit price", "total"], | |
| }, | |
| { | |
| "template_id": "T7_SALES_ORDER_ZIMMER", | |
| "name": "Zimmer Sales Order", | |
| "keywords_all": [], | |
| "keywords_any": ["zimmer", "zimmer biomet", "biomet", "sales order", "purchase order", "po number"], | |
| }, | |
| ] | |
| # ---------------------------- | |
| # Public API (EXPLICIT key/model) | |
| # ---------------------------- | |
| def classify_with_openai( | |
| image_paths: List[str], | |
| *, | |
| api_key: str, | |
| model: str, | |
| max_pages: int = 2, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Input: list of PNG file paths (page renders). | |
| Output: | |
| { | |
| "template_id": "T1_..." OR "UNKNOWN", | |
| "confidence": 0..1, | |
| "reason": "short string", | |
| "trainer_schema": {} # reserved for later | |
| } | |
| Hard guarantees: | |
| - does NOT read environment variables | |
| - does NOT guess api keys | |
| - strict normalization to known template_ids | |
| """ | |
| api_key = (api_key or "").strip() | |
| model = (model or "").strip() | |
| if not api_key: | |
| raise RuntimeError("classify_with_openai: api_key is empty") | |
| if not model: | |
| raise RuntimeError("classify_with_openai: model is empty") | |
| if not image_paths: | |
| return { | |
| "template_id": "UNKNOWN", | |
| "confidence": 0.0, | |
| "reason": "No rendered images provided.", | |
| "trainer_schema": {}, | |
| } | |
| # Encode first N pages (keep small + deterministic) | |
| pages_b64: List[str] = [] | |
| for p in image_paths[: max_pages if max_pages > 0 else 1]: | |
| pages_b64.append(_png_file_to_b64(Path(p))) | |
| client = OpenAI(api_key=api_key) | |
| system = ( | |
| "You are a strict document template classifier.\n" | |
| "You will be shown PNG images of PDF pages (scanned forms).\n" | |
| "Your job is to decide which known template matches.\n\n" | |
| "Hard rules:\n" | |
| "1) Output VALID JSON only. No markdown. No extra text.\n" | |
| "2) Choose ONE template_id from the provided list OR return template_id='UNKNOWN'.\n" | |
| "3) If uncertain, return UNKNOWN.\n" | |
| "4) Use printed headers, vendor branding, and distinctive layout cues.\n" | |
| "5) confidence must be 0..1.\n" | |
| ) | |
| prompt_payload = { | |
| "known_templates": KNOWN_TEMPLATES, | |
| "output_schema": { | |
| "template_id": "string (one of known template_ids) OR 'UNKNOWN'", | |
| "confidence": "number 0..1", | |
| "reason": "short string", | |
| }, | |
| } | |
| user_text = ( | |
| "Classify the attached document images against known_templates.\n" | |
| "Return JSON matching output_schema.\n\n" | |
| f"{json.dumps(prompt_payload, indent=2)}" | |
| ) | |
| # Multi-modal message: text + images | |
| content: List[Dict[str, Any]] = [{"type": "text", "text": user_text}] | |
| for b64png in pages_b64: | |
| content.append( | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{b64png}"}, | |
| } | |
| ) | |
| resp = client.chat.completions.create( | |
| model=model, | |
| temperature=0.0, | |
| messages=[ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": content}, | |
| ], | |
| ) | |
| raw = (resp.choices[0].message.content or "").strip() | |
| parsed = _parse_json_object(raw) | |
| template_id = str(parsed.get("template_id") or "").strip() | |
| confidence = _to_float(parsed.get("confidence"), default=0.0) | |
| confidence = max(0.0, min(1.0, confidence)) | |
| reason = str(parsed.get("reason") or "").strip() | |
| # Normalize: only allow known template ids or UNKNOWN | |
| template_id = _normalize_template_id(template_id) | |
| # If model returns UNKNOWN but gives high confidence, clamp confidence. | |
| if template_id == "UNKNOWN" and confidence > 0.6: | |
| confidence = 0.6 | |
| return { | |
| "template_id": template_id, | |
| "confidence": confidence, | |
| "reason": reason[:500], | |
| "trainer_schema": {}, | |
| } | |
| # ---------------------------- | |
| # Legacy wrapper (ENV-based) - keep only if you want | |
| # ---------------------------- | |
| def classify_with_openai_from_env(image_paths: List[str]) -> Dict[str, Any]: | |
| """ | |
| Backwards compatible wrapper. | |
| Reads env vars, then calls classify_with_openai(api_key=..., model=...). | |
| Use this only if you have old code you haven't updated yet. | |
| """ | |
| import os | |
| api_key = (os.getenv("OPENAI_API_KEY_TEST") or os.getenv("OPENAI_API_KEY") or "").strip() | |
| if not api_key: | |
| raise RuntimeError("Missing OPENAI_API_KEY_TEST (or OPENAI_API_KEY)") | |
| model = (os.getenv("OPENAI_MODEL") or "gpt-4o-mini").strip() | |
| # IMPORTANT: call the explicit version (one implementation only) | |
| return classify_with_openai( | |
| image_paths, | |
| api_key=api_key, | |
| model=model, | |
| ) | |
| # ---------------------------- | |
| # Helpers | |
| # ---------------------------- | |
| def _normalize_template_id(template_id: str) -> str: | |
| tid = (template_id or "").strip() | |
| if not tid: | |
| return "UNKNOWN" | |
| known_ids = {t["template_id"] for t in KNOWN_TEMPLATES} | |
| if tid in known_ids: | |
| return tid | |
| # common garbage patterns (model returns name instead of id, etc.) | |
| low = tid.lower() | |
| for t in KNOWN_TEMPLATES: | |
| if t["name"].lower() == low: | |
| return t["template_id"] | |
| return "UNKNOWN" | |
| def _png_file_to_b64(path: Path) -> str: | |
| data = path.read_bytes() | |
| return base64.b64encode(data).decode("utf-8") | |
| _JSON_BLOCK_RE = re.compile(r"\{.*\}", re.DOTALL) | |
| def _parse_json_object(text: str) -> Dict[str, Any]: | |
| """ | |
| Extract and parse the first {...} JSON object from model output. | |
| Handles: | |
| - pure JSON | |
| - JSON embedded in text | |
| - fenced code blocks (we strip fences) | |
| """ | |
| if not text: | |
| return {} | |
| s = text.strip() | |
| # Strip ```json fences if present | |
| s = _strip_code_fences(s) | |
| # Fast path: starts with "{" | |
| if s.startswith("{"): | |
| try: | |
| return json.loads(s) | |
| except Exception: | |
| pass | |
| # Try to find a JSON-looking block | |
| m = _JSON_BLOCK_RE.search(s) | |
| if not m: | |
| return {} | |
| chunk = m.group(0) | |
| try: | |
| return json.loads(chunk) | |
| except Exception: | |
| # last attempt: remove trailing commas (common model mistake) | |
| cleaned = _remove_trailing_commas(chunk) | |
| try: | |
| return json.loads(cleaned) | |
| except Exception: | |
| return {} | |
| def _strip_code_fences(s: str) -> str: | |
| # remove leading ```json / ``` and trailing ``` | |
| if s.startswith("```"): | |
| s = re.sub(r"^```[a-zA-Z0-9]*\s*", "", s) | |
| s = re.sub(r"\s*```$", "", s) | |
| return s.strip() | |
| def _remove_trailing_commas(s: str) -> str: | |
| # naive but effective: remove ",}" and ",]" patterns repeatedly | |
| prev = None | |
| cur = s | |
| while prev != cur: | |
| prev = cur | |
| cur = re.sub(r",\s*}", "}", cur) | |
| cur = re.sub(r",\s*]", "]", cur) | |
| return cur | |
| def _to_float(x: Any, default: float = 0.0) -> float: | |
| try: | |
| return float(x) | |
| except Exception: | |
| return default | |
| # ---------------------------- | |
| # Optional: quick self-check (manual) | |
| # ---------------------------- | |
| def _debug_summarize_result(res: Dict[str, Any]) -> str: | |
| return f"template_id={res.get('template_id')} conf={res.get('confidence')} reason={str(res.get('reason') or '')[:80]}" | |
| def _validate_known_templates() -> Tuple[bool, str]: | |
| ids = [t.get("template_id") for t in KNOWN_TEMPLATES] | |
| if any(not i for i in ids): | |
| return False, "One or more templates missing template_id" | |
| if len(set(ids)) != len(ids): | |
| return False, "Duplicate template_id in KNOWN_TEMPLATES" | |
| return True, "ok" |