document-processor / unified_api.py
manarsaber11's picture
Update unified_api.py
236505b verified
"""
Unified Document Processing API
OCR (Groq llama-4-scout) + Classification (RoBERTa) in one endpoint
Loads model from HuggingFace Hub
"""
from dotenv import load_dotenv
load_dotenv()
import os
import re
import json
import logging
import base64
import shutil
import torch
import torch.nn as nn
import joblib
from datetime import datetime
from contextlib import asynccontextmanager
from typing import Optional, List
from fastapi import FastAPI, File, UploadFile, HTTPException, Header
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, RobertaModel
import torch.nn.functional as F
# ═══════════════════════════════════════════════════════════════
# Config
# ═══════════════════════════════════════════════════════════════
class Config:
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "YOUR_API_KEY")
GROQ_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
HF_REPO_ID = "manarsaber11/enterprise-classifier"
MAX_FILE_SIZE = 50 * 1024 * 1024
ALLOWED_EXT = {"pdf", "jpg", "jpeg", "png", "gif", "bmp"}
UPLOAD_FOLDER = "uploads"
CLASSIFIER_MAX_LEN = 320
CONFIDENCE_THRESHOLD = 0.85
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler("api.log"), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# ═══════════════════════════════════════════════════════════════
# RoBERTa Model Class
# ═══════════════════════════════════════════════════════════════
class RoBertMultiOutput(nn.Module):
def __init__(self, num_department, num_priorities, department_weights=None, priority_weights=None):
super().__init__()
self.bert = RobertaModel.from_pretrained("roberta-base")
self.dropout = nn.Dropout(0.3)
self.department_classifier = nn.Linear(768, num_department)
self.priority_head = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_priorities)
)
self.department_loss_fn = nn.CrossEntropyLoss(weight=department_weights)
self.priority_loss_fn = nn.CrossEntropyLoss(weight=priority_weights)
def forward(self, input_ids, attention_mask, department=None, priority=None):
output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled = self.dropout(output.pooler_output)
department_logits = self.department_classifier(pooled)
priority_logits = self.priority_head(pooled)
loss = None
if department is not None and priority is not None:
loss = self.department_loss_fn(department_logits, department) + \
2.0 * self.priority_loss_fn(priority_logits, priority)
return {"loss": loss, "department_logits": department_logits, "priority_logits": priority_logits}
# ═══════════════════════════════════════════════════════════════
# Global state
# ═══════════════════════════════════════════════════════════════
_state: dict = {}
# ═══════════════════════════════════════════════════════════════
# Pydantic Schemas
# ═══════════════════════════════════════════════════════════════
class Entity(BaseModel):
type: str
value: str
class RecipientInfo(BaseModel):
name: Optional[str] = None
date: str
found: bool
class AgentReview(BaseModel):
triggered: bool
agent_agrees: bool
final_department: str
reasoning: str
class DocumentResult(BaseModel):
raw_text: str
summary: str
language: str
entities: List[Entity] = []
recipient: RecipientInfo
department: str
priority: str
department_confidence: float
priority_confidence: float
agent_review: Optional[AgentReview] = None
route: bool
pages: int
file_type: str
file_size_bytes: int
processed_at: str
model_ocr: str
model_classifier: str
class SuccessResponse(BaseModel):
success: bool = True
error: Optional[str] = None
data: Optional[DocumentResult] = None
class HealthResponse(BaseModel):
status: str
timestamp: str
ocr_model: str
classifier_model: str
# ═══════════════════════════════════════════════════════════════
# Helpers
# ═══════════════════════════════════════════════════════════════
def clean_text(text: str) -> str:
text = text.strip().strip('"')
text = re.sub(r"[\n\t\r]", " ", text)
text = re.sub(r"<[^>]+>", "", text)
text = text.encode("ascii", "ignore").decode("ascii")
text = re.sub(r" +", " ", text)
return text.strip()
def classify_text(text: str) -> dict:
model = _state["clf_model"]
tokenizer = _state["tokenizer"]
device = _state["device"]
le_dept = _state["le_dept"]
le_prio = _state["le_prio"]
cleaned = clean_text(text)
if not cleaned:
return {"department": "unknown", "priority": "unknown",
"department_confidence": 0.0, "priority_confidence": 0.0}
inputs = tokenizer(
cleaned,
truncation=True,
padding="max_length",
max_length=Config.CLASSIFIER_MAX_LEN,
return_tensors="pt",
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask)
dept_probs = F.softmax(outputs["department_logits"], dim=1).cpu().squeeze()
prio_probs = F.softmax(outputs["priority_logits"], dim=1).cpu().squeeze()
dept_idx = dept_probs.argmax().item()
prio_idx = prio_probs.argmax().item()
return {
"department": le_dept.inverse_transform([dept_idx])[0],
"priority": le_prio.inverse_transform([prio_idx])[0],
"department_confidence": round(float(dept_probs[dept_idx]), 4),
"priority_confidence": round(float(prio_probs[prio_idx]), 4),
}
# ═══════════════════════════════════════════════════════════════
# OCR + Analysis Processor
# ═══════════════════════════════════════════════════════════════
class DocumentProcessor:
def __init__(self, api_key: str = None):
try:
from groq import Groq
self.client = Groq(api_key=api_key or Config.GROQ_API_KEY)
except ImportError:
raise HTTPException(status_code=500, detail="Run: pip install groq")
self.document_text = ""
self.num_pages = 0
self.file_size = 0
def _pdf_to_images(self, pdf_path: str) -> List[str]:
try:
import fitz
except ImportError:
raise HTTPException(status_code=500, detail="Run: pip install pymupdf")
doc = fitz.open(pdf_path)
self.num_pages = len(doc)
images = []
for i in range(len(doc)):
pix = doc.load_page(i).get_pixmap()
images.append(base64.b64encode(pix.tobytes("png")).decode("utf-8"))
doc.close()
return images
def _image_to_b64(self, path: str) -> str:
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
def _ocr_page(self, b64_img: str, page_num: int) -> str:
response = self.client.chat.completions.create(
model=Config.GROQ_MODEL,
messages=[{
"role": "user",
"content": [
{
"type": "text",
"text": (
f"You are an expert OCR engine specialized in Arabic and mixed Arabic/English documents. Page {page_num}.\n\n"
"STRICT RULES:\n"
"1. Extract ALL text exactly as it appears — Arabic, English, and numbers.\n"
"2. Arabic text: preserve RIGHT-TO-LEFT order, copy every word exactly.\n"
"3. Numbers: copy exactly as shown (Arabic-Indic ١٢٣ or Western 123).\n"
"4. Tables: reconstruct each row on one line using | as column separator.\n"
"5. Mixed lines (Arabic + English + numbers): preserve the full line as-is.\n"
"6. Do NOT translate, summarize, reorder, or skip any text.\n"
"7. Do NOT add commentary, headers, or any text not visible on the page.\n"
"8. Empty page: output only [NO TEXT].\n\n"
"Output the raw extracted text now:"
)
},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_img}"}}
]
}],
temperature=0,
max_tokens=4000
)
return response.choices[0].message.content or ""
def _clean_ocr(self, text: str) -> str:
bad = [
r"^###", r"^```",
r"لقد قمت", r"النص المستخرج", r"استخلاص",
r"^Here is the extracted", r"^I (can see|found|analyzed)",
]
lines = text.split("\n")
return "\n".join(
l for l in lines if not any(re.search(p, l.strip()) for p in bad)
)
async def _ocr_all_pages(self, images: List[str]) -> str:
all_text = ""
for i, img in enumerate(images):
page_num = i + 1
logger.info(f"OCR page {page_num}/{len(images)}")
try:
page_text = self._ocr_page(img, page_num)
all_text += f"\n\n=== Page {page_num} ===\n{self._clean_ocr(page_text)}"
except Exception as e:
logger.error(f"Page {page_num} failed: {e}")
all_text += f"\n\n=== Page {page_num} ===\n[EXTRACTION FAILED]"
return all_text
def _groq(self, system: str, user: str, max_tokens: int = 500) -> str:
response = self.client.chat.completions.create(
model=Config.GROQ_MODEL,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user}
],
temperature=0,
max_tokens=max_tokens
)
return response.choices[0].message.content.strip()
def _parse_json(self, raw: str):
for marker in ["```json", "```"]:
if marker in raw:
raw = raw.split(marker)[1].split("```")[0].strip()
break
return json.loads(raw)
async def get_recipient(self) -> RecipientInfo:
today = datetime.now().strftime("%Y-%m-%d")
try:
answer = self._groq(
system="Document analysis assistant. Respond with valid JSON only.",
user=(
f"Extract recipient and date from this document.\n\n"
f"--- TEXT ---\n{self.document_text[:2000]}\n--- END ---\n\n"
"RECIPIENT: person/org this is addressed TO. If not found → null\n"
f"DATE: document date in YYYY-MM-DD. If not found → {today}\n"
'Return ONLY: {"name": "...", "date": "YYYY-MM-DD"}'
),
max_tokens=200
)
info = self._parse_json(answer)
name = info.get("name")
found = bool(name and name not in [None, "null", "", "غير محدد"])
date = info.get("date", today)
if not re.match(r"\d{4}-\d{2}-\d{2}", str(date)):
date = today
return RecipientInfo(name=name if found else None, date=date, found=found)
except Exception as e:
logger.warning(f"Recipient failed: {e}")
return RecipientInfo(name=None, date=today, found=False)
async def get_entities(self) -> List[Entity]:
try:
answer = self._groq(
system="NER expert. Return ONLY a valid JSON array, no extra text.",
user=(
f"Extract named entities:\n\n{self.document_text[:3000]}\n\n"
"Types: PERSON_NAME, ORGANIZATION, LOCATION, DATE, REFERENCE_NUMBER, PHONE, EMAIL, AMOUNT\n"
'Return: [{"type": "TYPE", "value": "value"}, ...]'
)
)
data = self._parse_json(answer)
return [Entity(**e) for e in data] if isinstance(data, list) else []
except Exception as e:
logger.warning(f"Entities failed: {e}")
return []
async def get_summary(self, language: str) -> str:
try:
if language == "arabic":
prompt = f"لخّص الوثيقة التالية باللغة العربية الفصحى في فقرة أو اثنتين:\n\n{self.document_text[:5000]}"
else:
prompt = f"Summarize this document in 1-2 paragraphs:\n\n{self.document_text[:5000]}"
return self._groq(system="Document summarizer.", user=prompt, max_tokens=500)
except Exception as e:
logger.warning(f"Summary failed: {e}")
return ""
def detect_language(self) -> str:
arabic = sum(1 for c in self.document_text if "\u0600" <= c <= "\u06FF")
english = sum(1 for c in self.document_text if "a" <= c.lower() <= "z")
return "arabic" if arabic > english else "english"
async def translate_to_english(self, text: str) -> str:
try:
return self._groq(
system="You are a translator. Return ONLY the English translation, no explanation, no extra text.",
user=f"Translate the following text to English:\n\n{text}",
max_tokens=600
)
except Exception as e:
logger.warning(f"Translation failed: {e}")
return text
async def agent_review_department(self, clf: dict) -> AgentReview:
departments = [
"business_development", "customer_support", "financial_accounting",
"hr_department", "it_department", "legal"
]
try:
dept = clf["department"]
conf = clf["department_confidence"] * 100
prompt = (
f"An AI model classified this document as '{dept}' with confidence {conf:.1f}%.\n\n"
f"--- DOCUMENT TEXT ---\n{self.document_text[:2000]}\n--- END ---\n\n"
f"Available departments: {departments}\n\n"
"Do you agree? If not, suggest the correct department.\n"
'Return ONLY: {"agent_agrees": true, "final_department": "...", "reasoning": "..."}'
)
answer = self._groq(
system=(
"You are a document routing expert. Verify or correct the department classification. "
"Respond with valid JSON only, no extra text."
),
user=prompt,
max_tokens=300
)
data = self._parse_json(answer)
agrees = bool(data.get("agent_agrees", True))
final = data.get("final_department", dept)
if final not in departments:
final = dept
return AgentReview(
triggered=True,
agent_agrees=agrees,
final_department=final,
reasoning=data.get("reasoning", "")
)
except Exception as e:
logger.warning(f"Agent review failed: {e}")
return AgentReview(
triggered=True,
agent_agrees=True,
final_department=clf["department"],
reasoning="Agent review failed, keeping model decision."
)
async def process(self, file_path: str) -> DocumentResult:
self.file_size = os.path.getsize(file_path)
ext = os.path.splitext(file_path)[1].lower()
# Step 1: OCR
if ext == ".pdf":
images = self._pdf_to_images(file_path)
else:
self.num_pages = 1
images = [self._image_to_b64(file_path)]
if not images:
raise HTTPException(status_code=400, detail="CONVERSION_FAILED")
self.document_text = await self._ocr_all_pages(images)
if not self.document_text or len(self.document_text.strip()) < 10:
raise HTTPException(status_code=400, detail="EXTRACTION_FAILED")
# Step 2: Analyze
language = self.detect_language()
recipient = await self.get_recipient()
entities = await self.get_entities()
summary = await self.get_summary(language)
# Step 3: Translate if Arabic then classify
clf_input = self.document_text[:500]
if language == "arabic":
logger.info("[translate] Arabic detected, translating before classification...")
clf_input = await self.translate_to_english(clf_input)
logger.info("[translate] Done.")
clf = classify_text(clf_input)
# Step 4: Agent review if confidence is low
agent_review = None
final_department = clf["department"]
if clf["department_confidence"] < Config.CONFIDENCE_THRESHOLD:
logger.info(f"[agent] Low confidence ({clf['department_confidence']:.2f}), triggering agent review...")
agent_review = await self.agent_review_department(clf)
final_department = agent_review.final_department
logger.info(f"[agent] {clf['department']}{final_department} (agrees: {agent_review.agent_agrees})")
else:
logger.info(f"[agent] High confidence ({clf['department_confidence']:.2f}), skipping.")
return DocumentResult(
raw_text = self.document_text,
summary = summary,
language = language,
entities = entities,
recipient = recipient,
department = final_department,
priority = clf["priority"],
department_confidence = clf["department_confidence"],
priority_confidence = clf["priority_confidence"],
agent_review = agent_review,
route = not recipient.found,
pages = self.num_pages,
file_type = ext.upper().replace(".", ""),
file_size_bytes = self.file_size,
processed_at = datetime.now().isoformat(),
model_ocr = Config.GROQ_MODEL,
model_classifier = "RoBERTa fine-tuned"
)
# ═══════════════════════════════════════════════════════════════
# Lifespan — load model from HuggingFace Hub
# ═══════════════════════════════════════════════════════════════
@asynccontextmanager
async def lifespan(app: FastAPI):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"[startup] device = {device}")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(BASE_DIR, "model_last.pt")
le_dept_path = os.path.join(BASE_DIR, "label_encoder.pkl")
le_prio_path = os.path.join(BASE_DIR, "priority_encoder.pkl")
logger.info(f"[startup] loading model from local files...")
tokenizer = AutoTokenizer.from_pretrained(BASE_DIR)
le_dept = joblib.load(le_dept_path)
le_prio = joblib.load(le_prio_path)
ckpt = torch.load(model_path, map_location=device, weights_only=False)
model = RoBertMultiOutput(len(le_dept.classes_), len(le_prio.classes_))
model.load_state_dict(ckpt["model_state_dict"], strict=False)
model.to(device).eval()
_state.update(
clf_model=model,
tokenizer=tokenizer,
le_dept=le_dept,
le_prio=le_prio,
device=device,
)
logger.info(f"[startup] departments : {list(le_dept.classes_)}")
logger.info(f"[startup] priorities : {list(le_prio.classes_)}")
yield
_state.clear()
logger.info("[shutdown] resources released.")
# ═══════════════════════════════════════════════════════════════
# FastAPI App
# ═══════════════════════════════════════════════════════════════
app = FastAPI(
title="Document Processing API",
description=(
"**One endpoint** combining:\n\n"
"1. OCR — extract text from PDF/images using Groq llama-4-scout\n"
"2. Classification — department + priority using fine-tuned RoBERTa\n"
"3. Routing — decides if manual routing is needed\n\n"
"Upload any PDF or image and get a unified JSON response."
),
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ═══════════════════════════════════════════════════════════════
# Routes
# ═══════════════════════════════════════════════════════════════
@app.get("/", tags=["Info"])
def root():
return {"message": "Document Processing API", "docs": "/docs"}
@app.get("/health", response_model=HealthResponse, tags=["Info"])
def health():
return HealthResponse(
status="healthy",
timestamp=datetime.now().isoformat(),
ocr_model=Config.GROQ_MODEL,
classifier_model="RoBERTa fine-tuned"
)
@app.post("/api/v1/process", response_model=SuccessResponse, tags=["Process"])
async def process_document(
file: UploadFile = File(...),
x_groq_api_key: Optional[str] = Header(None, alias="X-Groq-Api-Key")
):
"""
Upload a PDF or image → returns unified JSON with:
raw_text, summary, entities, recipient, department, priority, route
**Header required:** `X-Groq-Api-Key: your_groq_api_key`
"""
temp_path = None
try:
if not x_groq_api_key:
raise HTTPException(status_code=401, detail="MISSING_GROQ_API_KEY: Add X-Groq-Api-Key header")
if not file.filename:
raise HTTPException(status_code=400, detail="EMPTY_FILENAME")
ext = os.path.splitext(file.filename)[1].lower().replace(".", "")
if ext not in Config.ALLOWED_EXT:
raise HTTPException(status_code=400, detail="INVALID_FILE_TYPE")
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_path = os.path.join(Config.UPLOAD_FOLDER, f"{ts}_{file.filename}")
with open(temp_path, "wb") as buf:
shutil.copyfileobj(file.file, buf)
if os.path.getsize(temp_path) > Config.MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="FILE_TOO_LARGE")
processor = DocumentProcessor(api_key=x_groq_api_key)
result = await processor.process(temp_path)
return SuccessResponse(success=True, data=result)
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail="INTERNAL_SERVER_ERROR")
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except Exception:
pass
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse(
status_code=exc.status_code,
content={"success": False, "error": exc.detail, "data": None}
)
# ═══════════════════════════════════════════════════════════════
# Run
# ═══════════════════════════════════════════════════════════════
if __name__ == "__main__":
import uvicorn
import sys
import pathlib
if sys.platform == "win32":
sys.stdout.reconfigure(encoding="utf-8")
module_name = pathlib.Path(__file__).stem
uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, reload=True)