File size: 16,650 Bytes
f6e574f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ac95db
f6e574f
9ac95db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfd15d5
9ac95db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e574f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ac95db
 
 
 
 
f6e574f
 
 
 
 
9ac95db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e574f
 
9ac95db
f6e574f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ac95db
f6e574f
9ac95db
 
 
 
 
 
 
f6e574f
9ac95db
 
 
 
 
 
f6e574f
9ac95db
 
f6e574f
9ac95db
 
 
 
 
 
 
 
 
 
 
 
f6e574f
9ac95db
f6e574f
 
9ac95db
 
 
 
 
 
 
 
 
 
 
 
f6e574f
 
 
9ac95db
 
f6e574f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""Document classification using BERT-tiny model."""
import os
from pathlib import Path
from typing import List, Dict, Optional
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import numpy as np

# Model configuration
MODEL_NAME = "prajjwal1/bert-tiny"
# Models directory: use /app/Model in Docker, or project_root/Model locally
# Check if we're in Docker by looking for /app directory
if Path("/app").exists() and Path("/app/backend").exists():
    # Docker environment
    MODELS_DIR = Path("/app/Model")
else:
    # Local development - go up from backend/app/classifier.py to project root
    MODELS_DIR = Path(__file__).resolve().parent.parent.parent / "Model"
MODEL_PATH = MODELS_DIR / "bert-tiny"

# Common document types with descriptions and keywords for better classification
DOCUMENT_TYPES = {
    "invoice": {
        "description": "A document requesting payment for goods or services provided, containing itemized charges, totals, and payment terms.",
        "keywords": ["invoice", "bill", "amount due", "total", "subtotal", "tax", "payment terms", "invoice number", "invoice date", "due date", "itemized", "charges", "balance", "payable", "vendor", "billing"]
    },
    "receipt": {
        "description": "A document confirming payment has been received, showing transaction details and proof of purchase.",
        "keywords": ["receipt", "payment received", "paid", "thank you", "transaction", "purchase", "payment confirmation", "receipt number", "date of purchase", "amount paid"]
    },
    "contract": {
        "description": "A legally binding agreement between parties outlining terms, conditions, obligations, and signatures.",
        "keywords": ["contract", "agreement", "terms", "party", "signature", "effective date", "parties", "whereas", "hereby", "obligations", "rights", "termination", "breach"]
    },
    "resume": {
        "description": "A document summarizing a person's work experience, education, skills, and qualifications for job applications.",
        "keywords": ["resume", "cv", "curriculum vitae", "experience", "education", "skills", "employment", "work history", "qualifications", "objective", "references", "contact information"]
    },
    "letter": {
        "description": "A formal or informal written correspondence addressed to a recipient with greetings and closing.",
        "keywords": ["dear", "sincerely", "yours", "letter", "correspondence", "regards", "best regards", "yours truly", "to whom it may concern", "date:", "subject:"]
    },
    "report": {
        "description": "A structured document presenting analysis, findings, conclusions, and recommendations on a specific topic.",
        "keywords": ["report", "summary", "findings", "conclusion", "analysis", "recommendations", "executive summary", "introduction", "methodology", "results", "discussion"]
    },
    "memo": {
        "description": "An internal business communication document with headers like To, From, Subject, and Date.",
        "keywords": ["memo", "memorandum", "to:", "from:", "subject:", "date:", "re:", "internal", "interoffice"]
    },
    "email": {
        "description": "Electronic mail correspondence with headers showing sender, recipient, subject, and message content.",
        "keywords": ["from:", "to:", "subject:", "sent:", "email", "cc:", "bcc:", "reply to", "message id", "date sent"]
    },
    "form": {
        "description": "A structured document with fields to be filled out, often requiring signatures and dates.",
        "keywords": ["form", "application", "please fill", "signature", "date", "please print", "complete", "fill out", "applicant", "fields"]
    },
    "certificate": {
        "description": "An official document certifying completion, achievement, or qualification with certification details.",
        "keywords": ["certificate", "certified", "awarded", "this certifies", "certification", "certificate of", "issued", "certificate number"]
    },
    "license": {
        "description": "An official document granting permission to perform certain activities, with license numbers and expiration dates.",
        "keywords": ["license", "licensed", "expires", "license number", "licensee", "licensing authority", "valid until", "license type", "permit"]
    },
    "passport": {
        "description": "An official government document for international travel containing personal identification and nationality information.",
        "keywords": ["passport", "nationality", "date of birth", "passport number", "passport no", "country of issue", "expiry date", "place of birth", "issuing authority"]
    },
    "medical record": {
        "description": "Healthcare documentation containing patient information, diagnoses, treatments, and medical history.",
        "keywords": ["medical", "diagnosis", "patient", "treatment", "prescription", "doctor", "physician", "symptoms", "medication", "health", "medical history", "patient id"]
    },
    "bank statement": {
        "description": "A financial document from a bank showing account transactions, balances, deposits, and withdrawals.",
        "keywords": ["bank statement", "account statement", "statement of account", "account number", "account balance", "opening balance", "closing balance", "available balance", "statement period", "statement date", "start date balance", "transaction", "transactions", "deposit", "withdrawal", "debit", "credit", "checking account", "savings account", "account summary", "bank name", "routing number", "ending balance", "beginning balance", "total deposits", "total withdrawals", "service charge", "interest earned", "atm", "check", "checks", "transfer", "fee"]
    },
    "tax document": {
        "description": "Tax-related paperwork such as W-2 forms, 1099 forms, tax returns, or IRS correspondence.",
        "keywords": ["tax", "irs", "income", "deduction", "w-2", "1099", "tax return", "federal tax", "social security", "withholding", "adjusted gross income", "taxable income"]
    },
    "legal document": {
        "description": "Court documents, legal filings, contracts, or other documents related to legal proceedings or matters.",
        "keywords": ["legal", "court", "plaintiff", "defendant", "attorney", "lawyer", "case number", "filing", "petition", "motion", "order", "judgment", "legal counsel"]
    },
    "academic paper": {
        "description": "A scholarly document with abstract, introduction, methodology, results, references, and citations.",
        "keywords": ["abstract", "introduction", "methodology", "references", "citation", "research", "study", "literature review", "hypothesis", "data analysis", "conclusion", "bibliography"]
    },
    "presentation": {
        "description": "A document with slides, bullet points, or structured content for presenting information to an audience.",
        "keywords": ["slide", "presentation", "agenda", "overview", "bullet points", "powerpoint", "key points", "summary slide", "title slide"]
    },
    "manual": {
        "description": "An instructional document providing step-by-step procedures, guidelines, or how-to information.",
        "keywords": ["manual", "instructions", "how to", "procedure", "steps", "guide", "tutorial", "user guide", "operation", "setup", "installation"]
    },
    "quote": {
        "description": "A document providing a price estimate or quotation for goods or services before purchase.",
        "keywords": ["quote", "quotation", "estimate", "pricing", "quote number", "valid until", "quote date", "estimated cost", "price quote", "proposal"]
    },
    "purchase order": {
        "description": "A commercial document issued by a buyer to a seller indicating types, quantities, and agreed prices for products or services.",
        "keywords": ["purchase order", "po number", "po#", "order number", "purchase", "order date", "ship to", "bill to", "quantity", "unit price", "po"]
    },
    "insurance policy": {
        "description": "A document outlining insurance coverage, terms, premiums, and policy details.",
        "keywords": ["insurance", "policy", "policy number", "premium", "coverage", "insured", "beneficiary", "policyholder", "deductible", "claim", "insurance company"]
    },
    "other": {
        "description": "A document that does not clearly fit into any of the above categories.",
        "keywords": []
    }
}


class DocumentClassifier:
    """Class for classifying documents using BERT-tiny."""
    
    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._load_model()
        self._precompute_type_embeddings()
    
    def _load_model(self):
        """Load the BERT-tiny model, downloading if necessary."""
        try:
            # Check if model exists locally, otherwise download
            if MODEL_PATH.exists():
                print(f"Loading model from local path: {MODEL_PATH}")
                model_path = str(MODEL_PATH)
            else:
                print(f"Downloading model {MODEL_NAME}...")
                model_path = MODEL_NAME
                # Create models directory
                MODELS_DIR.mkdir(parents=True, exist_ok=True)
            
            # Load tokenizer and model (using AutoModel for embeddings)
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = AutoModel.from_pretrained(model_path)
            self.model.to(self.device)
            self.model.eval()
            
            # Save model locally if downloaded
            if not MODEL_PATH.exists():
                print(f"Saving model to {MODEL_PATH}...")
                self.tokenizer.save_pretrained(str(MODEL_PATH))
                self.model.save_pretrained(str(MODEL_PATH))
                print("Model saved successfully!")
                
        except Exception as e:
            print(f"Error loading model: {e}")
            raise
    
    def _get_embedding(self, text: str, max_length: int = 512) -> torch.Tensor:
        """Get embedding for a text using BERT-tiny."""
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
            padding=True
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            # Use mean pooling of token embeddings
            embeddings = outputs.last_hidden_state.mean(dim=1)
        
        return embeddings
    
    def _precompute_type_embeddings(self):
        """Precompute embeddings for each document type description."""
        print("Precomputing document type embeddings...")
        self.type_embeddings = {}
        
        for doc_type, doc_info in DOCUMENT_TYPES.items():
            # Combine type name, description, and keywords for better representation
            description = doc_info["description"]
            keywords = " ".join(doc_info.get("keywords", []))
            text = f"{doc_type}: {description} Keywords: {keywords}"
            embedding = self._get_embedding(text)
            self.type_embeddings[doc_type] = embedding
        
        print("Document type embeddings computed!")
    
    def _calculate_keyword_score(self, text: str, doc_type: str) -> float:
        """Calculate keyword matching score for a document type."""
        text_lower = text.lower()
        doc_info = DOCUMENT_TYPES.get(doc_type, {})
        keywords = doc_info.get("keywords", [])
        
        if not keywords:
            return 0.0
        
        # Count keyword matches
        matches = sum(1 for keyword in keywords if keyword.lower() in text_lower)
        
        # Calculate score: matches / total keywords, with bonus for multiple matches
        base_score = matches / len(keywords) if keywords else 0.0
        
        # Boost score if multiple keywords found (indicates stronger match)
        if matches > 0:
            boost = min(0.3, matches * 0.05)  # Up to 30% boost
            base_score = min(1.0, base_score + boost)
        
        return base_score
    
    def classify_document(self, text: str, max_length: int = 512) -> Dict[str, any]:
        """
        Classify a document based on its text content using hybrid keyword + semantic similarity.
        
        Args:
            text: Document text content
            max_length: Maximum token length for the model
            
        Returns:
            Dictionary with classification results
        """
        if not text or not text.strip():
            return {
                "document_type": "unknown",
                "confidence": 0.0,
                "error": "No text extracted from document"
            }
        
        try:
            # Truncate text if too long (keep first part which usually has most relevant info)
            if len(text) > max_length * 4:  # Rough estimate: 4 chars per token
                # Take first part and last part for better context
                first_part = text[:max_length * 2]
                last_part = text[-max_length * 2:]
                text = first_part + " " + last_part
            
            # Get embedding for the document text
            doc_embedding = self._get_embedding(text, max_length)
            
            # Calculate scores using hybrid approach
            scores = {}
            
            for doc_type in DOCUMENT_TYPES.keys():
                # 1. Keyword matching score (0-1)
                keyword_score = self._calculate_keyword_score(text, doc_type)
                
                # 2. Semantic similarity score (0-1, normalized)
                type_embedding = self.type_embeddings[doc_type]
                similarity = F.cosine_similarity(doc_embedding, type_embedding, dim=1)
                semantic_score = (similarity.item() + 1) / 2  # Normalize from [-1, 1] to [0, 1]
                
                # 3. Combine scores: 60% keyword, 40% semantic
                # This gives more weight to explicit keyword matches
                combined_score = (keyword_score * 0.6) + (semantic_score * 0.4)
                scores[doc_type] = combined_score
            
            # Find the best match
            best_type = max(scores.items(), key=lambda x: x[1])
            
            # Normalize confidence to percentage (scale to make it more meaningful)
            # Use sigmoid-like scaling for better confidence representation
            max_score = best_type[1]
            if max_score > 0.5:
                # High confidence: scale from 0.5-1.0 to 50%-95%
                confidence = 50 + (max_score - 0.5) * 90
            elif max_score > 0.3:
                # Medium confidence: scale from 0.3-0.5 to 30%-50%
                confidence = 30 + (max_score - 0.3) * 100
            else:
                # Low confidence: scale from 0-0.3 to 0%-30%
                confidence = max_score * 100
            
            confidence = min(95, max(5, confidence))  # Clamp between 5% and 95%
            
            # Get top 5 classifications
            top_5 = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:5]
            
            # Convert scores to percentages for display
            top_5_percentages = {}
            for doc_type, score in top_5:
                if score > 0.5:
                    percent = 50 + (score - 0.5) * 90
                elif score > 0.3:
                    percent = 30 + (score - 0.3) * 100
                else:
                    percent = score * 100
                top_5_percentages[doc_type] = min(95, max(5, percent))
            
            return {
                "document_type": best_type[0],
                "confidence": round(confidence / 100, 3),  # Return as 0-1 for consistency
                "all_scores": {k: round(v / 100, 3) for k, v in top_5_percentages.items()},
                "text_preview": text[:200] + "..." if len(text) > 200 else text
            }
            
        except Exception as e:
            print(f"Error classifying document: {e}")
            import traceback
            traceback.print_exc()
            return {
                "document_type": "unknown",
                "confidence": 0.0,
                "error": str(e)
            }


# Global classifier instance
_classifier_instance = None

def get_classifier() -> DocumentClassifier:
    """Get or create the global classifier instance."""
    global _classifier_instance
    if _classifier_instance is None:
        _classifier_instance = DocumentClassifier()
    return _classifier_instance