Spaces:
Sleeping
Sleeping
Seth commited on
Commit ·
f6e574f
1
Parent(s): b434cd3
Update
Browse files- .gitignore +46 -0
- README.md +90 -2
- backend/app/classifier.py +194 -0
- backend/app/main.py +58 -2
- backend/app/pdf_processor.py +31 -0
- backend/requirements.txt +7 -1
- download_model.py +42 -0
- frontend/index.html +7 -1
- frontend/src/App.jsx +312 -13
- test_classifier.py +55 -0
.gitignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
venv/
|
| 9 |
+
ENV/
|
| 10 |
+
.venv
|
| 11 |
+
*.egg-info/
|
| 12 |
+
dist/
|
| 13 |
+
build/
|
| 14 |
+
|
| 15 |
+
# Virtual environment
|
| 16 |
+
venv/
|
| 17 |
+
|
| 18 |
+
# Models (large files)
|
| 19 |
+
Model/
|
| 20 |
+
models/
|
| 21 |
+
*.bin
|
| 22 |
+
*.safetensors
|
| 23 |
+
|
| 24 |
+
# Node
|
| 25 |
+
node_modules/
|
| 26 |
+
npm-debug.log*
|
| 27 |
+
yarn-debug.log*
|
| 28 |
+
yarn-error.log*
|
| 29 |
+
|
| 30 |
+
# Frontend build
|
| 31 |
+
frontend/dist/
|
| 32 |
+
frontend/.vite/
|
| 33 |
+
|
| 34 |
+
# IDE
|
| 35 |
+
.vscode/
|
| 36 |
+
.idea/
|
| 37 |
+
*.swp
|
| 38 |
+
*.swo
|
| 39 |
+
*~
|
| 40 |
+
|
| 41 |
+
# OS
|
| 42 |
+
.DS_Store
|
| 43 |
+
Thumbs.db
|
| 44 |
+
|
| 45 |
+
# Logs
|
| 46 |
+
*.log
|
README.md
CHANGED
|
@@ -1,10 +1,98 @@
|
|
| 1 |
---
|
| 2 |
title: DocClassify
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: DocClassify
|
| 3 |
+
emoji: 📄
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# Document Classifier
|
| 11 |
+
|
| 12 |
+
A web application that uses BERT-tiny to classify PDF documents by type. Upload a PDF file and get instant classification results.
|
| 13 |
+
|
| 14 |
+
## Features
|
| 15 |
+
|
| 16 |
+
- 📄 PDF file upload and processing
|
| 17 |
+
- 🤖 BERT-tiny model for document classification
|
| 18 |
+
- 🎯 Classifies 20+ document types including:
|
| 19 |
+
- Invoice, Receipt, Contract, Resume
|
| 20 |
+
- Letter, Report, Memo, Email
|
| 21 |
+
- Form, Certificate, License, Passport
|
| 22 |
+
- Medical records, Bank statements, Tax documents
|
| 23 |
+
- Legal documents, Academic papers, and more
|
| 24 |
+
- 💾 Model is downloaded and cached locally on first use
|
| 25 |
+
- 🎨 Modern, user-friendly interface
|
| 26 |
+
|
| 27 |
+
## How It Works
|
| 28 |
+
|
| 29 |
+
1. The app uses the `prajjwal1/bert-tiny` model from Hugging Face
|
| 30 |
+
2. On first run, the model is automatically downloaded to the `models/` directory
|
| 31 |
+
3. PDF text is extracted using PyPDF2
|
| 32 |
+
4. Document embeddings are computed using BERT-tiny
|
| 33 |
+
5. Similarity scores are calculated against pre-computed document type embeddings
|
| 34 |
+
6. The document is classified with confidence scores
|
| 35 |
+
|
| 36 |
+
## Setup
|
| 37 |
+
|
| 38 |
+
### Local Development
|
| 39 |
+
|
| 40 |
+
1. **Backend Setup:**
|
| 41 |
+
```bash
|
| 42 |
+
cd backend
|
| 43 |
+
pip install -r requirements.txt
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
2. **Frontend Setup:**
|
| 47 |
+
```bash
|
| 48 |
+
cd frontend
|
| 49 |
+
npm install
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
3. **Run Backend:**
|
| 53 |
+
```bash
|
| 54 |
+
cd backend
|
| 55 |
+
uvicorn app.main:app --reload --port 8000
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
4. **Run Frontend:**
|
| 59 |
+
```bash
|
| 60 |
+
cd frontend
|
| 61 |
+
npm run dev
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
5. Open `http://localhost:5173` in your browser
|
| 65 |
+
|
| 66 |
+
### Docker Deployment
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
docker build -t docclassify .
|
| 70 |
+
docker run -p 7860:7860 docclassify
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Usage
|
| 74 |
+
|
| 75 |
+
1. Click "Select PDF File" to choose a PDF document
|
| 76 |
+
2. Click "Classify Document" to process the file
|
| 77 |
+
3. View the classification result with confidence scores
|
| 78 |
+
4. See top 5 document type predictions
|
| 79 |
+
|
| 80 |
+
## Model Information
|
| 81 |
+
|
| 82 |
+
- **Model:** `prajjwal1/bert-tiny`
|
| 83 |
+
- **Size:** ~4.4M parameters
|
| 84 |
+
- **Architecture:** BERT (L=2, H=128)
|
| 85 |
+
- **Source:** [Hugging Face Model Card](https://huggingface.co/prajjwal1/bert-tiny)
|
| 86 |
+
|
| 87 |
+
## Technical Stack
|
| 88 |
+
|
| 89 |
+
- **Backend:** FastAPI, PyTorch, Transformers, PyPDF2
|
| 90 |
+
- **Frontend:** React, Vite
|
| 91 |
+
- **Model:** BERT-tiny (prajjwal1/bert-tiny)
|
| 92 |
+
|
| 93 |
+
## Notes
|
| 94 |
+
|
| 95 |
+
- The model will be automatically downloaded on first use (~17MB)
|
| 96 |
+
- Classification works best with text-based PDFs
|
| 97 |
+
- Image-based PDFs may not work if they don't contain extractable text
|
| 98 |
+
- Processing time depends on document size and system resources
|
backend/app/classifier.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document classification using BERT-tiny model."""
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from transformers import AutoTokenizer, AutoModel
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
# Model configuration
|
| 11 |
+
MODEL_NAME = "prajjwal1/bert-tiny"
|
| 12 |
+
# Models directory: use /app/Model in Docker, or project_root/Model locally
|
| 13 |
+
# Check if we're in Docker by looking for /app directory
|
| 14 |
+
if Path("/app").exists() and Path("/app/backend").exists():
|
| 15 |
+
# Docker environment
|
| 16 |
+
MODELS_DIR = Path("/app/Model")
|
| 17 |
+
else:
|
| 18 |
+
# Local development - go up from backend/app/classifier.py to project root
|
| 19 |
+
MODELS_DIR = Path(__file__).resolve().parent.parent.parent / "Model"
|
| 20 |
+
MODEL_PATH = MODELS_DIR / "bert-tiny"
|
| 21 |
+
|
| 22 |
+
# Common document types with descriptions for better classification
|
| 23 |
+
DOCUMENT_TYPES = {
|
| 24 |
+
"invoice": "A document requesting payment for goods or services provided, containing itemized charges, totals, and payment terms.",
|
| 25 |
+
"receipt": "A document confirming payment has been received, showing transaction details and proof of purchase.",
|
| 26 |
+
"contract": "A legally binding agreement between parties outlining terms, conditions, obligations, and signatures.",
|
| 27 |
+
"resume": "A document summarizing a person's work experience, education, skills, and qualifications for job applications.",
|
| 28 |
+
"letter": "A formal or informal written correspondence addressed to a recipient with greetings and closing.",
|
| 29 |
+
"report": "A structured document presenting analysis, findings, conclusions, and recommendations on a specific topic.",
|
| 30 |
+
"memo": "An internal business communication document with headers like To, From, Subject, and Date.",
|
| 31 |
+
"email": "Electronic mail correspondence with headers showing sender, recipient, subject, and message content.",
|
| 32 |
+
"form": "A structured document with fields to be filled out, often requiring signatures and dates.",
|
| 33 |
+
"certificate": "An official document certifying completion, achievement, or qualification with certification details.",
|
| 34 |
+
"license": "An official document granting permission to perform certain activities, with license numbers and expiration dates.",
|
| 35 |
+
"passport": "An official government document for international travel containing personal identification and nationality information.",
|
| 36 |
+
"medical record": "Healthcare documentation containing patient information, diagnoses, treatments, and medical history.",
|
| 37 |
+
"bank statement": "A financial document from a bank showing account transactions, balances, deposits, and withdrawals.",
|
| 38 |
+
"tax document": "Tax-related paperwork such as W-2 forms, 1099 forms, tax returns, or IRS correspondence.",
|
| 39 |
+
"legal document": "Court documents, legal filings, contracts, or other documents related to legal proceedings or matters.",
|
| 40 |
+
"academic paper": "A scholarly document with abstract, introduction, methodology, results, references, and citations.",
|
| 41 |
+
"presentation": "A document with slides, bullet points, or structured content for presenting information to an audience.",
|
| 42 |
+
"manual": "An instructional document providing step-by-step procedures, guidelines, or how-to information.",
|
| 43 |
+
"other": "A document that does not clearly fit into any of the above categories."
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DocumentClassifier:
|
| 48 |
+
"""Class for classifying documents using BERT-tiny."""
|
| 49 |
+
|
| 50 |
+
def __init__(self):
|
| 51 |
+
self.tokenizer = None
|
| 52 |
+
self.model = None
|
| 53 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 54 |
+
self._load_model()
|
| 55 |
+
self._precompute_type_embeddings()
|
| 56 |
+
|
| 57 |
+
def _load_model(self):
|
| 58 |
+
"""Load the BERT-tiny model, downloading if necessary."""
|
| 59 |
+
try:
|
| 60 |
+
# Check if model exists locally, otherwise download
|
| 61 |
+
if MODEL_PATH.exists():
|
| 62 |
+
print(f"Loading model from local path: {MODEL_PATH}")
|
| 63 |
+
model_path = str(MODEL_PATH)
|
| 64 |
+
else:
|
| 65 |
+
print(f"Downloading model {MODEL_NAME}...")
|
| 66 |
+
model_path = MODEL_NAME
|
| 67 |
+
# Create models directory
|
| 68 |
+
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
# Load tokenizer and model (using AutoModel for embeddings)
|
| 71 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 72 |
+
self.model = AutoModel.from_pretrained(model_path)
|
| 73 |
+
self.model.to(self.device)
|
| 74 |
+
self.model.eval()
|
| 75 |
+
|
| 76 |
+
# Save model locally if downloaded
|
| 77 |
+
if not MODEL_PATH.exists():
|
| 78 |
+
print(f"Saving model to {MODEL_PATH}...")
|
| 79 |
+
self.tokenizer.save_pretrained(str(MODEL_PATH))
|
| 80 |
+
self.model.save_pretrained(str(MODEL_PATH))
|
| 81 |
+
print("Model saved successfully!")
|
| 82 |
+
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error loading model: {e}")
|
| 85 |
+
raise
|
| 86 |
+
|
| 87 |
+
def _get_embedding(self, text: str, max_length: int = 512) -> torch.Tensor:
|
| 88 |
+
"""Get embedding for a text using BERT-tiny."""
|
| 89 |
+
inputs = self.tokenizer(
|
| 90 |
+
text,
|
| 91 |
+
return_tensors="pt",
|
| 92 |
+
truncation=True,
|
| 93 |
+
max_length=max_length,
|
| 94 |
+
padding=True
|
| 95 |
+
).to(self.device)
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
outputs = self.model(**inputs)
|
| 99 |
+
# Use mean pooling of token embeddings
|
| 100 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 101 |
+
|
| 102 |
+
return embeddings
|
| 103 |
+
|
| 104 |
+
def _precompute_type_embeddings(self):
|
| 105 |
+
"""Precompute embeddings for each document type description."""
|
| 106 |
+
print("Precomputing document type embeddings...")
|
| 107 |
+
self.type_embeddings = {}
|
| 108 |
+
|
| 109 |
+
for doc_type, description in DOCUMENT_TYPES.items():
|
| 110 |
+
# Combine type name and description for better representation
|
| 111 |
+
text = f"{doc_type}: {description}"
|
| 112 |
+
embedding = self._get_embedding(text)
|
| 113 |
+
self.type_embeddings[doc_type] = embedding
|
| 114 |
+
|
| 115 |
+
print("Document type embeddings computed!")
|
| 116 |
+
|
| 117 |
+
def classify_document(self, text: str, max_length: int = 512) -> Dict[str, any]:
|
| 118 |
+
"""
|
| 119 |
+
Classify a document based on its text content using BERT-tiny embeddings.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
text: Document text content
|
| 123 |
+
max_length: Maximum token length for the model
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Dictionary with classification results
|
| 127 |
+
"""
|
| 128 |
+
if not text or not text.strip():
|
| 129 |
+
return {
|
| 130 |
+
"document_type": "unknown",
|
| 131 |
+
"confidence": 0.0,
|
| 132 |
+
"error": "No text extracted from document"
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
# Truncate text if too long (keep first part which usually has most relevant info)
|
| 137 |
+
if len(text) > max_length * 4: # Rough estimate: 4 chars per token
|
| 138 |
+
# Take first part and last part for better context
|
| 139 |
+
first_part = text[:max_length * 2]
|
| 140 |
+
last_part = text[-max_length * 2:]
|
| 141 |
+
text = first_part + " " + last_part
|
| 142 |
+
|
| 143 |
+
# Get embedding for the document text
|
| 144 |
+
doc_embedding = self._get_embedding(text, max_length)
|
| 145 |
+
|
| 146 |
+
# Calculate cosine similarity with each document type
|
| 147 |
+
scores = {}
|
| 148 |
+
for doc_type, type_embedding in self.type_embeddings.items():
|
| 149 |
+
# Calculate cosine similarity
|
| 150 |
+
similarity = F.cosine_similarity(doc_embedding, type_embedding, dim=1)
|
| 151 |
+
scores[doc_type] = similarity.item()
|
| 152 |
+
|
| 153 |
+
# Normalize scores to 0-1 range using softmax
|
| 154 |
+
score_values = torch.tensor(list(scores.values()))
|
| 155 |
+
normalized_scores = F.softmax(score_values, dim=0)
|
| 156 |
+
|
| 157 |
+
# Update scores with normalized values
|
| 158 |
+
normalized_dict = {}
|
| 159 |
+
for i, doc_type in enumerate(scores.keys()):
|
| 160 |
+
normalized_dict[doc_type] = normalized_scores[i].item()
|
| 161 |
+
|
| 162 |
+
# Find the best match
|
| 163 |
+
best_type = max(normalized_dict.items(), key=lambda x: x[1])
|
| 164 |
+
|
| 165 |
+
# Get top 5 classifications
|
| 166 |
+
top_5 = sorted(normalized_dict.items(), key=lambda x: x[1], reverse=True)[:5]
|
| 167 |
+
|
| 168 |
+
return {
|
| 169 |
+
"document_type": best_type[0],
|
| 170 |
+
"confidence": round(best_type[1], 3),
|
| 171 |
+
"all_scores": {k: round(v, 3) for k, v in top_5},
|
| 172 |
+
"text_preview": text[:200] + "..." if len(text) > 200 else text
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Error classifying document: {e}")
|
| 177 |
+
import traceback
|
| 178 |
+
traceback.print_exc()
|
| 179 |
+
return {
|
| 180 |
+
"document_type": "unknown",
|
| 181 |
+
"confidence": 0.0,
|
| 182 |
+
"error": str(e)
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# Global classifier instance
|
| 187 |
+
_classifier_instance = None
|
| 188 |
+
|
| 189 |
+
def get_classifier() -> DocumentClassifier:
|
| 190 |
+
"""Get or create the global classifier instance."""
|
| 191 |
+
global _classifier_instance
|
| 192 |
+
if _classifier_instance is None:
|
| 193 |
+
_classifier_instance = DocumentClassifier()
|
| 194 |
+
return _classifier_instance
|
backend/app/main.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
-
from fastapi.responses import FileResponse
|
| 3 |
from fastapi.staticfiles import StaticFiles
|
| 4 |
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
from pathlib import Path
|
|
|
|
|
|
|
| 6 |
|
| 7 |
app = FastAPI()
|
| 8 |
|
|
@@ -14,6 +16,16 @@ app.add_middleware(
|
|
| 14 |
allow_headers=["*"],
|
| 15 |
)
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# ---- API ----
|
| 18 |
@app.get("/api/health")
|
| 19 |
def health():
|
|
@@ -23,6 +35,50 @@ def health():
|
|
| 23 |
def hello():
|
| 24 |
return {"message": "Hello from FastAPI"}
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# ---- Frontend static serving ----
|
| 27 |
FRONTEND_DIST = Path(__file__).resolve().parents[2] / "frontend" / "dist"
|
| 28 |
INDEX_FILE = FRONTEND_DIST / "index.html"
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 2 |
+
from fastapi.responses import FileResponse, JSONResponse
|
| 3 |
from fastapi.staticfiles import StaticFiles
|
| 4 |
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
from pathlib import Path
|
| 6 |
+
from app.pdf_processor import extract_text_from_pdf
|
| 7 |
+
from app.classifier import get_classifier
|
| 8 |
|
| 9 |
app = FastAPI()
|
| 10 |
|
|
|
|
| 16 |
allow_headers=["*"],
|
| 17 |
)
|
| 18 |
|
| 19 |
+
# Initialize classifier (lazy loading)
|
| 20 |
+
classifier = None
|
| 21 |
+
|
| 22 |
+
def get_classifier_instance():
|
| 23 |
+
"""Lazy load the classifier."""
|
| 24 |
+
global classifier
|
| 25 |
+
if classifier is None:
|
| 26 |
+
classifier = get_classifier()
|
| 27 |
+
return classifier
|
| 28 |
+
|
| 29 |
# ---- API ----
|
| 30 |
@app.get("/api/health")
|
| 31 |
def health():
|
|
|
|
| 35 |
def hello():
|
| 36 |
return {"message": "Hello from FastAPI"}
|
| 37 |
|
| 38 |
+
@app.post("/api/classify")
|
| 39 |
+
async def classify_document(file: UploadFile = File(...)):
|
| 40 |
+
"""
|
| 41 |
+
Classify a PDF document.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
file: Uploaded PDF file
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Classification results with document type and confidence
|
| 48 |
+
"""
|
| 49 |
+
# Validate file type
|
| 50 |
+
if not file.filename.lower().endswith('.pdf'):
|
| 51 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported")
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
# Read file content
|
| 55 |
+
contents = await file.read()
|
| 56 |
+
|
| 57 |
+
# Extract text from PDF
|
| 58 |
+
text = extract_text_from_pdf(contents)
|
| 59 |
+
|
| 60 |
+
if not text:
|
| 61 |
+
raise HTTPException(
|
| 62 |
+
status_code=400,
|
| 63 |
+
detail="Could not extract text from PDF. The file might be empty, corrupted, or image-based."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Classify the document
|
| 67 |
+
classifier_instance = get_classifier_instance()
|
| 68 |
+
result = classifier_instance.classify_document(text)
|
| 69 |
+
|
| 70 |
+
return JSONResponse(content={
|
| 71 |
+
"success": True,
|
| 72 |
+
"filename": file.filename,
|
| 73 |
+
"classification": result,
|
| 74 |
+
"text_length": len(text)
|
| 75 |
+
})
|
| 76 |
+
|
| 77 |
+
except HTTPException:
|
| 78 |
+
raise
|
| 79 |
+
except Exception as e:
|
| 80 |
+
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
|
| 81 |
+
|
| 82 |
# ---- Frontend static serving ----
|
| 83 |
FRONTEND_DIST = Path(__file__).resolve().parents[2] / "frontend" / "dist"
|
| 84 |
INDEX_FILE = FRONTEND_DIST / "index.html"
|
backend/app/pdf_processor.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PDF text extraction utilities."""
|
| 2 |
+
import io
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from PyPDF2 import PdfReader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def extract_text_from_pdf(pdf_bytes: bytes) -> Optional[str]:
|
| 8 |
+
"""
|
| 9 |
+
Extract text content from a PDF file.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
pdf_bytes: PDF file content as bytes
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Extracted text as string, or None if extraction fails
|
| 16 |
+
"""
|
| 17 |
+
try:
|
| 18 |
+
pdf_file = io.BytesIO(pdf_bytes)
|
| 19 |
+
reader = PdfReader(pdf_file)
|
| 20 |
+
|
| 21 |
+
text_parts = []
|
| 22 |
+
for page in reader.pages:
|
| 23 |
+
text = page.extract_text()
|
| 24 |
+
if text:
|
| 25 |
+
text_parts.append(text)
|
| 26 |
+
|
| 27 |
+
full_text = "\n\n".join(text_parts)
|
| 28 |
+
return full_text if full_text.strip() else None
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Error extracting text from PDF: {e}")
|
| 31 |
+
return None
|
backend/requirements.txt
CHANGED
|
@@ -1,2 +1,8 @@
|
|
| 1 |
fastapi
|
| 2 |
-
uvicorn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
python-multipart
|
| 4 |
+
transformers
|
| 5 |
+
torch
|
| 6 |
+
PyPDF2
|
| 7 |
+
sentencepiece
|
| 8 |
+
protobuf
|
download_model.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Download BERT-tiny model upfront to Model folder."""
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from transformers import AutoTokenizer, AutoModel
|
| 5 |
+
|
| 6 |
+
MODEL_NAME = "prajjwal1/bert-tiny"
|
| 7 |
+
MODELS_DIR = Path(__file__).resolve().parent / "Model"
|
| 8 |
+
MODEL_PATH = MODELS_DIR / "bert-tiny"
|
| 9 |
+
|
| 10 |
+
def download_model():
|
| 11 |
+
"""Download and save the BERT-tiny model."""
|
| 12 |
+
print(f"Downloading model: {MODEL_NAME}")
|
| 13 |
+
print(f"Target directory: {MODEL_PATH}")
|
| 14 |
+
|
| 15 |
+
# Create Model directory
|
| 16 |
+
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
if MODEL_PATH.exists():
|
| 19 |
+
print(f"Model already exists at {MODEL_PATH}")
|
| 20 |
+
print("Skipping download.")
|
| 21 |
+
return
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
print("Downloading tokenizer...")
|
| 25 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 26 |
+
|
| 27 |
+
print("Downloading model...")
|
| 28 |
+
model = AutoModel.from_pretrained(MODEL_NAME)
|
| 29 |
+
|
| 30 |
+
print(f"Saving model to {MODEL_PATH}...")
|
| 31 |
+
tokenizer.save_pretrained(str(MODEL_PATH))
|
| 32 |
+
model.save_pretrained(str(MODEL_PATH))
|
| 33 |
+
|
| 34 |
+
print("✅ Model downloaded and saved successfully!")
|
| 35 |
+
print(f"Location: {MODEL_PATH}")
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"❌ Error downloading model: {e}")
|
| 39 |
+
raise
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
download_model()
|
frontend/index.html
CHANGED
|
@@ -3,7 +3,13 @@
|
|
| 3 |
<head>
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 6 |
-
<title>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
</head>
|
| 8 |
<body>
|
| 9 |
<div id="root"></div>
|
|
|
|
| 3 |
<head>
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 6 |
+
<title>Document Classifier - BERT-tiny</title>
|
| 7 |
+
<style>
|
| 8 |
+
@keyframes spin {
|
| 9 |
+
0% { transform: rotate(0deg); }
|
| 10 |
+
100% { transform: rotate(360deg); }
|
| 11 |
+
}
|
| 12 |
+
</style>
|
| 13 |
</head>
|
| 14 |
<body>
|
| 15 |
<div id="root"></div>
|
frontend/src/App.jsx
CHANGED
|
@@ -1,23 +1,322 @@
|
|
| 1 |
-
import React, {
|
| 2 |
|
| 3 |
export default function App() {
|
| 4 |
-
const [
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
.
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
return (
|
| 14 |
-
<div
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
</div>
|
| 22 |
);
|
| 23 |
}
|
|
|
|
| 1 |
+
import React, { useState } from "react";
|
| 2 |
|
| 3 |
export default function App() {
|
| 4 |
+
const [file, setFile] = useState(null);
|
| 5 |
+
const [loading, setLoading] = useState(false);
|
| 6 |
+
const [result, setResult] = useState(null);
|
| 7 |
+
const [error, setError] = useState(null);
|
| 8 |
|
| 9 |
+
const handleFileChange = (e) => {
|
| 10 |
+
const selectedFile = e.target.files[0];
|
| 11 |
+
if (selectedFile) {
|
| 12 |
+
if (selectedFile.type !== "application/pdf") {
|
| 13 |
+
setError("Please select a PDF file");
|
| 14 |
+
setFile(null);
|
| 15 |
+
return;
|
| 16 |
+
}
|
| 17 |
+
setFile(selectedFile);
|
| 18 |
+
setError(null);
|
| 19 |
+
setResult(null);
|
| 20 |
+
}
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
const handleClassify = async () => {
|
| 24 |
+
if (!file) {
|
| 25 |
+
setError("Please select a PDF file first");
|
| 26 |
+
return;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
setLoading(true);
|
| 30 |
+
setError(null);
|
| 31 |
+
setResult(null);
|
| 32 |
+
|
| 33 |
+
const formData = new FormData();
|
| 34 |
+
formData.append("file", file);
|
| 35 |
+
|
| 36 |
+
try {
|
| 37 |
+
const response = await fetch("/api/classify", {
|
| 38 |
+
method: "POST",
|
| 39 |
+
body: formData,
|
| 40 |
+
});
|
| 41 |
+
|
| 42 |
+
if (!response.ok) {
|
| 43 |
+
const errorData = await response.json();
|
| 44 |
+
throw new Error(errorData.detail || "Classification failed");
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
const data = await response.json();
|
| 48 |
+
setResult(data);
|
| 49 |
+
} catch (err) {
|
| 50 |
+
setError(err.message || "An error occurred during classification");
|
| 51 |
+
} finally {
|
| 52 |
+
setLoading(false);
|
| 53 |
+
}
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
const handleReset = () => {
|
| 57 |
+
setFile(null);
|
| 58 |
+
setResult(null);
|
| 59 |
+
setError(null);
|
| 60 |
+
// Reset file input
|
| 61 |
+
const fileInput = document.getElementById("pdf-upload");
|
| 62 |
+
if (fileInput) fileInput.value = "";
|
| 63 |
+
};
|
| 64 |
|
| 65 |
return (
|
| 66 |
+
<div
|
| 67 |
+
style={{
|
| 68 |
+
fontFamily: "system-ui, -apple-system, sans-serif",
|
| 69 |
+
maxWidth: "800px",
|
| 70 |
+
margin: "0 auto",
|
| 71 |
+
padding: "24px",
|
| 72 |
+
lineHeight: 1.6,
|
| 73 |
+
}}
|
| 74 |
+
>
|
| 75 |
+
<div style={{ textAlign: "center", marginBottom: "32px" }}>
|
| 76 |
+
<h1 style={{ margin: "0 0 8px 0", color: "#1a1a1a" }}>
|
| 77 |
+
📄 Document Classifier
|
| 78 |
+
</h1>
|
| 79 |
+
<p style={{ color: "#666", margin: "0" }}>
|
| 80 |
+
Upload a PDF file to classify its document type using BERT-tiny
|
| 81 |
+
</p>
|
| 82 |
+
</div>
|
| 83 |
+
|
| 84 |
+
<div
|
| 85 |
+
style={{
|
| 86 |
+
border: "2px dashed #ddd",
|
| 87 |
+
borderRadius: "12px",
|
| 88 |
+
padding: "32px",
|
| 89 |
+
textAlign: "center",
|
| 90 |
+
backgroundColor: "#fafafa",
|
| 91 |
+
marginBottom: "24px",
|
| 92 |
+
}}
|
| 93 |
+
>
|
| 94 |
+
<input
|
| 95 |
+
id="pdf-upload"
|
| 96 |
+
type="file"
|
| 97 |
+
accept=".pdf"
|
| 98 |
+
onChange={handleFileChange}
|
| 99 |
+
style={{ display: "none" }}
|
| 100 |
+
/>
|
| 101 |
+
<label
|
| 102 |
+
htmlFor="pdf-upload"
|
| 103 |
+
style={{
|
| 104 |
+
display: "inline-block",
|
| 105 |
+
padding: "12px 24px",
|
| 106 |
+
backgroundColor: "#4f46e5",
|
| 107 |
+
color: "white",
|
| 108 |
+
borderRadius: "8px",
|
| 109 |
+
cursor: "pointer",
|
| 110 |
+
fontSize: "16px",
|
| 111 |
+
fontWeight: "500",
|
| 112 |
+
marginBottom: "16px",
|
| 113 |
+
transition: "background-color 0.2s",
|
| 114 |
+
}}
|
| 115 |
+
onMouseOver={(e) => (e.target.style.backgroundColor = "#4338ca")}
|
| 116 |
+
onMouseOut={(e) => (e.target.style.backgroundColor = "#4f46e5")}
|
| 117 |
+
>
|
| 118 |
+
{file ? "Change PDF File" : "Select PDF File"}
|
| 119 |
+
</label>
|
| 120 |
|
| 121 |
+
{file && (
|
| 122 |
+
<div style={{ marginTop: "16px" }}>
|
| 123 |
+
<p style={{ margin: "8px 0", color: "#333" }}>
|
| 124 |
+
<strong>Selected:</strong> {file.name}
|
| 125 |
+
</p>
|
| 126 |
+
<p style={{ margin: "4px 0", color: "#666", fontSize: "14px" }}>
|
| 127 |
+
Size: {(file.size / 1024).toFixed(2)} KB
|
| 128 |
+
</p>
|
| 129 |
+
</div>
|
| 130 |
+
)}
|
| 131 |
+
|
| 132 |
+
<div style={{ marginTop: "24px" }}>
|
| 133 |
+
<button
|
| 134 |
+
onClick={handleClassify}
|
| 135 |
+
disabled={!file || loading}
|
| 136 |
+
style={{
|
| 137 |
+
padding: "12px 32px",
|
| 138 |
+
fontSize: "16px",
|
| 139 |
+
fontWeight: "600",
|
| 140 |
+
backgroundColor: file && !loading ? "#10b981" : "#9ca3af",
|
| 141 |
+
color: "white",
|
| 142 |
+
border: "none",
|
| 143 |
+
borderRadius: "8px",
|
| 144 |
+
cursor: file && !loading ? "pointer" : "not-allowed",
|
| 145 |
+
transition: "background-color 0.2s",
|
| 146 |
+
}}
|
| 147 |
+
onMouseOver={(e) => {
|
| 148 |
+
if (file && !loading) {
|
| 149 |
+
e.target.style.backgroundColor = "#059669";
|
| 150 |
+
}
|
| 151 |
+
}}
|
| 152 |
+
onMouseOut={(e) => {
|
| 153 |
+
if (file && !loading) {
|
| 154 |
+
e.target.style.backgroundColor = "#10b981";
|
| 155 |
+
}
|
| 156 |
+
}}
|
| 157 |
+
>
|
| 158 |
+
{loading ? "Classifying..." : "Classify Document"}
|
| 159 |
+
</button>
|
| 160 |
+
|
| 161 |
+
{file && (
|
| 162 |
+
<button
|
| 163 |
+
onClick={handleReset}
|
| 164 |
+
disabled={loading}
|
| 165 |
+
style={{
|
| 166 |
+
padding: "12px 24px",
|
| 167 |
+
fontSize: "16px",
|
| 168 |
+
fontWeight: "500",
|
| 169 |
+
backgroundColor: "transparent",
|
| 170 |
+
color: "#666",
|
| 171 |
+
border: "1px solid #ddd",
|
| 172 |
+
borderRadius: "8px",
|
| 173 |
+
cursor: loading ? "not-allowed" : "pointer",
|
| 174 |
+
marginLeft: "12px",
|
| 175 |
+
}}
|
| 176 |
+
>
|
| 177 |
+
Reset
|
| 178 |
+
</button>
|
| 179 |
+
)}
|
| 180 |
+
</div>
|
| 181 |
</div>
|
| 182 |
+
|
| 183 |
+
{error && (
|
| 184 |
+
<div
|
| 185 |
+
style={{
|
| 186 |
+
padding: "16px",
|
| 187 |
+
backgroundColor: "#fee2e2",
|
| 188 |
+
border: "1px solid #fecaca",
|
| 189 |
+
borderRadius: "8px",
|
| 190 |
+
color: "#991b1b",
|
| 191 |
+
marginBottom: "24px",
|
| 192 |
+
}}
|
| 193 |
+
>
|
| 194 |
+
<strong>Error:</strong> {error}
|
| 195 |
+
</div>
|
| 196 |
+
)}
|
| 197 |
+
|
| 198 |
+
{result && (
|
| 199 |
+
<div
|
| 200 |
+
style={{
|
| 201 |
+
padding: "24px",
|
| 202 |
+
backgroundColor: "#f0fdf4",
|
| 203 |
+
border: "2px solid #86efac",
|
| 204 |
+
borderRadius: "12px",
|
| 205 |
+
marginBottom: "24px",
|
| 206 |
+
}}
|
| 207 |
+
>
|
| 208 |
+
<h2 style={{ margin: "0 0 16px 0", color: "#166534" }}>
|
| 209 |
+
Classification Result
|
| 210 |
+
</h2>
|
| 211 |
+
|
| 212 |
+
<div
|
| 213 |
+
style={{
|
| 214 |
+
backgroundColor: "white",
|
| 215 |
+
padding: "20px",
|
| 216 |
+
borderRadius: "8px",
|
| 217 |
+
marginBottom: "16px",
|
| 218 |
+
}}
|
| 219 |
+
>
|
| 220 |
+
<div style={{ marginBottom: "12px" }}>
|
| 221 |
+
<span style={{ color: "#666", fontSize: "14px" }}>
|
| 222 |
+
Document Type:
|
| 223 |
+
</span>
|
| 224 |
+
<div
|
| 225 |
+
style={{
|
| 226 |
+
fontSize: "24px",
|
| 227 |
+
fontWeight: "700",
|
| 228 |
+
color: "#10b981",
|
| 229 |
+
marginTop: "4px",
|
| 230 |
+
textTransform: "capitalize",
|
| 231 |
+
}}
|
| 232 |
+
>
|
| 233 |
+
{result.classification.document_type}
|
| 234 |
+
</div>
|
| 235 |
+
</div>
|
| 236 |
+
|
| 237 |
+
<div style={{ marginBottom: "12px" }}>
|
| 238 |
+
<span style={{ color: "#666", fontSize: "14px" }}>
|
| 239 |
+
Confidence:
|
| 240 |
+
</span>
|
| 241 |
+
<div
|
| 242 |
+
style={{
|
| 243 |
+
fontSize: "20px",
|
| 244 |
+
fontWeight: "600",
|
| 245 |
+
color: "#059669",
|
| 246 |
+
marginTop: "4px",
|
| 247 |
+
}}
|
| 248 |
+
>
|
| 249 |
+
{(result.classification.confidence * 100).toFixed(1)}%
|
| 250 |
+
</div>
|
| 251 |
+
</div>
|
| 252 |
+
|
| 253 |
+
<div style={{ marginTop: "16px", paddingTop: "16px", borderTop: "1px solid #e5e7eb" }}>
|
| 254 |
+
<span style={{ color: "#666", fontSize: "14px" }}>
|
| 255 |
+
File: {result.filename}
|
| 256 |
+
</span>
|
| 257 |
+
<br />
|
| 258 |
+
<span style={{ color: "#666", fontSize: "14px" }}>
|
| 259 |
+
Text Length: {result.text_length.toLocaleString()} characters
|
| 260 |
+
</span>
|
| 261 |
+
</div>
|
| 262 |
+
</div>
|
| 263 |
+
|
| 264 |
+
{result.classification.all_scores && (
|
| 265 |
+
<div>
|
| 266 |
+
<h3 style={{ margin: "0 0 12px 0", fontSize: "16px", color: "#166534" }}>
|
| 267 |
+
Top 5 Classifications:
|
| 268 |
+
</h3>
|
| 269 |
+
<div style={{ backgroundColor: "white", padding: "16px", borderRadius: "8px" }}>
|
| 270 |
+
{Object.entries(result.classification.all_scores).map(
|
| 271 |
+
([type, score]) => (
|
| 272 |
+
<div
|
| 273 |
+
key={type}
|
| 274 |
+
style={{
|
| 275 |
+
display: "flex",
|
| 276 |
+
justifyContent: "space-between",
|
| 277 |
+
alignItems: "center",
|
| 278 |
+
padding: "8px 0",
|
| 279 |
+
borderBottom: "1px solid #f3f4f6",
|
| 280 |
+
}}
|
| 281 |
+
>
|
| 282 |
+
<span style={{ textTransform: "capitalize", color: "#374151" }}>
|
| 283 |
+
{type}
|
| 284 |
+
</span>
|
| 285 |
+
<span
|
| 286 |
+
style={{
|
| 287 |
+
fontWeight: "600",
|
| 288 |
+
color: type === result.classification.document_type ? "#10b981" : "#6b7280",
|
| 289 |
+
}}
|
| 290 |
+
>
|
| 291 |
+
{(score * 100).toFixed(1)}%
|
| 292 |
+
</span>
|
| 293 |
+
</div>
|
| 294 |
+
)
|
| 295 |
+
)}
|
| 296 |
+
</div>
|
| 297 |
+
</div>
|
| 298 |
+
)}
|
| 299 |
+
</div>
|
| 300 |
+
)}
|
| 301 |
+
|
| 302 |
+
{loading && (
|
| 303 |
+
<div style={{ textAlign: "center", padding: "24px" }}>
|
| 304 |
+
<div
|
| 305 |
+
style={{
|
| 306 |
+
display: "inline-block",
|
| 307 |
+
width: "40px",
|
| 308 |
+
height: "40px",
|
| 309 |
+
border: "4px solid #e5e7eb",
|
| 310 |
+
borderTop: "4px solid #4f46e5",
|
| 311 |
+
borderRadius: "50%",
|
| 312 |
+
animation: "spin 1s linear infinite",
|
| 313 |
+
}}
|
| 314 |
+
/>
|
| 315 |
+
<p style={{ marginTop: "16px", color: "#666" }}>
|
| 316 |
+
Processing your document...
|
| 317 |
+
</p>
|
| 318 |
+
</div>
|
| 319 |
+
)}
|
| 320 |
</div>
|
| 321 |
);
|
| 322 |
}
|
test_classifier.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Test the document classifier."""
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# Add backend to path
|
| 7 |
+
sys.path.insert(0, str(Path(__file__).parent / "backend"))
|
| 8 |
+
|
| 9 |
+
from app.classifier import get_classifier
|
| 10 |
+
|
| 11 |
+
def test_classifier():
|
| 12 |
+
"""Test the classifier with sample text."""
|
| 13 |
+
print("Loading classifier...")
|
| 14 |
+
classifier = get_classifier()
|
| 15 |
+
|
| 16 |
+
# Test with sample texts
|
| 17 |
+
test_cases = [
|
| 18 |
+
("Invoice for services rendered. Total amount due: $500.00. Payment terms: Net 30.", "invoice"),
|
| 19 |
+
("This is to certify that John Doe has completed the course.", "certificate"),
|
| 20 |
+
("Dear Sir, I am writing to inform you...", "letter"),
|
| 21 |
+
("Account Statement - Account #12345. Balance: $1,000.00", "bank statement"),
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
print("\n" + "="*60)
|
| 25 |
+
print("Testing Document Classifier")
|
| 26 |
+
print("="*60 + "\n")
|
| 27 |
+
|
| 28 |
+
for i, (text, expected_type) in enumerate(test_cases, 1):
|
| 29 |
+
print(f"Test {i}: Expected type: {expected_type}")
|
| 30 |
+
print(f"Text: {text[:50]}...")
|
| 31 |
+
|
| 32 |
+
result = classifier.classify_document(text)
|
| 33 |
+
|
| 34 |
+
print(f"✅ Classified as: {result['document_type']}")
|
| 35 |
+
print(f" Confidence: {result['confidence']:.1%}")
|
| 36 |
+
print(f" Top 3: {list(result['all_scores'].keys())[:3]}")
|
| 37 |
+
|
| 38 |
+
if result['document_type'] == expected_type:
|
| 39 |
+
print(" ✅ PASS - Correct classification!")
|
| 40 |
+
else:
|
| 41 |
+
print(f" ⚠️ Expected '{expected_type}' but got '{result['document_type']}'")
|
| 42 |
+
|
| 43 |
+
print()
|
| 44 |
+
|
| 45 |
+
print("="*60)
|
| 46 |
+
print("Test completed!")
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
try:
|
| 50 |
+
test_classifier()
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"❌ Error during testing: {e}")
|
| 53 |
+
import traceback
|
| 54 |
+
traceback.print_exc()
|
| 55 |
+
sys.exit(1)
|