""" Unified Document Processing API OCR (Groq llama-4-scout) + Classification (RoBERTa) in one endpoint Loads model from HuggingFace Hub """ from dotenv import load_dotenv load_dotenv() import os import re import json import logging import base64 import shutil import torch import torch.nn as nn import joblib from datetime import datetime from contextlib import asynccontextmanager from typing import Optional, List from fastapi import FastAPI, File, UploadFile, HTTPException, Header from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer, RobertaModel import torch.nn.functional as F # ═══════════════════════════════════════════════════════════════ # Config # ═══════════════════════════════════════════════════════════════ class Config: GROQ_API_KEY = os.getenv("GROQ_API_KEY", "YOUR_API_KEY") GROQ_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct" HF_REPO_ID = "manarsaber11/enterprise-classifier" MAX_FILE_SIZE = 50 * 1024 * 1024 ALLOWED_EXT = {"pdf", "jpg", "jpeg", "png", "gif", "bmp"} UPLOAD_FOLDER = "uploads" CLASSIFIER_MAX_LEN = 320 CONFIDENCE_THRESHOLD = 0.85 os.makedirs(UPLOAD_FOLDER, exist_ok=True) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[logging.FileHandler("api.log"), logging.StreamHandler()] ) logger = logging.getLogger(__name__) # ═══════════════════════════════════════════════════════════════ # RoBERTa Model Class # ═══════════════════════════════════════════════════════════════ class RoBertMultiOutput(nn.Module): def __init__(self, num_department, num_priorities, department_weights=None, priority_weights=None): super().__init__() self.bert = RobertaModel.from_pretrained("roberta-base") self.dropout = nn.Dropout(0.3) self.department_classifier = nn.Linear(768, num_department) self.priority_head = nn.Sequential( nn.Linear(768, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_priorities) ) self.department_loss_fn = nn.CrossEntropyLoss(weight=department_weights) self.priority_loss_fn = nn.CrossEntropyLoss(weight=priority_weights) def forward(self, input_ids, attention_mask, department=None, priority=None): output = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled = self.dropout(output.pooler_output) department_logits = self.department_classifier(pooled) priority_logits = self.priority_head(pooled) loss = None if department is not None and priority is not None: loss = self.department_loss_fn(department_logits, department) + \ 2.0 * self.priority_loss_fn(priority_logits, priority) return {"loss": loss, "department_logits": department_logits, "priority_logits": priority_logits} # ═══════════════════════════════════════════════════════════════ # Global state # ═══════════════════════════════════════════════════════════════ _state: dict = {} # ═══════════════════════════════════════════════════════════════ # Pydantic Schemas # ═══════════════════════════════════════════════════════════════ class Entity(BaseModel): type: str value: str class RecipientInfo(BaseModel): name: Optional[str] = None date: str found: bool class AgentReview(BaseModel): triggered: bool agent_agrees: bool final_department: str reasoning: str class DocumentResult(BaseModel): raw_text: str summary: str language: str entities: List[Entity] = [] recipient: RecipientInfo department: str priority: str department_confidence: float priority_confidence: float agent_review: Optional[AgentReview] = None route: bool pages: int file_type: str file_size_bytes: int processed_at: str model_ocr: str model_classifier: str class SuccessResponse(BaseModel): success: bool = True error: Optional[str] = None data: Optional[DocumentResult] = None class HealthResponse(BaseModel): status: str timestamp: str ocr_model: str classifier_model: str # ═══════════════════════════════════════════════════════════════ # Helpers # ═══════════════════════════════════════════════════════════════ def clean_text(text: str) -> str: text = text.strip().strip('"') text = re.sub(r"[\n\t\r]", " ", text) text = re.sub(r"<[^>]+>", "", text) text = text.encode("ascii", "ignore").decode("ascii") text = re.sub(r" +", " ", text) return text.strip() def classify_text(text: str) -> dict: model = _state["clf_model"] tokenizer = _state["tokenizer"] device = _state["device"] le_dept = _state["le_dept"] le_prio = _state["le_prio"] cleaned = clean_text(text) if not cleaned: return {"department": "unknown", "priority": "unknown", "department_confidence": 0.0, "priority_confidence": 0.0} inputs = tokenizer( cleaned, truncation=True, padding="max_length", max_length=Config.CLASSIFIER_MAX_LEN, return_tensors="pt", ) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) with torch.no_grad(): outputs = model(input_ids, attention_mask) dept_probs = F.softmax(outputs["department_logits"], dim=1).cpu().squeeze() prio_probs = F.softmax(outputs["priority_logits"], dim=1).cpu().squeeze() dept_idx = dept_probs.argmax().item() prio_idx = prio_probs.argmax().item() return { "department": le_dept.inverse_transform([dept_idx])[0], "priority": le_prio.inverse_transform([prio_idx])[0], "department_confidence": round(float(dept_probs[dept_idx]), 4), "priority_confidence": round(float(prio_probs[prio_idx]), 4), } # ═══════════════════════════════════════════════════════════════ # OCR + Analysis Processor # ═══════════════════════════════════════════════════════════════ class DocumentProcessor: def __init__(self, api_key: str = None): try: from groq import Groq self.client = Groq(api_key=api_key or Config.GROQ_API_KEY) except ImportError: raise HTTPException(status_code=500, detail="Run: pip install groq") self.document_text = "" self.num_pages = 0 self.file_size = 0 def _pdf_to_images(self, pdf_path: str) -> List[str]: try: import fitz except ImportError: raise HTTPException(status_code=500, detail="Run: pip install pymupdf") doc = fitz.open(pdf_path) self.num_pages = len(doc) images = [] for i in range(len(doc)): pix = doc.load_page(i).get_pixmap() images.append(base64.b64encode(pix.tobytes("png")).decode("utf-8")) doc.close() return images def _image_to_b64(self, path: str) -> str: with open(path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") def _ocr_page(self, b64_img: str, page_num: int) -> str: response = self.client.chat.completions.create( model=Config.GROQ_MODEL, messages=[{ "role": "user", "content": [ { "type": "text", "text": ( f"You are an expert OCR engine specialized in Arabic and mixed Arabic/English documents. Page {page_num}.\n\n" "STRICT RULES:\n" "1. Extract ALL text exactly as it appears — Arabic, English, and numbers.\n" "2. Arabic text: preserve RIGHT-TO-LEFT order, copy every word exactly.\n" "3. Numbers: copy exactly as shown (Arabic-Indic ١٢٣ or Western 123).\n" "4. Tables: reconstruct each row on one line using | as column separator.\n" "5. Mixed lines (Arabic + English + numbers): preserve the full line as-is.\n" "6. Do NOT translate, summarize, reorder, or skip any text.\n" "7. Do NOT add commentary, headers, or any text not visible on the page.\n" "8. Empty page: output only [NO TEXT].\n\n" "Output the raw extracted text now:" ) }, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_img}"}} ] }], temperature=0, max_tokens=4000 ) return response.choices[0].message.content or "" def _clean_ocr(self, text: str) -> str: bad = [ r"^###", r"^```", r"لقد قمت", r"النص المستخرج", r"استخلاص", r"^Here is the extracted", r"^I (can see|found|analyzed)", ] lines = text.split("\n") return "\n".join( l for l in lines if not any(re.search(p, l.strip()) for p in bad) ) async def _ocr_all_pages(self, images: List[str]) -> str: all_text = "" for i, img in enumerate(images): page_num = i + 1 logger.info(f"OCR page {page_num}/{len(images)}") try: page_text = self._ocr_page(img, page_num) all_text += f"\n\n=== Page {page_num} ===\n{self._clean_ocr(page_text)}" except Exception as e: logger.error(f"Page {page_num} failed: {e}") all_text += f"\n\n=== Page {page_num} ===\n[EXTRACTION FAILED]" return all_text def _groq(self, system: str, user: str, max_tokens: int = 500) -> str: response = self.client.chat.completions.create( model=Config.GROQ_MODEL, messages=[ {"role": "system", "content": system}, {"role": "user", "content": user} ], temperature=0, max_tokens=max_tokens ) return response.choices[0].message.content.strip() def _parse_json(self, raw: str): for marker in ["```json", "```"]: if marker in raw: raw = raw.split(marker)[1].split("```")[0].strip() break return json.loads(raw) async def get_recipient(self) -> RecipientInfo: today = datetime.now().strftime("%Y-%m-%d") try: answer = self._groq( system="Document analysis assistant. Respond with valid JSON only.", user=( f"Extract recipient and date from this document.\n\n" f"--- TEXT ---\n{self.document_text[:2000]}\n--- END ---\n\n" "RECIPIENT: person/org this is addressed TO. If not found → null\n" f"DATE: document date in YYYY-MM-DD. If not found → {today}\n" 'Return ONLY: {"name": "...", "date": "YYYY-MM-DD"}' ), max_tokens=200 ) info = self._parse_json(answer) name = info.get("name") found = bool(name and name not in [None, "null", "", "غير محدد"]) date = info.get("date", today) if not re.match(r"\d{4}-\d{2}-\d{2}", str(date)): date = today return RecipientInfo(name=name if found else None, date=date, found=found) except Exception as e: logger.warning(f"Recipient failed: {e}") return RecipientInfo(name=None, date=today, found=False) async def get_entities(self) -> List[Entity]: try: answer = self._groq( system="NER expert. Return ONLY a valid JSON array, no extra text.", user=( f"Extract named entities:\n\n{self.document_text[:3000]}\n\n" "Types: PERSON_NAME, ORGANIZATION, LOCATION, DATE, REFERENCE_NUMBER, PHONE, EMAIL, AMOUNT\n" 'Return: [{"type": "TYPE", "value": "value"}, ...]' ) ) data = self._parse_json(answer) return [Entity(**e) for e in data] if isinstance(data, list) else [] except Exception as e: logger.warning(f"Entities failed: {e}") return [] async def get_summary(self, language: str) -> str: try: if language == "arabic": prompt = f"لخّص الوثيقة التالية باللغة العربية الفصحى في فقرة أو اثنتين:\n\n{self.document_text[:5000]}" else: prompt = f"Summarize this document in 1-2 paragraphs:\n\n{self.document_text[:5000]}" return self._groq(system="Document summarizer.", user=prompt, max_tokens=500) except Exception as e: logger.warning(f"Summary failed: {e}") return "" def detect_language(self) -> str: arabic = sum(1 for c in self.document_text if "\u0600" <= c <= "\u06FF") english = sum(1 for c in self.document_text if "a" <= c.lower() <= "z") return "arabic" if arabic > english else "english" async def translate_to_english(self, text: str) -> str: try: return self._groq( system="You are a translator. Return ONLY the English translation, no explanation, no extra text.", user=f"Translate the following text to English:\n\n{text}", max_tokens=600 ) except Exception as e: logger.warning(f"Translation failed: {e}") return text async def agent_review_department(self, clf: dict) -> AgentReview: departments = [ "business_development", "customer_support", "financial_accounting", "hr_department", "it_department", "legal" ] try: dept = clf["department"] conf = clf["department_confidence"] * 100 prompt = ( f"An AI model classified this document as '{dept}' with confidence {conf:.1f}%.\n\n" f"--- DOCUMENT TEXT ---\n{self.document_text[:2000]}\n--- END ---\n\n" f"Available departments: {departments}\n\n" "Do you agree? If not, suggest the correct department.\n" 'Return ONLY: {"agent_agrees": true, "final_department": "...", "reasoning": "..."}' ) answer = self._groq( system=( "You are a document routing expert. Verify or correct the department classification. " "Respond with valid JSON only, no extra text." ), user=prompt, max_tokens=300 ) data = self._parse_json(answer) agrees = bool(data.get("agent_agrees", True)) final = data.get("final_department", dept) if final not in departments: final = dept return AgentReview( triggered=True, agent_agrees=agrees, final_department=final, reasoning=data.get("reasoning", "") ) except Exception as e: logger.warning(f"Agent review failed: {e}") return AgentReview( triggered=True, agent_agrees=True, final_department=clf["department"], reasoning="Agent review failed, keeping model decision." ) async def process(self, file_path: str) -> DocumentResult: self.file_size = os.path.getsize(file_path) ext = os.path.splitext(file_path)[1].lower() # Step 1: OCR if ext == ".pdf": images = self._pdf_to_images(file_path) else: self.num_pages = 1 images = [self._image_to_b64(file_path)] if not images: raise HTTPException(status_code=400, detail="CONVERSION_FAILED") self.document_text = await self._ocr_all_pages(images) if not self.document_text or len(self.document_text.strip()) < 10: raise HTTPException(status_code=400, detail="EXTRACTION_FAILED") # Step 2: Analyze language = self.detect_language() recipient = await self.get_recipient() entities = await self.get_entities() summary = await self.get_summary(language) # Step 3: Translate if Arabic then classify clf_input = self.document_text[:500] if language == "arabic": logger.info("[translate] Arabic detected, translating before classification...") clf_input = await self.translate_to_english(clf_input) logger.info("[translate] Done.") clf = classify_text(clf_input) # Step 4: Agent review if confidence is low agent_review = None final_department = clf["department"] if clf["department_confidence"] < Config.CONFIDENCE_THRESHOLD: logger.info(f"[agent] Low confidence ({clf['department_confidence']:.2f}), triggering agent review...") agent_review = await self.agent_review_department(clf) final_department = agent_review.final_department logger.info(f"[agent] {clf['department']} → {final_department} (agrees: {agent_review.agent_agrees})") else: logger.info(f"[agent] High confidence ({clf['department_confidence']:.2f}), skipping.") return DocumentResult( raw_text = self.document_text, summary = summary, language = language, entities = entities, recipient = recipient, department = final_department, priority = clf["priority"], department_confidence = clf["department_confidence"], priority_confidence = clf["priority_confidence"], agent_review = agent_review, route = not recipient.found, pages = self.num_pages, file_type = ext.upper().replace(".", ""), file_size_bytes = self.file_size, processed_at = datetime.now().isoformat(), model_ocr = Config.GROQ_MODEL, model_classifier = "RoBERTa fine-tuned" ) # ═══════════════════════════════════════════════════════════════ # Lifespan — load model from HuggingFace Hub # ═══════════════════════════════════════════════════════════════ @asynccontextmanager async def lifespan(app: FastAPI): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"[startup] device = {device}") BASE_DIR = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(BASE_DIR, "model_last.pt") le_dept_path = os.path.join(BASE_DIR, "label_encoder.pkl") le_prio_path = os.path.join(BASE_DIR, "priority_encoder.pkl") logger.info(f"[startup] loading model from local files...") tokenizer = AutoTokenizer.from_pretrained(BASE_DIR) le_dept = joblib.load(le_dept_path) le_prio = joblib.load(le_prio_path) ckpt = torch.load(model_path, map_location=device, weights_only=False) model = RoBertMultiOutput(len(le_dept.classes_), len(le_prio.classes_)) model.load_state_dict(ckpt["model_state_dict"], strict=False) model.to(device).eval() _state.update( clf_model=model, tokenizer=tokenizer, le_dept=le_dept, le_prio=le_prio, device=device, ) logger.info(f"[startup] departments : {list(le_dept.classes_)}") logger.info(f"[startup] priorities : {list(le_prio.classes_)}") yield _state.clear() logger.info("[shutdown] resources released.") # ═══════════════════════════════════════════════════════════════ # FastAPI App # ═══════════════════════════════════════════════════════════════ app = FastAPI( title="Document Processing API", description=( "**One endpoint** combining:\n\n" "1. OCR — extract text from PDF/images using Groq llama-4-scout\n" "2. Classification — department + priority using fine-tuned RoBERTa\n" "3. Routing — decides if manual routing is needed\n\n" "Upload any PDF or image and get a unified JSON response." ), version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ═══════════════════════════════════════════════════════════════ # Routes # ═══════════════════════════════════════════════════════════════ @app.get("/", tags=["Info"]) def root(): return {"message": "Document Processing API", "docs": "/docs"} @app.get("/health", response_model=HealthResponse, tags=["Info"]) def health(): return HealthResponse( status="healthy", timestamp=datetime.now().isoformat(), ocr_model=Config.GROQ_MODEL, classifier_model="RoBERTa fine-tuned" ) @app.post("/api/v1/process", response_model=SuccessResponse, tags=["Process"]) async def process_document( file: UploadFile = File(...), x_groq_api_key: Optional[str] = Header(None, alias="X-Groq-Api-Key") ): """ Upload a PDF or image → returns unified JSON with: raw_text, summary, entities, recipient, department, priority, route **Header required:** `X-Groq-Api-Key: your_groq_api_key` """ temp_path = None try: if not x_groq_api_key: raise HTTPException(status_code=401, detail="MISSING_GROQ_API_KEY: Add X-Groq-Api-Key header") if not file.filename: raise HTTPException(status_code=400, detail="EMPTY_FILENAME") ext = os.path.splitext(file.filename)[1].lower().replace(".", "") if ext not in Config.ALLOWED_EXT: raise HTTPException(status_code=400, detail="INVALID_FILE_TYPE") ts = datetime.now().strftime("%Y%m%d_%H%M%S") temp_path = os.path.join(Config.UPLOAD_FOLDER, f"{ts}_{file.filename}") with open(temp_path, "wb") as buf: shutil.copyfileobj(file.file, buf) if os.path.getsize(temp_path) > Config.MAX_FILE_SIZE: raise HTTPException(status_code=413, detail="FILE_TOO_LARGE") processor = DocumentProcessor(api_key=x_groq_api_key) result = await processor.process(temp_path) return SuccessResponse(success=True, data=result) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail="INTERNAL_SERVER_ERROR") finally: if temp_path and os.path.exists(temp_path): try: os.remove(temp_path) except Exception: pass @app.exception_handler(HTTPException) async def http_exception_handler(request, exc): return JSONResponse( status_code=exc.status_code, content={"success": False, "error": exc.detail, "data": None} ) # ═══════════════════════════════════════════════════════════════ # Run # ═══════════════════════════════════════════════════════════════ if __name__ == "__main__": import uvicorn import sys import pathlib if sys.platform == "win32": sys.stdout.reconfigure(encoding="utf-8") module_name = pathlib.Path(__file__).stem uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, reload=True)