"""Document classification using BERT-tiny model.""" import os from pathlib import Path from typing import List, Dict, Optional from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F import numpy as np # Model configuration MODEL_NAME = "prajjwal1/bert-tiny" # Models directory: use /app/Model in Docker, or project_root/Model locally # Check if we're in Docker by looking for /app directory if Path("/app").exists() and Path("/app/backend").exists(): # Docker environment MODELS_DIR = Path("/app/Model") else: # Local development - go up from backend/app/classifier.py to project root MODELS_DIR = Path(__file__).resolve().parent.parent.parent / "Model" MODEL_PATH = MODELS_DIR / "bert-tiny" # Common document types with descriptions and keywords for better classification DOCUMENT_TYPES = { "invoice": { "description": "A document requesting payment for goods or services provided, containing itemized charges, totals, and payment terms.", "keywords": ["invoice", "bill", "amount due", "total", "subtotal", "tax", "payment terms", "invoice number", "invoice date", "due date", "itemized", "charges", "balance", "payable", "vendor", "billing"] }, "receipt": { "description": "A document confirming payment has been received, showing transaction details and proof of purchase.", "keywords": ["receipt", "payment received", "paid", "thank you", "transaction", "purchase", "payment confirmation", "receipt number", "date of purchase", "amount paid"] }, "contract": { "description": "A legally binding agreement between parties outlining terms, conditions, obligations, and signatures.", "keywords": ["contract", "agreement", "terms", "party", "signature", "effective date", "parties", "whereas", "hereby", "obligations", "rights", "termination", "breach"] }, "resume": { "description": "A document summarizing a person's work experience, education, skills, and qualifications for job applications.", "keywords": ["resume", "cv", "curriculum vitae", "experience", "education", "skills", "employment", "work history", "qualifications", "objective", "references", "contact information"] }, "letter": { "description": "A formal or informal written correspondence addressed to a recipient with greetings and closing.", "keywords": ["dear", "sincerely", "yours", "letter", "correspondence", "regards", "best regards", "yours truly", "to whom it may concern", "date:", "subject:"] }, "report": { "description": "A structured document presenting analysis, findings, conclusions, and recommendations on a specific topic.", "keywords": ["report", "summary", "findings", "conclusion", "analysis", "recommendations", "executive summary", "introduction", "methodology", "results", "discussion"] }, "memo": { "description": "An internal business communication document with headers like To, From, Subject, and Date.", "keywords": ["memo", "memorandum", "to:", "from:", "subject:", "date:", "re:", "internal", "interoffice"] }, "email": { "description": "Electronic mail correspondence with headers showing sender, recipient, subject, and message content.", "keywords": ["from:", "to:", "subject:", "sent:", "email", "cc:", "bcc:", "reply to", "message id", "date sent"] }, "form": { "description": "A structured document with fields to be filled out, often requiring signatures and dates.", "keywords": ["form", "application", "please fill", "signature", "date", "please print", "complete", "fill out", "applicant", "fields"] }, "certificate": { "description": "An official document certifying completion, achievement, or qualification with certification details.", "keywords": ["certificate", "certified", "awarded", "this certifies", "certification", "certificate of", "issued", "certificate number"] }, "license": { "description": "An official document granting permission to perform certain activities, with license numbers and expiration dates.", "keywords": ["license", "licensed", "expires", "license number", "licensee", "licensing authority", "valid until", "license type", "permit"] }, "passport": { "description": "An official government document for international travel containing personal identification and nationality information.", "keywords": ["passport", "nationality", "date of birth", "passport number", "passport no", "country of issue", "expiry date", "place of birth", "issuing authority"] }, "medical record": { "description": "Healthcare documentation containing patient information, diagnoses, treatments, and medical history.", "keywords": ["medical", "diagnosis", "patient", "treatment", "prescription", "doctor", "physician", "symptoms", "medication", "health", "medical history", "patient id"] }, "bank statement": { "description": "A financial document from a bank showing account transactions, balances, deposits, and withdrawals.", "keywords": ["bank statement", "account statement", "statement of account", "account number", "account balance", "opening balance", "closing balance", "available balance", "statement period", "statement date", "start date balance", "transaction", "transactions", "deposit", "withdrawal", "debit", "credit", "checking account", "savings account", "account summary", "bank name", "routing number", "ending balance", "beginning balance", "total deposits", "total withdrawals", "service charge", "interest earned", "atm", "check", "checks", "transfer", "fee"] }, "tax document": { "description": "Tax-related paperwork such as W-2 forms, 1099 forms, tax returns, or IRS correspondence.", "keywords": ["tax", "irs", "income", "deduction", "w-2", "1099", "tax return", "federal tax", "social security", "withholding", "adjusted gross income", "taxable income"] }, "legal document": { "description": "Court documents, legal filings, contracts, or other documents related to legal proceedings or matters.", "keywords": ["legal", "court", "plaintiff", "defendant", "attorney", "lawyer", "case number", "filing", "petition", "motion", "order", "judgment", "legal counsel"] }, "academic paper": { "description": "A scholarly document with abstract, introduction, methodology, results, references, and citations.", "keywords": ["abstract", "introduction", "methodology", "references", "citation", "research", "study", "literature review", "hypothesis", "data analysis", "conclusion", "bibliography"] }, "presentation": { "description": "A document with slides, bullet points, or structured content for presenting information to an audience.", "keywords": ["slide", "presentation", "agenda", "overview", "bullet points", "powerpoint", "key points", "summary slide", "title slide"] }, "manual": { "description": "An instructional document providing step-by-step procedures, guidelines, or how-to information.", "keywords": ["manual", "instructions", "how to", "procedure", "steps", "guide", "tutorial", "user guide", "operation", "setup", "installation"] }, "quote": { "description": "A document providing a price estimate or quotation for goods or services before purchase.", "keywords": ["quote", "quotation", "estimate", "pricing", "quote number", "valid until", "quote date", "estimated cost", "price quote", "proposal"] }, "purchase order": { "description": "A commercial document issued by a buyer to a seller indicating types, quantities, and agreed prices for products or services.", "keywords": ["purchase order", "po number", "po#", "order number", "purchase", "order date", "ship to", "bill to", "quantity", "unit price", "po"] }, "insurance policy": { "description": "A document outlining insurance coverage, terms, premiums, and policy details.", "keywords": ["insurance", "policy", "policy number", "premium", "coverage", "insured", "beneficiary", "policyholder", "deductible", "claim", "insurance company"] }, "other": { "description": "A document that does not clearly fit into any of the above categories.", "keywords": [] } } class DocumentClassifier: """Class for classifying documents using BERT-tiny.""" def __init__(self): self.tokenizer = None self.model = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._load_model() self._precompute_type_embeddings() def _load_model(self): """Load the BERT-tiny model, downloading if necessary.""" try: # Check if model exists locally, otherwise download if MODEL_PATH.exists(): print(f"Loading model from local path: {MODEL_PATH}") model_path = str(MODEL_PATH) else: print(f"Downloading model {MODEL_NAME}...") model_path = MODEL_NAME # Create models directory MODELS_DIR.mkdir(parents=True, exist_ok=True) # Load tokenizer and model (using AutoModel for embeddings) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModel.from_pretrained(model_path) self.model.to(self.device) self.model.eval() # Save model locally if downloaded if not MODEL_PATH.exists(): print(f"Saving model to {MODEL_PATH}...") self.tokenizer.save_pretrained(str(MODEL_PATH)) self.model.save_pretrained(str(MODEL_PATH)) print("Model saved successfully!") except Exception as e: print(f"Error loading model: {e}") raise def _get_embedding(self, text: str, max_length: int = 512) -> torch.Tensor: """Get embedding for a text using BERT-tiny.""" inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=max_length, padding=True ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) # Use mean pooling of token embeddings embeddings = outputs.last_hidden_state.mean(dim=1) return embeddings def _precompute_type_embeddings(self): """Precompute embeddings for each document type description.""" print("Precomputing document type embeddings...") self.type_embeddings = {} for doc_type, doc_info in DOCUMENT_TYPES.items(): # Combine type name, description, and keywords for better representation description = doc_info["description"] keywords = " ".join(doc_info.get("keywords", [])) text = f"{doc_type}: {description} Keywords: {keywords}" embedding = self._get_embedding(text) self.type_embeddings[doc_type] = embedding print("Document type embeddings computed!") def _calculate_keyword_score(self, text: str, doc_type: str) -> float: """Calculate keyword matching score for a document type.""" text_lower = text.lower() doc_info = DOCUMENT_TYPES.get(doc_type, {}) keywords = doc_info.get("keywords", []) if not keywords: return 0.0 # Count keyword matches matches = sum(1 for keyword in keywords if keyword.lower() in text_lower) # Calculate score: matches / total keywords, with bonus for multiple matches base_score = matches / len(keywords) if keywords else 0.0 # Boost score if multiple keywords found (indicates stronger match) if matches > 0: boost = min(0.3, matches * 0.05) # Up to 30% boost base_score = min(1.0, base_score + boost) return base_score def classify_document(self, text: str, max_length: int = 512) -> Dict[str, any]: """ Classify a document based on its text content using hybrid keyword + semantic similarity. Args: text: Document text content max_length: Maximum token length for the model Returns: Dictionary with classification results """ if not text or not text.strip(): return { "document_type": "unknown", "confidence": 0.0, "error": "No text extracted from document" } try: # Truncate text if too long (keep first part which usually has most relevant info) if len(text) > max_length * 4: # Rough estimate: 4 chars per token # Take first part and last part for better context first_part = text[:max_length * 2] last_part = text[-max_length * 2:] text = first_part + " " + last_part # Get embedding for the document text doc_embedding = self._get_embedding(text, max_length) # Calculate scores using hybrid approach scores = {} for doc_type in DOCUMENT_TYPES.keys(): # 1. Keyword matching score (0-1) keyword_score = self._calculate_keyword_score(text, doc_type) # 2. Semantic similarity score (0-1, normalized) type_embedding = self.type_embeddings[doc_type] similarity = F.cosine_similarity(doc_embedding, type_embedding, dim=1) semantic_score = (similarity.item() + 1) / 2 # Normalize from [-1, 1] to [0, 1] # 3. Combine scores: 60% keyword, 40% semantic # This gives more weight to explicit keyword matches combined_score = (keyword_score * 0.6) + (semantic_score * 0.4) scores[doc_type] = combined_score # Find the best match best_type = max(scores.items(), key=lambda x: x[1]) # Normalize confidence to percentage (scale to make it more meaningful) # Use sigmoid-like scaling for better confidence representation max_score = best_type[1] if max_score > 0.5: # High confidence: scale from 0.5-1.0 to 50%-95% confidence = 50 + (max_score - 0.5) * 90 elif max_score > 0.3: # Medium confidence: scale from 0.3-0.5 to 30%-50% confidence = 30 + (max_score - 0.3) * 100 else: # Low confidence: scale from 0-0.3 to 0%-30% confidence = max_score * 100 confidence = min(95, max(5, confidence)) # Clamp between 5% and 95% # Get top 5 classifications top_5 = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:5] # Convert scores to percentages for display top_5_percentages = {} for doc_type, score in top_5: if score > 0.5: percent = 50 + (score - 0.5) * 90 elif score > 0.3: percent = 30 + (score - 0.3) * 100 else: percent = score * 100 top_5_percentages[doc_type] = min(95, max(5, percent)) return { "document_type": best_type[0], "confidence": round(confidence / 100, 3), # Return as 0-1 for consistency "all_scores": {k: round(v / 100, 3) for k, v in top_5_percentages.items()}, "text_preview": text[:200] + "..." if len(text) > 200 else text } except Exception as e: print(f"Error classifying document: {e}") import traceback traceback.print_exc() return { "document_type": "unknown", "confidence": 0.0, "error": str(e) } # Global classifier instance _classifier_instance = None def get_classifier() -> DocumentClassifier: """Get or create the global classifier instance.""" global _classifier_instance if _classifier_instance is None: _classifier_instance = DocumentClassifier() return _classifier_instance