from __future__ import annotations import base64 import json import re from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from openai import OpenAI # ---------------------------- # Known templates (mirror your main system) # ---------------------------- KNOWN_TEMPLATES: List[Dict[str, Any]] = [ { "template_id": "T1_IFACTOR_DELIVERED_ORDER", "name": "I-FACTOR Delivered Order Form", "keywords_all": ["delivered order form"], "keywords_any": ["i-factor", "cerapedics", "product information", "stickers", "bill to", "delivered to"], }, { "template_id": "T2_SEASPINE_DELIVERED_GOODS_FORM", "name": "SeaSpine Delivered Goods Form", "keywords_all": ["delivered goods form"], "keywords_any": ["seaspine", "isotis", "handling fee", "sales order", "invoice"], }, { "template_id": "T3_ASTURA_SALES_ORDER_FORM", "name": "Astura Sales Order Form", "keywords_all": [], "keywords_any": ["astura", "dc141", "ca200", "cbba", "sales order"], }, { "template_id": "T4_MEDICAL_ESTIMATION_OF_CHARGES", "name": "Medical Estimation of Charges", "keywords_all": [], "keywords_any": ["estimation of charges", "good faith estimate", "patient responsibility", "insurance"], }, { "template_id": "T5_CLINICAL_PROGRESS_NOTE_POSTOP", "name": "Clinical Progress Note Postop", "keywords_all": [], "keywords_any": ["clinical progress note", "progress note", "post-op", "assessment", "plan"], }, { "template_id": "T6_CUSTOMER_CHARGE_SHEET_SPINE", "name": "Customer Charge Sheet Spine", "keywords_all": [], "keywords_any": ["customer charge sheet", "charge sheet", "spine", "qty", "unit price", "total"], }, { "template_id": "T7_SALES_ORDER_ZIMMER", "name": "Zimmer Sales Order", "keywords_all": [], "keywords_any": ["zimmer", "zimmer biomet", "biomet", "sales order", "purchase order", "po number"], }, ] # ---------------------------- # Public API (EXPLICIT key/model) # ---------------------------- def classify_with_openai( image_paths: List[str], *, api_key: str, model: str, max_pages: int = 2, ) -> Dict[str, Any]: """ Input: list of PNG file paths (page renders). Output: { "template_id": "T1_..." OR "UNKNOWN", "confidence": 0..1, "reason": "short string", "trainer_schema": {} # reserved for later } Hard guarantees: - does NOT read environment variables - does NOT guess api keys - strict normalization to known template_ids """ api_key = (api_key or "").strip() model = (model or "").strip() if not api_key: raise RuntimeError("classify_with_openai: api_key is empty") if not model: raise RuntimeError("classify_with_openai: model is empty") if not image_paths: return { "template_id": "UNKNOWN", "confidence": 0.0, "reason": "No rendered images provided.", "trainer_schema": {}, } # Encode first N pages (keep small + deterministic) pages_b64: List[str] = [] for p in image_paths[: max_pages if max_pages > 0 else 1]: pages_b64.append(_png_file_to_b64(Path(p))) client = OpenAI(api_key=api_key) system = ( "You are a strict document template classifier.\n" "You will be shown PNG images of PDF pages (scanned forms).\n" "Your job is to decide which known template matches.\n\n" "Hard rules:\n" "1) Output VALID JSON only. No markdown. No extra text.\n" "2) Choose ONE template_id from the provided list OR return template_id='UNKNOWN'.\n" "3) If uncertain, return UNKNOWN.\n" "4) Use printed headers, vendor branding, and distinctive layout cues.\n" "5) confidence must be 0..1.\n" ) prompt_payload = { "known_templates": KNOWN_TEMPLATES, "output_schema": { "template_id": "string (one of known template_ids) OR 'UNKNOWN'", "confidence": "number 0..1", "reason": "short string", }, } user_text = ( "Classify the attached document images against known_templates.\n" "Return JSON matching output_schema.\n\n" f"{json.dumps(prompt_payload, indent=2)}" ) # Multi-modal message: text + images content: List[Dict[str, Any]] = [{"type": "text", "text": user_text}] for b64png in pages_b64: content.append( { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64png}"}, } ) resp = client.chat.completions.create( model=model, temperature=0.0, messages=[ {"role": "system", "content": system}, {"role": "user", "content": content}, ], ) raw = (resp.choices[0].message.content or "").strip() parsed = _parse_json_object(raw) template_id = str(parsed.get("template_id") or "").strip() confidence = _to_float(parsed.get("confidence"), default=0.0) confidence = max(0.0, min(1.0, confidence)) reason = str(parsed.get("reason") or "").strip() # Normalize: only allow known template ids or UNKNOWN template_id = _normalize_template_id(template_id) # If model returns UNKNOWN but gives high confidence, clamp confidence. if template_id == "UNKNOWN" and confidence > 0.6: confidence = 0.6 return { "template_id": template_id, "confidence": confidence, "reason": reason[:500], "trainer_schema": {}, } # ---------------------------- # Legacy wrapper (ENV-based) - keep only if you want # ---------------------------- def classify_with_openai_from_env(image_paths: List[str]) -> Dict[str, Any]: """ Backwards compatible wrapper. Reads env vars, then calls classify_with_openai(api_key=..., model=...). Use this only if you have old code you haven't updated yet. """ import os api_key = (os.getenv("OPENAI_API_KEY_TEST") or os.getenv("OPENAI_API_KEY") or "").strip() if not api_key: raise RuntimeError("Missing OPENAI_API_KEY_TEST (or OPENAI_API_KEY)") model = (os.getenv("OPENAI_MODEL") or "gpt-4o-mini").strip() # IMPORTANT: call the explicit version (one implementation only) return classify_with_openai( image_paths, api_key=api_key, model=model, ) # ---------------------------- # Helpers # ---------------------------- def _normalize_template_id(template_id: str) -> str: tid = (template_id or "").strip() if not tid: return "UNKNOWN" known_ids = {t["template_id"] for t in KNOWN_TEMPLATES} if tid in known_ids: return tid # common garbage patterns (model returns name instead of id, etc.) low = tid.lower() for t in KNOWN_TEMPLATES: if t["name"].lower() == low: return t["template_id"] return "UNKNOWN" def _png_file_to_b64(path: Path) -> str: data = path.read_bytes() return base64.b64encode(data).decode("utf-8") _JSON_BLOCK_RE = re.compile(r"\{.*\}", re.DOTALL) def _parse_json_object(text: str) -> Dict[str, Any]: """ Extract and parse the first {...} JSON object from model output. Handles: - pure JSON - JSON embedded in text - fenced code blocks (we strip fences) """ if not text: return {} s = text.strip() # Strip ```json fences if present s = _strip_code_fences(s) # Fast path: starts with "{" if s.startswith("{"): try: return json.loads(s) except Exception: pass # Try to find a JSON-looking block m = _JSON_BLOCK_RE.search(s) if not m: return {} chunk = m.group(0) try: return json.loads(chunk) except Exception: # last attempt: remove trailing commas (common model mistake) cleaned = _remove_trailing_commas(chunk) try: return json.loads(cleaned) except Exception: return {} def _strip_code_fences(s: str) -> str: # remove leading ```json / ``` and trailing ``` if s.startswith("```"): s = re.sub(r"^```[a-zA-Z0-9]*\s*", "", s) s = re.sub(r"\s*```$", "", s) return s.strip() def _remove_trailing_commas(s: str) -> str: # naive but effective: remove ",}" and ",]" patterns repeatedly prev = None cur = s while prev != cur: prev = cur cur = re.sub(r",\s*}", "}", cur) cur = re.sub(r",\s*]", "]", cur) return cur def _to_float(x: Any, default: float = 0.0) -> float: try: return float(x) except Exception: return default # ---------------------------- # Optional: quick self-check (manual) # ---------------------------- def _debug_summarize_result(res: Dict[str, Any]) -> str: return f"template_id={res.get('template_id')} conf={res.get('confidence')} reason={str(res.get('reason') or '')[:80]}" def _validate_known_templates() -> Tuple[bool, str]: ids = [t.get("template_id") for t in KNOWN_TEMPLATES] if any(not i for i in ids): return False, "One or more templates missing template_id" if len(set(ids)) != len(ids): return False, "Duplicate template_id in KNOWN_TEMPLATES" return True, "ok"