manarsaber11 commited on
Commit
4eda73b
·
verified ·
1 Parent(s): ad6b675

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +12 -0
  2. requirements.txt +12 -0
  3. unified_api.py +622 -0
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ EXPOSE 7860
11
+
12
+ CMD ["uvicorn", "unified_api:app", "--host", "0.0.0.0", "--port", "7860"]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.111.0
2
+ uvicorn[standard]>=0.30.0
3
+ transformers>=4.40.0
4
+ torch>=2.1.0
5
+ scikit-learn>=1.3.0
6
+ joblib>=1.3.0
7
+ pydantic>=2.0.0
8
+ python-multipart
9
+ python-dotenv
10
+ groq
11
+ pymupdf
12
+ huggingface_hub
unified_api.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Document Processing API
3
+ OCR (Groq llama-4-scout) + Classification (RoBERTa) in one endpoint
4
+ Loads model from HuggingFace Hub
5
+ """
6
+ from dotenv import load_dotenv
7
+ load_dotenv()
8
+
9
+ import os
10
+ import re
11
+ import json
12
+ import logging
13
+ import base64
14
+ import shutil
15
+ import torch
16
+ import torch.nn as nn
17
+ import joblib
18
+ from datetime import datetime
19
+ from contextlib import asynccontextmanager
20
+ from typing import Optional, List
21
+
22
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Header
23
+ from fastapi.responses import JSONResponse
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+ from pydantic import BaseModel
26
+ from transformers import AutoTokenizer, RobertaModel
27
+ from huggingface_hub import hf_hub_download
28
+ import torch.nn.functional as F
29
+
30
+
31
+ # ═══════════════════════════════════════════════════════════════
32
+ # Config
33
+ # ═══════════════════════════════════════════════════════════════
34
+ class Config:
35
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", "YOUR_API_KEY")
36
+ GROQ_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
37
+ HF_REPO_ID = "manarsaber11/enterprise-classifier"
38
+ MAX_FILE_SIZE = 50 * 1024 * 1024
39
+ ALLOWED_EXT = {"pdf", "jpg", "jpeg", "png", "gif", "bmp"}
40
+ UPLOAD_FOLDER = "uploads"
41
+ CLASSIFIER_MAX_LEN = 320
42
+ CONFIDENCE_THRESHOLD = 0.85
43
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
44
+
45
+
46
+ logging.basicConfig(
47
+ level=logging.INFO,
48
+ format="%(asctime)s - %(levelname)s - %(message)s",
49
+ handlers=[logging.FileHandler("api.log"), logging.StreamHandler()]
50
+ )
51
+ logger = logging.getLogger(__name__)
52
+
53
+
54
+ # ═══════════════════════════════════════════════════════════════
55
+ # RoBERTa Model Class
56
+ # ═══════════════════════════════════════════════════════════════
57
+ class RoBertMultiOutput(nn.Module):
58
+ def __init__(self, num_department, num_priorities, department_weights=None, priority_weights=None):
59
+ super().__init__()
60
+ self.bert = RobertaModel.from_pretrained("roberta-base")
61
+ self.dropout = nn.Dropout(0.3)
62
+ self.department_classifier = nn.Linear(768, num_department)
63
+ self.priority_head = nn.Sequential(
64
+ nn.Linear(768, 256),
65
+ nn.ReLU(),
66
+ nn.Dropout(0.3),
67
+ nn.Linear(256, num_priorities)
68
+ )
69
+ self.department_loss_fn = nn.CrossEntropyLoss(weight=department_weights)
70
+ self.priority_loss_fn = nn.CrossEntropyLoss(weight=priority_weights)
71
+
72
+ def forward(self, input_ids, attention_mask, department=None, priority=None):
73
+ output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
74
+ pooled = self.dropout(output.pooler_output)
75
+ department_logits = self.department_classifier(pooled)
76
+ priority_logits = self.priority_head(pooled)
77
+ loss = None
78
+ if department is not None and priority is not None:
79
+ loss = self.department_loss_fn(department_logits, department) + \
80
+ 2.0 * self.priority_loss_fn(priority_logits, priority)
81
+ return {"loss": loss, "department_logits": department_logits, "priority_logits": priority_logits}
82
+
83
+
84
+ # ═══════════════════════════════════════════════════════════════
85
+ # Global state
86
+ # ═══════════════════════════════════════════════════════════════
87
+ _state: dict = {}
88
+
89
+
90
+ # ═══════════════════════════════════════════════════════════════
91
+ # Pydantic Schemas
92
+ # ═══════════════════════════════════════════════════════════════
93
+ class Entity(BaseModel):
94
+ type: str
95
+ value: str
96
+
97
+ class RecipientInfo(BaseModel):
98
+ name: Optional[str] = None
99
+ date: str
100
+ found: bool
101
+
102
+ class AgentReview(BaseModel):
103
+ triggered: bool
104
+ agent_agrees: bool
105
+ final_department: str
106
+ reasoning: str
107
+
108
+ class DocumentResult(BaseModel):
109
+ raw_text: str
110
+ summary: str
111
+ language: str
112
+ entities: List[Entity] = []
113
+ recipient: RecipientInfo
114
+ department: str
115
+ priority: str
116
+ department_confidence: float
117
+ priority_confidence: float
118
+ agent_review: Optional[AgentReview] = None
119
+ route: bool
120
+ pages: int
121
+ file_type: str
122
+ file_size_bytes: int
123
+ processed_at: str
124
+ model_ocr: str
125
+ model_classifier: str
126
+
127
+ class SuccessResponse(BaseModel):
128
+ success: bool = True
129
+ error: Optional[str] = None
130
+ data: Optional[DocumentResult] = None
131
+
132
+ class HealthResponse(BaseModel):
133
+ status: str
134
+ timestamp: str
135
+ ocr_model: str
136
+ classifier_model: str
137
+
138
+
139
+ # ═══════════════════════════════════════════════════════════════
140
+ # Helpers
141
+ # ═══════════════════════════════════════════════════════════════
142
+ def clean_text(text: str) -> str:
143
+ text = text.strip().strip('"')
144
+ text = re.sub(r"[\n\t\r]", " ", text)
145
+ text = re.sub(r"<[^>]+>", "", text)
146
+ text = text.encode("ascii", "ignore").decode("ascii")
147
+ text = re.sub(r" +", " ", text)
148
+ return text.strip()
149
+
150
+
151
+ def classify_text(text: str) -> dict:
152
+ model = _state["clf_model"]
153
+ tokenizer = _state["tokenizer"]
154
+ device = _state["device"]
155
+ le_dept = _state["le_dept"]
156
+ le_prio = _state["le_prio"]
157
+
158
+ cleaned = clean_text(text)
159
+ if not cleaned:
160
+ return {"department": "unknown", "priority": "unknown",
161
+ "department_confidence": 0.0, "priority_confidence": 0.0}
162
+
163
+ inputs = tokenizer(
164
+ cleaned,
165
+ truncation=True,
166
+ padding="max_length",
167
+ max_length=Config.CLASSIFIER_MAX_LEN,
168
+ return_tensors="pt",
169
+ )
170
+ input_ids = inputs["input_ids"].to(device)
171
+ attention_mask = inputs["attention_mask"].to(device)
172
+
173
+ with torch.no_grad():
174
+ outputs = model(input_ids, attention_mask)
175
+
176
+ dept_probs = F.softmax(outputs["department_logits"], dim=1).cpu().squeeze()
177
+ prio_probs = F.softmax(outputs["priority_logits"], dim=1).cpu().squeeze()
178
+
179
+ dept_idx = dept_probs.argmax().item()
180
+ prio_idx = prio_probs.argmax().item()
181
+
182
+ return {
183
+ "department": le_dept.inverse_transform([dept_idx])[0],
184
+ "priority": le_prio.inverse_transform([prio_idx])[0],
185
+ "department_confidence": round(float(dept_probs[dept_idx]), 4),
186
+ "priority_confidence": round(float(prio_probs[prio_idx]), 4),
187
+ }
188
+
189
+
190
+ # ═══════════════════════════════════════════════════════════════
191
+ # OCR + Analysis Processor
192
+ # ═══════════════════════════════════════════════════════════════
193
+ class DocumentProcessor:
194
+
195
+ def __init__(self, api_key: str = None):
196
+ try:
197
+ from groq import Groq
198
+ self.client = Groq(api_key=api_key or Config.GROQ_API_KEY)
199
+ except ImportError:
200
+ raise HTTPException(status_code=500, detail="Run: pip install groq")
201
+ self.document_text = ""
202
+ self.num_pages = 0
203
+ self.file_size = 0
204
+
205
+ def _pdf_to_images(self, pdf_path: str) -> List[str]:
206
+ try:
207
+ import fitz
208
+ except ImportError:
209
+ raise HTTPException(status_code=500, detail="Run: pip install pymupdf")
210
+ doc = fitz.open(pdf_path)
211
+ self.num_pages = len(doc)
212
+ images = []
213
+ for i in range(len(doc)):
214
+ pix = doc.load_page(i).get_pixmap()
215
+ images.append(base64.b64encode(pix.tobytes("png")).decode("utf-8"))
216
+ doc.close()
217
+ return images
218
+
219
+ def _image_to_b64(self, path: str) -> str:
220
+ with open(path, "rb") as f:
221
+ return base64.b64encode(f.read()).decode("utf-8")
222
+
223
+ def _ocr_page(self, b64_img: str, page_num: int) -> str:
224
+ response = self.client.chat.completions.create(
225
+ model=Config.GROQ_MODEL,
226
+ messages=[{
227
+ "role": "user",
228
+ "content": [
229
+ {
230
+ "type": "text",
231
+ "text": (
232
+ f"You are an expert OCR engine specialized in Arabic and mixed Arabic/English documents. Page {page_num}.\n\n"
233
+ "STRICT RULES:\n"
234
+ "1. Extract ALL text exactly as it appears — Arabic, English, and numbers.\n"
235
+ "2. Arabic text: preserve RIGHT-TO-LEFT order, copy every word exactly.\n"
236
+ "3. Numbers: copy exactly as shown (Arabic-Indic ١٢٣ or Western 123).\n"
237
+ "4. Tables: reconstruct each row on one line using | as column separator.\n"
238
+ "5. Mixed lines (Arabic + English + numbers): preserve the full line as-is.\n"
239
+ "6. Do NOT translate, summarize, reorder, or skip any text.\n"
240
+ "7. Do NOT add commentary, headers, or any text not visible on the page.\n"
241
+ "8. Empty page: output only [NO TEXT].\n\n"
242
+ "Output the raw extracted text now:"
243
+ )
244
+ },
245
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_img}"}}
246
+ ]
247
+ }],
248
+ temperature=0,
249
+ max_tokens=4000
250
+ )
251
+ return response.choices[0].message.content or ""
252
+
253
+ def _clean_ocr(self, text: str) -> str:
254
+ bad = [
255
+ r"^###", r"^```",
256
+ r"لقد قمت", r"النص المستخرج", r"استخلاص",
257
+ r"^Here is the extracted", r"^I (can see|found|analyzed)",
258
+ ]
259
+ lines = text.split("\n")
260
+ return "\n".join(
261
+ l for l in lines if not any(re.search(p, l.strip()) for p in bad)
262
+ )
263
+
264
+ async def _ocr_all_pages(self, images: List[str]) -> str:
265
+ all_text = ""
266
+ for i, img in enumerate(images):
267
+ page_num = i + 1
268
+ logger.info(f"OCR page {page_num}/{len(images)}")
269
+ try:
270
+ page_text = self._ocr_page(img, page_num)
271
+ all_text += f"\n\n=== Page {page_num} ===\n{self._clean_ocr(page_text)}"
272
+ except Exception as e:
273
+ logger.error(f"Page {page_num} failed: {e}")
274
+ all_text += f"\n\n=== Page {page_num} ===\n[EXTRACTION FAILED]"
275
+ return all_text
276
+
277
+ def _groq(self, system: str, user: str, max_tokens: int = 500) -> str:
278
+ response = self.client.chat.completions.create(
279
+ model=Config.GROQ_MODEL,
280
+ messages=[
281
+ {"role": "system", "content": system},
282
+ {"role": "user", "content": user}
283
+ ],
284
+ temperature=0,
285
+ max_tokens=max_tokens
286
+ )
287
+ return response.choices[0].message.content.strip()
288
+
289
+ def _parse_json(self, raw: str):
290
+ for marker in ["```json", "```"]:
291
+ if marker in raw:
292
+ raw = raw.split(marker)[1].split("```")[0].strip()
293
+ break
294
+ return json.loads(raw)
295
+
296
+ async def get_recipient(self) -> RecipientInfo:
297
+ today = datetime.now().strftime("%Y-%m-%d")
298
+ try:
299
+ answer = self._groq(
300
+ system="Document analysis assistant. Respond with valid JSON only.",
301
+ user=(
302
+ f"Extract recipient and date from this document.\n\n"
303
+ f"--- TEXT ---\n{self.document_text[:2000]}\n--- END ---\n\n"
304
+ "RECIPIENT: person/org this is addressed TO. If not found → null\n"
305
+ f"DATE: document date in YYYY-MM-DD. If not found → {today}\n"
306
+ 'Return ONLY: {"name": "...", "date": "YYYY-MM-DD"}'
307
+ ),
308
+ max_tokens=200
309
+ )
310
+ info = self._parse_json(answer)
311
+ name = info.get("name")
312
+ found = bool(name and name not in [None, "null", "", "غير محدد"])
313
+ date = info.get("date", today)
314
+ if not re.match(r"\d{4}-\d{2}-\d{2}", str(date)):
315
+ date = today
316
+ return RecipientInfo(name=name if found else None, date=date, found=found)
317
+ except Exception as e:
318
+ logger.warning(f"Recipient failed: {e}")
319
+ return RecipientInfo(name=None, date=today, found=False)
320
+
321
+ async def get_entities(self) -> List[Entity]:
322
+ try:
323
+ answer = self._groq(
324
+ system="NER expert. Return ONLY a valid JSON array, no extra text.",
325
+ user=(
326
+ f"Extract named entities:\n\n{self.document_text[:3000]}\n\n"
327
+ "Types: PERSON_NAME, ORGANIZATION, LOCATION, DATE, REFERENCE_NUMBER, PHONE, EMAIL, AMOUNT\n"
328
+ 'Return: [{"type": "TYPE", "value": "value"}, ...]'
329
+ )
330
+ )
331
+ data = self._parse_json(answer)
332
+ return [Entity(**e) for e in data] if isinstance(data, list) else []
333
+ except Exception as e:
334
+ logger.warning(f"Entities failed: {e}")
335
+ return []
336
+
337
+ async def get_summary(self, language: str) -> str:
338
+ try:
339
+ if language == "arabic":
340
+ prompt = f"لخّص الوثيقة التالية باللغة العربية الفصحى في فقرة أو اثنتين:\n\n{self.document_text[:5000]}"
341
+ else:
342
+ prompt = f"Summarize this document in 1-2 paragraphs:\n\n{self.document_text[:5000]}"
343
+ return self._groq(system="Document summarizer.", user=prompt, max_tokens=500)
344
+ except Exception as e:
345
+ logger.warning(f"Summary failed: {e}")
346
+ return ""
347
+
348
+ def detect_language(self) -> str:
349
+ arabic = sum(1 for c in self.document_text if "\u0600" <= c <= "\u06FF")
350
+ english = sum(1 for c in self.document_text if "a" <= c.lower() <= "z")
351
+ return "arabic" if arabic > english else "english"
352
+
353
+ async def translate_to_english(self, text: str) -> str:
354
+ try:
355
+ return self._groq(
356
+ system="You are a translator. Return ONLY the English translation, no explanation, no extra text.",
357
+ user=f"Translate the following text to English:\n\n{text}",
358
+ max_tokens=600
359
+ )
360
+ except Exception as e:
361
+ logger.warning(f"Translation failed: {e}")
362
+ return text
363
+
364
+ async def agent_review_department(self, clf: dict) -> AgentReview:
365
+ departments = [
366
+ "business_development", "customer_support", "financial_accounting",
367
+ "hr_department", "it_department", "legal"
368
+ ]
369
+ try:
370
+ dept = clf["department"]
371
+ conf = clf["department_confidence"] * 100
372
+ prompt = (
373
+ f"An AI model classified this document as '{dept}' with confidence {conf:.1f}%.\n\n"
374
+ f"--- DOCUMENT TEXT ---\n{self.document_text[:2000]}\n--- END ---\n\n"
375
+ f"Available departments: {departments}\n\n"
376
+ "Do you agree? If not, suggest the correct department.\n"
377
+ 'Return ONLY: {"agent_agrees": true, "final_department": "...", "reasoning": "..."}'
378
+ )
379
+ answer = self._groq(
380
+ system=(
381
+ "You are a document routing expert. Verify or correct the department classification. "
382
+ "Respond with valid JSON only, no extra text."
383
+ ),
384
+ user=prompt,
385
+ max_tokens=300
386
+ )
387
+ data = self._parse_json(answer)
388
+ agrees = bool(data.get("agent_agrees", True))
389
+ final = data.get("final_department", dept)
390
+ if final not in departments:
391
+ final = dept
392
+ return AgentReview(
393
+ triggered=True,
394
+ agent_agrees=agrees,
395
+ final_department=final,
396
+ reasoning=data.get("reasoning", "")
397
+ )
398
+ except Exception as e:
399
+ logger.warning(f"Agent review failed: {e}")
400
+ return AgentReview(
401
+ triggered=True,
402
+ agent_agrees=True,
403
+ final_department=clf["department"],
404
+ reasoning="Agent review failed, keeping model decision."
405
+ )
406
+
407
+ async def process(self, file_path: str) -> DocumentResult:
408
+ self.file_size = os.path.getsize(file_path)
409
+ ext = os.path.splitext(file_path)[1].lower()
410
+
411
+ # Step 1: OCR
412
+ if ext == ".pdf":
413
+ images = self._pdf_to_images(file_path)
414
+ else:
415
+ self.num_pages = 1
416
+ images = [self._image_to_b64(file_path)]
417
+
418
+ if not images:
419
+ raise HTTPException(status_code=400, detail="CONVERSION_FAILED")
420
+
421
+ self.document_text = await self._ocr_all_pages(images)
422
+
423
+ if not self.document_text or len(self.document_text.strip()) < 10:
424
+ raise HTTPException(status_code=400, detail="EXTRACTION_FAILED")
425
+
426
+ # Step 2: Analyze
427
+ language = self.detect_language()
428
+ recipient = await self.get_recipient()
429
+ entities = await self.get_entities()
430
+ summary = await self.get_summary(language)
431
+
432
+ # Step 3: Translate if Arabic then classify
433
+ clf_input = self.document_text[:500]
434
+ if language == "arabic":
435
+ logger.info("[translate] Arabic detected, translating before classification...")
436
+ clf_input = await self.translate_to_english(clf_input)
437
+ logger.info("[translate] Done.")
438
+ clf = classify_text(clf_input)
439
+
440
+ # Step 4: Agent review if confidence is low
441
+ agent_review = None
442
+ final_department = clf["department"]
443
+
444
+ if clf["department_confidence"] < Config.CONFIDENCE_THRESHOLD:
445
+ logger.info(f"[agent] Low confidence ({clf['department_confidence']:.2f}), triggering agent review...")
446
+ agent_review = await self.agent_review_department(clf)
447
+ final_department = agent_review.final_department
448
+ logger.info(f"[agent] {clf['department']} → {final_department} (agrees: {agent_review.agent_agrees})")
449
+ else:
450
+ logger.info(f"[agent] High confidence ({clf['department_confidence']:.2f}), skipping.")
451
+
452
+ return DocumentResult(
453
+ raw_text = self.document_text,
454
+ summary = summary,
455
+ language = language,
456
+ entities = entities,
457
+ recipient = recipient,
458
+ department = final_department,
459
+ priority = clf["priority"],
460
+ department_confidence = clf["department_confidence"],
461
+ priority_confidence = clf["priority_confidence"],
462
+ agent_review = agent_review,
463
+ route = not recipient.found,
464
+ pages = self.num_pages,
465
+ file_type = ext.upper().replace(".", ""),
466
+ file_size_bytes = self.file_size,
467
+ processed_at = datetime.now().isoformat(),
468
+ model_ocr = Config.GROQ_MODEL,
469
+ model_classifier = "RoBERTa fine-tuned"
470
+ )
471
+
472
+
473
+ # ═══════════════════════════════════════════════════════════════
474
+ # Lifespan — load model from HuggingFace Hub
475
+ # ═══════════════════════════════════════════════════════════════
476
+ @asynccontextmanager
477
+ async def lifespan(app: FastAPI):
478
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
479
+ logger.info(f"[startup] device = {device}")
480
+ logger.info(f"[startup] downloading model from HuggingFace: {Config.HF_REPO_ID}")
481
+
482
+ # Download files from HF Hub
483
+ model_path = hf_hub_download(repo_id=Config.HF_REPO_ID, filename="model_last.pt")
484
+ le_dept_path = hf_hub_download(repo_id=Config.HF_REPO_ID, filename="label_encoder.pkl")
485
+ le_prio_path = hf_hub_download(repo_id=Config.HF_REPO_ID, filename="priority_encoder.pkl")
486
+
487
+ tokenizer = AutoTokenizer.from_pretrained(Config.HF_REPO_ID)
488
+ le_dept = joblib.load(le_dept_path)
489
+ le_prio = joblib.load(le_prio_path)
490
+
491
+ ckpt = torch.load(model_path, map_location=device, weights_only=False)
492
+ model = RoBertMultiOutput(len(le_dept.classes_), len(le_prio.classes_))
493
+ model.load_state_dict(ckpt["model_state_dict"], strict=False)
494
+ model.to(device).eval()
495
+
496
+ _state.update(
497
+ clf_model=model,
498
+ tokenizer=tokenizer,
499
+ le_dept=le_dept,
500
+ le_prio=le_prio,
501
+ device=device,
502
+ )
503
+ logger.info(f"[startup] departments : {list(le_dept.classes_)}")
504
+ logger.info(f"[startup] priorities : {list(le_prio.classes_)}")
505
+ yield
506
+ _state.clear()
507
+ logger.info("[shutdown] resources released.")
508
+
509
+
510
+ # ═══════════════════════════════════════════════════════════════
511
+ # FastAPI App
512
+ # ═══════════════════════════════════════════════════════════════
513
+ app = FastAPI(
514
+ title="Document Processing API",
515
+ description=(
516
+ "**One endpoint** combining:\n\n"
517
+ "1. OCR — extract text from PDF/images using Groq llama-4-scout\n"
518
+ "2. Classification — department + priority using fine-tuned RoBERTa\n"
519
+ "3. Routing — decides if manual routing is needed\n\n"
520
+ "Upload any PDF or image and get a unified JSON response."
521
+ ),
522
+ version="1.0.0",
523
+ lifespan=lifespan,
524
+ )
525
+
526
+ app.add_middleware(
527
+ CORSMiddleware,
528
+ allow_origins=["*"],
529
+ allow_methods=["*"],
530
+ allow_headers=["*"],
531
+ )
532
+
533
+
534
+ # ═══════════════════════════════════════════════════════════════
535
+ # Routes
536
+ # ═══════════════════════════════════════════════════════════════
537
+ @app.get("/", tags=["Info"])
538
+ def root():
539
+ return {"message": "Document Processing API", "docs": "/docs"}
540
+
541
+
542
+ @app.get("/health", response_model=HealthResponse, tags=["Info"])
543
+ def health():
544
+ return HealthResponse(
545
+ status="healthy",
546
+ timestamp=datetime.now().isoformat(),
547
+ ocr_model=Config.GROQ_MODEL,
548
+ classifier_model="RoBERTa fine-tuned"
549
+ )
550
+
551
+
552
+ @app.post("/api/v1/process", response_model=SuccessResponse, tags=["Process"])
553
+ async def process_document(
554
+ file: UploadFile = File(...),
555
+ x_groq_api_key: Optional[str] = Header(None, alias="X-Groq-Api-Key")
556
+ ):
557
+ """
558
+ Upload a PDF or image → returns unified JSON with:
559
+ raw_text, summary, entities, recipient, department, priority, route
560
+
561
+ **Header required:** `X-Groq-Api-Key: your_groq_api_key`
562
+ """
563
+ temp_path = None
564
+ try:
565
+ if not x_groq_api_key:
566
+ raise HTTPException(status_code=401, detail="MISSING_GROQ_API_KEY: Add X-Groq-Api-Key header")
567
+
568
+ if not file.filename:
569
+ raise HTTPException(status_code=400, detail="EMPTY_FILENAME")
570
+
571
+ ext = os.path.splitext(file.filename)[1].lower().replace(".", "")
572
+ if ext not in Config.ALLOWED_EXT:
573
+ raise HTTPException(status_code=400, detail="INVALID_FILE_TYPE")
574
+
575
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
576
+ temp_path = os.path.join(Config.UPLOAD_FOLDER, f"{ts}_{file.filename}")
577
+
578
+ with open(temp_path, "wb") as buf:
579
+ shutil.copyfileobj(file.file, buf)
580
+
581
+ if os.path.getsize(temp_path) > Config.MAX_FILE_SIZE:
582
+ raise HTTPException(status_code=413, detail="FILE_TOO_LARGE")
583
+
584
+ processor = DocumentProcessor(api_key=x_groq_api_key)
585
+ result = await processor.process(temp_path)
586
+
587
+ return SuccessResponse(success=True, data=result)
588
+
589
+ except HTTPException:
590
+ raise
591
+ except Exception as e:
592
+ logger.error(f"Unexpected error: {e}")
593
+ raise HTTPException(status_code=500, detail="INTERNAL_SERVER_ERROR")
594
+ finally:
595
+ if temp_path and os.path.exists(temp_path):
596
+ try:
597
+ os.remove(temp_path)
598
+ except Exception:
599
+ pass
600
+
601
+
602
+ @app.exception_handler(HTTPException)
603
+ async def http_exception_handler(request, exc):
604
+ return JSONResponse(
605
+ status_code=exc.status_code,
606
+ content={"success": False, "error": exc.detail, "data": None}
607
+ )
608
+
609
+
610
+ # ═══════════════════════════════════════════════════════════════
611
+ # Run
612
+ # ═══════════════════════════════════════════════════════════════
613
+ if __name__ == "__main__":
614
+ import uvicorn
615
+ import sys
616
+ import pathlib
617
+
618
+ if sys.platform == "win32":
619
+ sys.stdout.reconfigure(encoding="utf-8")
620
+
621
+ module_name = pathlib.Path(__file__).stem
622
+ uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, reload=True)