Spaces:
Sleeping
Sleeping
| """ | |
| Unified Document Processing API | |
| OCR (Groq llama-4-scout) + Classification (RoBERTa) in one endpoint | |
| Loads model from HuggingFace Hub | |
| """ | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import os | |
| import re | |
| import json | |
| import logging | |
| import base64 | |
| import shutil | |
| import torch | |
| import torch.nn as nn | |
| import joblib | |
| from datetime import datetime | |
| from contextlib import asynccontextmanager | |
| from typing import Optional, List | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Header | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, RobertaModel | |
| import torch.nn.functional as F | |
| # ═══════════════════════════════════════════════════════════════ | |
| # Config | |
| # ═══════════════════════════════════════════════════════════════ | |
| class Config: | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY", "YOUR_API_KEY") | |
| GROQ_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct" | |
| HF_REPO_ID = "manarsaber11/enterprise-classifier" | |
| MAX_FILE_SIZE = 50 * 1024 * 1024 | |
| ALLOWED_EXT = {"pdf", "jpg", "jpeg", "png", "gif", "bmp"} | |
| UPLOAD_FOLDER = "uploads" | |
| CLASSIFIER_MAX_LEN = 320 | |
| CONFIDENCE_THRESHOLD = 0.85 | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[logging.FileHandler("api.log"), logging.StreamHandler()] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ═══════════════════════════════════════════════════════════════ | |
| # RoBERTa Model Class | |
| # ═══════════════════════════════════════════════════════════════ | |
| class RoBertMultiOutput(nn.Module): | |
| def __init__(self, num_department, num_priorities, department_weights=None, priority_weights=None): | |
| super().__init__() | |
| self.bert = RobertaModel.from_pretrained("roberta-base") | |
| self.dropout = nn.Dropout(0.3) | |
| self.department_classifier = nn.Linear(768, num_department) | |
| self.priority_head = nn.Sequential( | |
| nn.Linear(768, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, num_priorities) | |
| ) | |
| self.department_loss_fn = nn.CrossEntropyLoss(weight=department_weights) | |
| self.priority_loss_fn = nn.CrossEntropyLoss(weight=priority_weights) | |
| def forward(self, input_ids, attention_mask, department=None, priority=None): | |
| output = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = self.dropout(output.pooler_output) | |
| department_logits = self.department_classifier(pooled) | |
| priority_logits = self.priority_head(pooled) | |
| loss = None | |
| if department is not None and priority is not None: | |
| loss = self.department_loss_fn(department_logits, department) + \ | |
| 2.0 * self.priority_loss_fn(priority_logits, priority) | |
| return {"loss": loss, "department_logits": department_logits, "priority_logits": priority_logits} | |
| # ═══════════════════════════════════════════════════════════════ | |
| # Global state | |
| # ═══════════════════════════════════════════════════════════════ | |
| _state: dict = {} | |
| # ═══════════════════════════════════════════════════════════════ | |
| # Pydantic Schemas | |
| # ═══════════════════════════════════════════════════════════════ | |
| class Entity(BaseModel): | |
| type: str | |
| value: str | |
| class RecipientInfo(BaseModel): | |
| name: Optional[str] = None | |
| date: str | |
| found: bool | |
| class AgentReview(BaseModel): | |
| triggered: bool | |
| agent_agrees: bool | |
| final_department: str | |
| reasoning: str | |
| class DocumentResult(BaseModel): | |
| raw_text: str | |
| summary: str | |
| language: str | |
| entities: List[Entity] = [] | |
| recipient: RecipientInfo | |
| department: str | |
| priority: str | |
| department_confidence: float | |
| priority_confidence: float | |
| agent_review: Optional[AgentReview] = None | |
| route: bool | |
| pages: int | |
| file_type: str | |
| file_size_bytes: int | |
| processed_at: str | |
| model_ocr: str | |
| model_classifier: str | |
| class SuccessResponse(BaseModel): | |
| success: bool = True | |
| error: Optional[str] = None | |
| data: Optional[DocumentResult] = None | |
| class HealthResponse(BaseModel): | |
| status: str | |
| timestamp: str | |
| ocr_model: str | |
| classifier_model: str | |
| # ═══════════════════════════════════════════════════════════════ | |
| # Helpers | |
| # ═══════════════════════════════════════════════════════════════ | |
| def clean_text(text: str) -> str: | |
| text = text.strip().strip('"') | |
| text = re.sub(r"[\n\t\r]", " ", text) | |
| text = re.sub(r"<[^>]+>", "", text) | |
| text = text.encode("ascii", "ignore").decode("ascii") | |
| text = re.sub(r" +", " ", text) | |
| return text.strip() | |
| def classify_text(text: str) -> dict: | |
| model = _state["clf_model"] | |
| tokenizer = _state["tokenizer"] | |
| device = _state["device"] | |
| le_dept = _state["le_dept"] | |
| le_prio = _state["le_prio"] | |
| cleaned = clean_text(text) | |
| if not cleaned: | |
| return {"department": "unknown", "priority": "unknown", | |
| "department_confidence": 0.0, "priority_confidence": 0.0} | |
| inputs = tokenizer( | |
| cleaned, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=Config.CLASSIFIER_MAX_LEN, | |
| return_tensors="pt", | |
| ) | |
| input_ids = inputs["input_ids"].to(device) | |
| attention_mask = inputs["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask) | |
| dept_probs = F.softmax(outputs["department_logits"], dim=1).cpu().squeeze() | |
| prio_probs = F.softmax(outputs["priority_logits"], dim=1).cpu().squeeze() | |
| dept_idx = dept_probs.argmax().item() | |
| prio_idx = prio_probs.argmax().item() | |
| return { | |
| "department": le_dept.inverse_transform([dept_idx])[0], | |
| "priority": le_prio.inverse_transform([prio_idx])[0], | |
| "department_confidence": round(float(dept_probs[dept_idx]), 4), | |
| "priority_confidence": round(float(prio_probs[prio_idx]), 4), | |
| } | |
| # ═══════════════════════════════════════════════════════════════ | |
| # OCR + Analysis Processor | |
| # ═══════════════════════════════════════════════════════════════ | |
| class DocumentProcessor: | |
| def __init__(self, api_key: str = None): | |
| try: | |
| from groq import Groq | |
| self.client = Groq(api_key=api_key or Config.GROQ_API_KEY) | |
| except ImportError: | |
| raise HTTPException(status_code=500, detail="Run: pip install groq") | |
| self.document_text = "" | |
| self.num_pages = 0 | |
| self.file_size = 0 | |
| def _pdf_to_images(self, pdf_path: str) -> List[str]: | |
| try: | |
| import fitz | |
| except ImportError: | |
| raise HTTPException(status_code=500, detail="Run: pip install pymupdf") | |
| doc = fitz.open(pdf_path) | |
| self.num_pages = len(doc) | |
| images = [] | |
| for i in range(len(doc)): | |
| pix = doc.load_page(i).get_pixmap() | |
| images.append(base64.b64encode(pix.tobytes("png")).decode("utf-8")) | |
| doc.close() | |
| return images | |
| def _image_to_b64(self, path: str) -> str: | |
| with open(path, "rb") as f: | |
| return base64.b64encode(f.read()).decode("utf-8") | |
| def _ocr_page(self, b64_img: str, page_num: int) -> str: | |
| response = self.client.chat.completions.create( | |
| model=Config.GROQ_MODEL, | |
| messages=[{ | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": ( | |
| f"You are an expert OCR engine specialized in Arabic and mixed Arabic/English documents. Page {page_num}.\n\n" | |
| "STRICT RULES:\n" | |
| "1. Extract ALL text exactly as it appears — Arabic, English, and numbers.\n" | |
| "2. Arabic text: preserve RIGHT-TO-LEFT order, copy every word exactly.\n" | |
| "3. Numbers: copy exactly as shown (Arabic-Indic ١٢٣ or Western 123).\n" | |
| "4. Tables: reconstruct each row on one line using | as column separator.\n" | |
| "5. Mixed lines (Arabic + English + numbers): preserve the full line as-is.\n" | |
| "6. Do NOT translate, summarize, reorder, or skip any text.\n" | |
| "7. Do NOT add commentary, headers, or any text not visible on the page.\n" | |
| "8. Empty page: output only [NO TEXT].\n\n" | |
| "Output the raw extracted text now:" | |
| ) | |
| }, | |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_img}"}} | |
| ] | |
| }], | |
| temperature=0, | |
| max_tokens=4000 | |
| ) | |
| return response.choices[0].message.content or "" | |
| def _clean_ocr(self, text: str) -> str: | |
| bad = [ | |
| r"^###", r"^```", | |
| r"لقد قمت", r"النص المستخرج", r"استخلاص", | |
| r"^Here is the extracted", r"^I (can see|found|analyzed)", | |
| ] | |
| lines = text.split("\n") | |
| return "\n".join( | |
| l for l in lines if not any(re.search(p, l.strip()) for p in bad) | |
| ) | |
| async def _ocr_all_pages(self, images: List[str]) -> str: | |
| all_text = "" | |
| for i, img in enumerate(images): | |
| page_num = i + 1 | |
| logger.info(f"OCR page {page_num}/{len(images)}") | |
| try: | |
| page_text = self._ocr_page(img, page_num) | |
| all_text += f"\n\n=== Page {page_num} ===\n{self._clean_ocr(page_text)}" | |
| except Exception as e: | |
| logger.error(f"Page {page_num} failed: {e}") | |
| all_text += f"\n\n=== Page {page_num} ===\n[EXTRACTION FAILED]" | |
| return all_text | |
| def _groq(self, system: str, user: str, max_tokens: int = 500) -> str: | |
| response = self.client.chat.completions.create( | |
| model=Config.GROQ_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user} | |
| ], | |
| temperature=0, | |
| max_tokens=max_tokens | |
| ) | |
| return response.choices[0].message.content.strip() | |
| def _parse_json(self, raw: str): | |
| for marker in ["```json", "```"]: | |
| if marker in raw: | |
| raw = raw.split(marker)[1].split("```")[0].strip() | |
| break | |
| return json.loads(raw) | |
| async def get_recipient(self) -> RecipientInfo: | |
| today = datetime.now().strftime("%Y-%m-%d") | |
| try: | |
| answer = self._groq( | |
| system="Document analysis assistant. Respond with valid JSON only.", | |
| user=( | |
| f"Extract recipient and date from this document.\n\n" | |
| f"--- TEXT ---\n{self.document_text[:2000]}\n--- END ---\n\n" | |
| "RECIPIENT: person/org this is addressed TO. If not found → null\n" | |
| f"DATE: document date in YYYY-MM-DD. If not found → {today}\n" | |
| 'Return ONLY: {"name": "...", "date": "YYYY-MM-DD"}' | |
| ), | |
| max_tokens=200 | |
| ) | |
| info = self._parse_json(answer) | |
| name = info.get("name") | |
| found = bool(name and name not in [None, "null", "", "غير محدد"]) | |
| date = info.get("date", today) | |
| if not re.match(r"\d{4}-\d{2}-\d{2}", str(date)): | |
| date = today | |
| return RecipientInfo(name=name if found else None, date=date, found=found) | |
| except Exception as e: | |
| logger.warning(f"Recipient failed: {e}") | |
| return RecipientInfo(name=None, date=today, found=False) | |
| async def get_entities(self) -> List[Entity]: | |
| try: | |
| answer = self._groq( | |
| system="NER expert. Return ONLY a valid JSON array, no extra text.", | |
| user=( | |
| f"Extract named entities:\n\n{self.document_text[:3000]}\n\n" | |
| "Types: PERSON_NAME, ORGANIZATION, LOCATION, DATE, REFERENCE_NUMBER, PHONE, EMAIL, AMOUNT\n" | |
| 'Return: [{"type": "TYPE", "value": "value"}, ...]' | |
| ) | |
| ) | |
| data = self._parse_json(answer) | |
| return [Entity(**e) for e in data] if isinstance(data, list) else [] | |
| except Exception as e: | |
| logger.warning(f"Entities failed: {e}") | |
| return [] | |
| async def get_summary(self, language: str) -> str: | |
| try: | |
| if language == "arabic": | |
| prompt = f"لخّص الوثيقة التالية باللغة العربية الفصحى في فقرة أو اثنتين:\n\n{self.document_text[:5000]}" | |
| else: | |
| prompt = f"Summarize this document in 1-2 paragraphs:\n\n{self.document_text[:5000]}" | |
| return self._groq(system="Document summarizer.", user=prompt, max_tokens=500) | |
| except Exception as e: | |
| logger.warning(f"Summary failed: {e}") | |
| return "" | |
| def detect_language(self) -> str: | |
| arabic = sum(1 for c in self.document_text if "\u0600" <= c <= "\u06FF") | |
| english = sum(1 for c in self.document_text if "a" <= c.lower() <= "z") | |
| return "arabic" if arabic > english else "english" | |
| async def translate_to_english(self, text: str) -> str: | |
| try: | |
| return self._groq( | |
| system="You are a translator. Return ONLY the English translation, no explanation, no extra text.", | |
| user=f"Translate the following text to English:\n\n{text}", | |
| max_tokens=600 | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Translation failed: {e}") | |
| return text | |
| async def agent_review_department(self, clf: dict) -> AgentReview: | |
| departments = [ | |
| "business_development", "customer_support", "financial_accounting", | |
| "hr_department", "it_department", "legal" | |
| ] | |
| try: | |
| dept = clf["department"] | |
| conf = clf["department_confidence"] * 100 | |
| prompt = ( | |
| f"An AI model classified this document as '{dept}' with confidence {conf:.1f}%.\n\n" | |
| f"--- DOCUMENT TEXT ---\n{self.document_text[:2000]}\n--- END ---\n\n" | |
| f"Available departments: {departments}\n\n" | |
| "Do you agree? If not, suggest the correct department.\n" | |
| 'Return ONLY: {"agent_agrees": true, "final_department": "...", "reasoning": "..."}' | |
| ) | |
| answer = self._groq( | |
| system=( | |
| "You are a document routing expert. Verify or correct the department classification. " | |
| "Respond with valid JSON only, no extra text." | |
| ), | |
| user=prompt, | |
| max_tokens=300 | |
| ) | |
| data = self._parse_json(answer) | |
| agrees = bool(data.get("agent_agrees", True)) | |
| final = data.get("final_department", dept) | |
| if final not in departments: | |
| final = dept | |
| return AgentReview( | |
| triggered=True, | |
| agent_agrees=agrees, | |
| final_department=final, | |
| reasoning=data.get("reasoning", "") | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Agent review failed: {e}") | |
| return AgentReview( | |
| triggered=True, | |
| agent_agrees=True, | |
| final_department=clf["department"], | |
| reasoning="Agent review failed, keeping model decision." | |
| ) | |
| async def process(self, file_path: str) -> DocumentResult: | |
| self.file_size = os.path.getsize(file_path) | |
| ext = os.path.splitext(file_path)[1].lower() | |
| # Step 1: OCR | |
| if ext == ".pdf": | |
| images = self._pdf_to_images(file_path) | |
| else: | |
| self.num_pages = 1 | |
| images = [self._image_to_b64(file_path)] | |
| if not images: | |
| raise HTTPException(status_code=400, detail="CONVERSION_FAILED") | |
| self.document_text = await self._ocr_all_pages(images) | |
| if not self.document_text or len(self.document_text.strip()) < 10: | |
| raise HTTPException(status_code=400, detail="EXTRACTION_FAILED") | |
| # Step 2: Analyze | |
| language = self.detect_language() | |
| recipient = await self.get_recipient() | |
| entities = await self.get_entities() | |
| summary = await self.get_summary(language) | |
| # Step 3: Translate if Arabic then classify | |
| clf_input = self.document_text[:500] | |
| if language == "arabic": | |
| logger.info("[translate] Arabic detected, translating before classification...") | |
| clf_input = await self.translate_to_english(clf_input) | |
| logger.info("[translate] Done.") | |
| clf = classify_text(clf_input) | |
| # Step 4: Agent review if confidence is low | |
| agent_review = None | |
| final_department = clf["department"] | |
| if clf["department_confidence"] < Config.CONFIDENCE_THRESHOLD: | |
| logger.info(f"[agent] Low confidence ({clf['department_confidence']:.2f}), triggering agent review...") | |
| agent_review = await self.agent_review_department(clf) | |
| final_department = agent_review.final_department | |
| logger.info(f"[agent] {clf['department']} → {final_department} (agrees: {agent_review.agent_agrees})") | |
| else: | |
| logger.info(f"[agent] High confidence ({clf['department_confidence']:.2f}), skipping.") | |
| return DocumentResult( | |
| raw_text = self.document_text, | |
| summary = summary, | |
| language = language, | |
| entities = entities, | |
| recipient = recipient, | |
| department = final_department, | |
| priority = clf["priority"], | |
| department_confidence = clf["department_confidence"], | |
| priority_confidence = clf["priority_confidence"], | |
| agent_review = agent_review, | |
| route = not recipient.found, | |
| pages = self.num_pages, | |
| file_type = ext.upper().replace(".", ""), | |
| file_size_bytes = self.file_size, | |
| processed_at = datetime.now().isoformat(), | |
| model_ocr = Config.GROQ_MODEL, | |
| model_classifier = "RoBERTa fine-tuned" | |
| ) | |
| # ═══════════════════════════════════════════════════════════════ | |
| # Lifespan — load model from HuggingFace Hub | |
| # ═══════════════════════════════════════════════════════════════ | |
| async def lifespan(app: FastAPI): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"[startup] device = {device}") | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| model_path = os.path.join(BASE_DIR, "model_last.pt") | |
| le_dept_path = os.path.join(BASE_DIR, "label_encoder.pkl") | |
| le_prio_path = os.path.join(BASE_DIR, "priority_encoder.pkl") | |
| logger.info(f"[startup] loading model from local files...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_DIR) | |
| le_dept = joblib.load(le_dept_path) | |
| le_prio = joblib.load(le_prio_path) | |
| ckpt = torch.load(model_path, map_location=device, weights_only=False) | |
| model = RoBertMultiOutput(len(le_dept.classes_), len(le_prio.classes_)) | |
| model.load_state_dict(ckpt["model_state_dict"], strict=False) | |
| model.to(device).eval() | |
| _state.update( | |
| clf_model=model, | |
| tokenizer=tokenizer, | |
| le_dept=le_dept, | |
| le_prio=le_prio, | |
| device=device, | |
| ) | |
| logger.info(f"[startup] departments : {list(le_dept.classes_)}") | |
| logger.info(f"[startup] priorities : {list(le_prio.classes_)}") | |
| yield | |
| _state.clear() | |
| logger.info("[shutdown] resources released.") | |
| # ═══════════════════════════════════════════════════════════════ | |
| # FastAPI App | |
| # ═══════════════════════════════════════════════════════════════ | |
| app = FastAPI( | |
| title="Document Processing API", | |
| description=( | |
| "**One endpoint** combining:\n\n" | |
| "1. OCR — extract text from PDF/images using Groq llama-4-scout\n" | |
| "2. Classification — department + priority using fine-tuned RoBERTa\n" | |
| "3. Routing — decides if manual routing is needed\n\n" | |
| "Upload any PDF or image and get a unified JSON response." | |
| ), | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ═══════════════════════════════════════════════════════════════ | |
| # Routes | |
| # ═══════════════════════════════════════════════════════════════ | |
| def root(): | |
| return {"message": "Document Processing API", "docs": "/docs"} | |
| def health(): | |
| return HealthResponse( | |
| status="healthy", | |
| timestamp=datetime.now().isoformat(), | |
| ocr_model=Config.GROQ_MODEL, | |
| classifier_model="RoBERTa fine-tuned" | |
| ) | |
| async def process_document( | |
| file: UploadFile = File(...), | |
| x_groq_api_key: Optional[str] = Header(None, alias="X-Groq-Api-Key") | |
| ): | |
| """ | |
| Upload a PDF or image → returns unified JSON with: | |
| raw_text, summary, entities, recipient, department, priority, route | |
| **Header required:** `X-Groq-Api-Key: your_groq_api_key` | |
| """ | |
| temp_path = None | |
| try: | |
| if not x_groq_api_key: | |
| raise HTTPException(status_code=401, detail="MISSING_GROQ_API_KEY: Add X-Groq-Api-Key header") | |
| if not file.filename: | |
| raise HTTPException(status_code=400, detail="EMPTY_FILENAME") | |
| ext = os.path.splitext(file.filename)[1].lower().replace(".", "") | |
| if ext not in Config.ALLOWED_EXT: | |
| raise HTTPException(status_code=400, detail="INVALID_FILE_TYPE") | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| temp_path = os.path.join(Config.UPLOAD_FOLDER, f"{ts}_{file.filename}") | |
| with open(temp_path, "wb") as buf: | |
| shutil.copyfileobj(file.file, buf) | |
| if os.path.getsize(temp_path) > Config.MAX_FILE_SIZE: | |
| raise HTTPException(status_code=413, detail="FILE_TOO_LARGE") | |
| processor = DocumentProcessor(api_key=x_groq_api_key) | |
| result = await processor.process(temp_path) | |
| return SuccessResponse(success=True, data=result) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail="INTERNAL_SERVER_ERROR") | |
| finally: | |
| if temp_path and os.path.exists(temp_path): | |
| try: | |
| os.remove(temp_path) | |
| except Exception: | |
| pass | |
| async def http_exception_handler(request, exc): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"success": False, "error": exc.detail, "data": None} | |
| ) | |
| # ═══════════════════════════════════════════════════════════════ | |
| # Run | |
| # ═══════════════════════════════════════════════════════════════ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import sys | |
| import pathlib | |
| if sys.platform == "win32": | |
| sys.stdout.reconfigure(encoding="utf-8") | |
| module_name = pathlib.Path(__file__).stem | |
| uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, reload=True) | |