DocClassify / backend /app /classifier.py
Seth
Update
dfd15d5
"""Document classification using BERT-tiny model."""
import os
from pathlib import Path
from typing import List, Dict, Optional
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import numpy as np
# Model configuration
MODEL_NAME = "prajjwal1/bert-tiny"
# Models directory: use /app/Model in Docker, or project_root/Model locally
# Check if we're in Docker by looking for /app directory
if Path("/app").exists() and Path("/app/backend").exists():
# Docker environment
MODELS_DIR = Path("/app/Model")
else:
# Local development - go up from backend/app/classifier.py to project root
MODELS_DIR = Path(__file__).resolve().parent.parent.parent / "Model"
MODEL_PATH = MODELS_DIR / "bert-tiny"
# Common document types with descriptions and keywords for better classification
DOCUMENT_TYPES = {
"invoice": {
"description": "A document requesting payment for goods or services provided, containing itemized charges, totals, and payment terms.",
"keywords": ["invoice", "bill", "amount due", "total", "subtotal", "tax", "payment terms", "invoice number", "invoice date", "due date", "itemized", "charges", "balance", "payable", "vendor", "billing"]
},
"receipt": {
"description": "A document confirming payment has been received, showing transaction details and proof of purchase.",
"keywords": ["receipt", "payment received", "paid", "thank you", "transaction", "purchase", "payment confirmation", "receipt number", "date of purchase", "amount paid"]
},
"contract": {
"description": "A legally binding agreement between parties outlining terms, conditions, obligations, and signatures.",
"keywords": ["contract", "agreement", "terms", "party", "signature", "effective date", "parties", "whereas", "hereby", "obligations", "rights", "termination", "breach"]
},
"resume": {
"description": "A document summarizing a person's work experience, education, skills, and qualifications for job applications.",
"keywords": ["resume", "cv", "curriculum vitae", "experience", "education", "skills", "employment", "work history", "qualifications", "objective", "references", "contact information"]
},
"letter": {
"description": "A formal or informal written correspondence addressed to a recipient with greetings and closing.",
"keywords": ["dear", "sincerely", "yours", "letter", "correspondence", "regards", "best regards", "yours truly", "to whom it may concern", "date:", "subject:"]
},
"report": {
"description": "A structured document presenting analysis, findings, conclusions, and recommendations on a specific topic.",
"keywords": ["report", "summary", "findings", "conclusion", "analysis", "recommendations", "executive summary", "introduction", "methodology", "results", "discussion"]
},
"memo": {
"description": "An internal business communication document with headers like To, From, Subject, and Date.",
"keywords": ["memo", "memorandum", "to:", "from:", "subject:", "date:", "re:", "internal", "interoffice"]
},
"email": {
"description": "Electronic mail correspondence with headers showing sender, recipient, subject, and message content.",
"keywords": ["from:", "to:", "subject:", "sent:", "email", "cc:", "bcc:", "reply to", "message id", "date sent"]
},
"form": {
"description": "A structured document with fields to be filled out, often requiring signatures and dates.",
"keywords": ["form", "application", "please fill", "signature", "date", "please print", "complete", "fill out", "applicant", "fields"]
},
"certificate": {
"description": "An official document certifying completion, achievement, or qualification with certification details.",
"keywords": ["certificate", "certified", "awarded", "this certifies", "certification", "certificate of", "issued", "certificate number"]
},
"license": {
"description": "An official document granting permission to perform certain activities, with license numbers and expiration dates.",
"keywords": ["license", "licensed", "expires", "license number", "licensee", "licensing authority", "valid until", "license type", "permit"]
},
"passport": {
"description": "An official government document for international travel containing personal identification and nationality information.",
"keywords": ["passport", "nationality", "date of birth", "passport number", "passport no", "country of issue", "expiry date", "place of birth", "issuing authority"]
},
"medical record": {
"description": "Healthcare documentation containing patient information, diagnoses, treatments, and medical history.",
"keywords": ["medical", "diagnosis", "patient", "treatment", "prescription", "doctor", "physician", "symptoms", "medication", "health", "medical history", "patient id"]
},
"bank statement": {
"description": "A financial document from a bank showing account transactions, balances, deposits, and withdrawals.",
"keywords": ["bank statement", "account statement", "statement of account", "account number", "account balance", "opening balance", "closing balance", "available balance", "statement period", "statement date", "start date balance", "transaction", "transactions", "deposit", "withdrawal", "debit", "credit", "checking account", "savings account", "account summary", "bank name", "routing number", "ending balance", "beginning balance", "total deposits", "total withdrawals", "service charge", "interest earned", "atm", "check", "checks", "transfer", "fee"]
},
"tax document": {
"description": "Tax-related paperwork such as W-2 forms, 1099 forms, tax returns, or IRS correspondence.",
"keywords": ["tax", "irs", "income", "deduction", "w-2", "1099", "tax return", "federal tax", "social security", "withholding", "adjusted gross income", "taxable income"]
},
"legal document": {
"description": "Court documents, legal filings, contracts, or other documents related to legal proceedings or matters.",
"keywords": ["legal", "court", "plaintiff", "defendant", "attorney", "lawyer", "case number", "filing", "petition", "motion", "order", "judgment", "legal counsel"]
},
"academic paper": {
"description": "A scholarly document with abstract, introduction, methodology, results, references, and citations.",
"keywords": ["abstract", "introduction", "methodology", "references", "citation", "research", "study", "literature review", "hypothesis", "data analysis", "conclusion", "bibliography"]
},
"presentation": {
"description": "A document with slides, bullet points, or structured content for presenting information to an audience.",
"keywords": ["slide", "presentation", "agenda", "overview", "bullet points", "powerpoint", "key points", "summary slide", "title slide"]
},
"manual": {
"description": "An instructional document providing step-by-step procedures, guidelines, or how-to information.",
"keywords": ["manual", "instructions", "how to", "procedure", "steps", "guide", "tutorial", "user guide", "operation", "setup", "installation"]
},
"quote": {
"description": "A document providing a price estimate or quotation for goods or services before purchase.",
"keywords": ["quote", "quotation", "estimate", "pricing", "quote number", "valid until", "quote date", "estimated cost", "price quote", "proposal"]
},
"purchase order": {
"description": "A commercial document issued by a buyer to a seller indicating types, quantities, and agreed prices for products or services.",
"keywords": ["purchase order", "po number", "po#", "order number", "purchase", "order date", "ship to", "bill to", "quantity", "unit price", "po"]
},
"insurance policy": {
"description": "A document outlining insurance coverage, terms, premiums, and policy details.",
"keywords": ["insurance", "policy", "policy number", "premium", "coverage", "insured", "beneficiary", "policyholder", "deductible", "claim", "insurance company"]
},
"other": {
"description": "A document that does not clearly fit into any of the above categories.",
"keywords": []
}
}
class DocumentClassifier:
"""Class for classifying documents using BERT-tiny."""
def __init__(self):
self.tokenizer = None
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._load_model()
self._precompute_type_embeddings()
def _load_model(self):
"""Load the BERT-tiny model, downloading if necessary."""
try:
# Check if model exists locally, otherwise download
if MODEL_PATH.exists():
print(f"Loading model from local path: {MODEL_PATH}")
model_path = str(MODEL_PATH)
else:
print(f"Downloading model {MODEL_NAME}...")
model_path = MODEL_NAME
# Create models directory
MODELS_DIR.mkdir(parents=True, exist_ok=True)
# Load tokenizer and model (using AutoModel for embeddings)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModel.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
# Save model locally if downloaded
if not MODEL_PATH.exists():
print(f"Saving model to {MODEL_PATH}...")
self.tokenizer.save_pretrained(str(MODEL_PATH))
self.model.save_pretrained(str(MODEL_PATH))
print("Model saved successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise
def _get_embedding(self, text: str, max_length: int = 512) -> torch.Tensor:
"""Get embedding for a text using BERT-tiny."""
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_length,
padding=True
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Use mean pooling of token embeddings
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings
def _precompute_type_embeddings(self):
"""Precompute embeddings for each document type description."""
print("Precomputing document type embeddings...")
self.type_embeddings = {}
for doc_type, doc_info in DOCUMENT_TYPES.items():
# Combine type name, description, and keywords for better representation
description = doc_info["description"]
keywords = " ".join(doc_info.get("keywords", []))
text = f"{doc_type}: {description} Keywords: {keywords}"
embedding = self._get_embedding(text)
self.type_embeddings[doc_type] = embedding
print("Document type embeddings computed!")
def _calculate_keyword_score(self, text: str, doc_type: str) -> float:
"""Calculate keyword matching score for a document type."""
text_lower = text.lower()
doc_info = DOCUMENT_TYPES.get(doc_type, {})
keywords = doc_info.get("keywords", [])
if not keywords:
return 0.0
# Count keyword matches
matches = sum(1 for keyword in keywords if keyword.lower() in text_lower)
# Calculate score: matches / total keywords, with bonus for multiple matches
base_score = matches / len(keywords) if keywords else 0.0
# Boost score if multiple keywords found (indicates stronger match)
if matches > 0:
boost = min(0.3, matches * 0.05) # Up to 30% boost
base_score = min(1.0, base_score + boost)
return base_score
def classify_document(self, text: str, max_length: int = 512) -> Dict[str, any]:
"""
Classify a document based on its text content using hybrid keyword + semantic similarity.
Args:
text: Document text content
max_length: Maximum token length for the model
Returns:
Dictionary with classification results
"""
if not text or not text.strip():
return {
"document_type": "unknown",
"confidence": 0.0,
"error": "No text extracted from document"
}
try:
# Truncate text if too long (keep first part which usually has most relevant info)
if len(text) > max_length * 4: # Rough estimate: 4 chars per token
# Take first part and last part for better context
first_part = text[:max_length * 2]
last_part = text[-max_length * 2:]
text = first_part + " " + last_part
# Get embedding for the document text
doc_embedding = self._get_embedding(text, max_length)
# Calculate scores using hybrid approach
scores = {}
for doc_type in DOCUMENT_TYPES.keys():
# 1. Keyword matching score (0-1)
keyword_score = self._calculate_keyword_score(text, doc_type)
# 2. Semantic similarity score (0-1, normalized)
type_embedding = self.type_embeddings[doc_type]
similarity = F.cosine_similarity(doc_embedding, type_embedding, dim=1)
semantic_score = (similarity.item() + 1) / 2 # Normalize from [-1, 1] to [0, 1]
# 3. Combine scores: 60% keyword, 40% semantic
# This gives more weight to explicit keyword matches
combined_score = (keyword_score * 0.6) + (semantic_score * 0.4)
scores[doc_type] = combined_score
# Find the best match
best_type = max(scores.items(), key=lambda x: x[1])
# Normalize confidence to percentage (scale to make it more meaningful)
# Use sigmoid-like scaling for better confidence representation
max_score = best_type[1]
if max_score > 0.5:
# High confidence: scale from 0.5-1.0 to 50%-95%
confidence = 50 + (max_score - 0.5) * 90
elif max_score > 0.3:
# Medium confidence: scale from 0.3-0.5 to 30%-50%
confidence = 30 + (max_score - 0.3) * 100
else:
# Low confidence: scale from 0-0.3 to 0%-30%
confidence = max_score * 100
confidence = min(95, max(5, confidence)) # Clamp between 5% and 95%
# Get top 5 classifications
top_5 = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:5]
# Convert scores to percentages for display
top_5_percentages = {}
for doc_type, score in top_5:
if score > 0.5:
percent = 50 + (score - 0.5) * 90
elif score > 0.3:
percent = 30 + (score - 0.3) * 100
else:
percent = score * 100
top_5_percentages[doc_type] = min(95, max(5, percent))
return {
"document_type": best_type[0],
"confidence": round(confidence / 100, 3), # Return as 0-1 for consistency
"all_scores": {k: round(v / 100, 3) for k, v in top_5_percentages.items()},
"text_preview": text[:200] + "..." if len(text) > 200 else text
}
except Exception as e:
print(f"Error classifying document: {e}")
import traceback
traceback.print_exc()
return {
"document_type": "unknown",
"confidence": 0.0,
"error": str(e)
}
# Global classifier instance
_classifier_instance = None
def get_classifier() -> DocumentClassifier:
"""Get or create the global classifier instance."""
global _classifier_instance
if _classifier_instance is None:
_classifier_instance = DocumentClassifier()
return _classifier_instance