Seth commited on
Commit
f6e574f
·
1 Parent(s): b434cd3
.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ venv/
9
+ ENV/
10
+ .venv
11
+ *.egg-info/
12
+ dist/
13
+ build/
14
+
15
+ # Virtual environment
16
+ venv/
17
+
18
+ # Models (large files)
19
+ Model/
20
+ models/
21
+ *.bin
22
+ *.safetensors
23
+
24
+ # Node
25
+ node_modules/
26
+ npm-debug.log*
27
+ yarn-debug.log*
28
+ yarn-error.log*
29
+
30
+ # Frontend build
31
+ frontend/dist/
32
+ frontend/.vite/
33
+
34
+ # IDE
35
+ .vscode/
36
+ .idea/
37
+ *.swp
38
+ *.swo
39
+ *~
40
+
41
+ # OS
42
+ .DS_Store
43
+ Thumbs.db
44
+
45
+ # Logs
46
+ *.log
README.md CHANGED
@@ -1,10 +1,98 @@
1
  ---
2
  title: DocClassify
3
- emoji: 🐨
4
  colorFrom: yellow
5
  colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: DocClassify
3
+ emoji: 📄
4
  colorFrom: yellow
5
  colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
+ # Document Classifier
11
+
12
+ A web application that uses BERT-tiny to classify PDF documents by type. Upload a PDF file and get instant classification results.
13
+
14
+ ## Features
15
+
16
+ - 📄 PDF file upload and processing
17
+ - 🤖 BERT-tiny model for document classification
18
+ - 🎯 Classifies 20+ document types including:
19
+ - Invoice, Receipt, Contract, Resume
20
+ - Letter, Report, Memo, Email
21
+ - Form, Certificate, License, Passport
22
+ - Medical records, Bank statements, Tax documents
23
+ - Legal documents, Academic papers, and more
24
+ - 💾 Model is downloaded and cached locally on first use
25
+ - 🎨 Modern, user-friendly interface
26
+
27
+ ## How It Works
28
+
29
+ 1. The app uses the `prajjwal1/bert-tiny` model from Hugging Face
30
+ 2. On first run, the model is automatically downloaded to the `models/` directory
31
+ 3. PDF text is extracted using PyPDF2
32
+ 4. Document embeddings are computed using BERT-tiny
33
+ 5. Similarity scores are calculated against pre-computed document type embeddings
34
+ 6. The document is classified with confidence scores
35
+
36
+ ## Setup
37
+
38
+ ### Local Development
39
+
40
+ 1. **Backend Setup:**
41
+ ```bash
42
+ cd backend
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ 2. **Frontend Setup:**
47
+ ```bash
48
+ cd frontend
49
+ npm install
50
+ ```
51
+
52
+ 3. **Run Backend:**
53
+ ```bash
54
+ cd backend
55
+ uvicorn app.main:app --reload --port 8000
56
+ ```
57
+
58
+ 4. **Run Frontend:**
59
+ ```bash
60
+ cd frontend
61
+ npm run dev
62
+ ```
63
+
64
+ 5. Open `http://localhost:5173` in your browser
65
+
66
+ ### Docker Deployment
67
+
68
+ ```bash
69
+ docker build -t docclassify .
70
+ docker run -p 7860:7860 docclassify
71
+ ```
72
+
73
+ ## Usage
74
+
75
+ 1. Click "Select PDF File" to choose a PDF document
76
+ 2. Click "Classify Document" to process the file
77
+ 3. View the classification result with confidence scores
78
+ 4. See top 5 document type predictions
79
+
80
+ ## Model Information
81
+
82
+ - **Model:** `prajjwal1/bert-tiny`
83
+ - **Size:** ~4.4M parameters
84
+ - **Architecture:** BERT (L=2, H=128)
85
+ - **Source:** [Hugging Face Model Card](https://huggingface.co/prajjwal1/bert-tiny)
86
+
87
+ ## Technical Stack
88
+
89
+ - **Backend:** FastAPI, PyTorch, Transformers, PyPDF2
90
+ - **Frontend:** React, Vite
91
+ - **Model:** BERT-tiny (prajjwal1/bert-tiny)
92
+
93
+ ## Notes
94
+
95
+ - The model will be automatically downloaded on first use (~17MB)
96
+ - Classification works best with text-based PDFs
97
+ - Image-based PDFs may not work if they don't contain extractable text
98
+ - Processing time depends on document size and system resources
backend/app/classifier.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document classification using BERT-tiny model."""
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Dict, Optional
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+
10
+ # Model configuration
11
+ MODEL_NAME = "prajjwal1/bert-tiny"
12
+ # Models directory: use /app/Model in Docker, or project_root/Model locally
13
+ # Check if we're in Docker by looking for /app directory
14
+ if Path("/app").exists() and Path("/app/backend").exists():
15
+ # Docker environment
16
+ MODELS_DIR = Path("/app/Model")
17
+ else:
18
+ # Local development - go up from backend/app/classifier.py to project root
19
+ MODELS_DIR = Path(__file__).resolve().parent.parent.parent / "Model"
20
+ MODEL_PATH = MODELS_DIR / "bert-tiny"
21
+
22
+ # Common document types with descriptions for better classification
23
+ DOCUMENT_TYPES = {
24
+ "invoice": "A document requesting payment for goods or services provided, containing itemized charges, totals, and payment terms.",
25
+ "receipt": "A document confirming payment has been received, showing transaction details and proof of purchase.",
26
+ "contract": "A legally binding agreement between parties outlining terms, conditions, obligations, and signatures.",
27
+ "resume": "A document summarizing a person's work experience, education, skills, and qualifications for job applications.",
28
+ "letter": "A formal or informal written correspondence addressed to a recipient with greetings and closing.",
29
+ "report": "A structured document presenting analysis, findings, conclusions, and recommendations on a specific topic.",
30
+ "memo": "An internal business communication document with headers like To, From, Subject, and Date.",
31
+ "email": "Electronic mail correspondence with headers showing sender, recipient, subject, and message content.",
32
+ "form": "A structured document with fields to be filled out, often requiring signatures and dates.",
33
+ "certificate": "An official document certifying completion, achievement, or qualification with certification details.",
34
+ "license": "An official document granting permission to perform certain activities, with license numbers and expiration dates.",
35
+ "passport": "An official government document for international travel containing personal identification and nationality information.",
36
+ "medical record": "Healthcare documentation containing patient information, diagnoses, treatments, and medical history.",
37
+ "bank statement": "A financial document from a bank showing account transactions, balances, deposits, and withdrawals.",
38
+ "tax document": "Tax-related paperwork such as W-2 forms, 1099 forms, tax returns, or IRS correspondence.",
39
+ "legal document": "Court documents, legal filings, contracts, or other documents related to legal proceedings or matters.",
40
+ "academic paper": "A scholarly document with abstract, introduction, methodology, results, references, and citations.",
41
+ "presentation": "A document with slides, bullet points, or structured content for presenting information to an audience.",
42
+ "manual": "An instructional document providing step-by-step procedures, guidelines, or how-to information.",
43
+ "other": "A document that does not clearly fit into any of the above categories."
44
+ }
45
+
46
+
47
+ class DocumentClassifier:
48
+ """Class for classifying documents using BERT-tiny."""
49
+
50
+ def __init__(self):
51
+ self.tokenizer = None
52
+ self.model = None
53
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ self._load_model()
55
+ self._precompute_type_embeddings()
56
+
57
+ def _load_model(self):
58
+ """Load the BERT-tiny model, downloading if necessary."""
59
+ try:
60
+ # Check if model exists locally, otherwise download
61
+ if MODEL_PATH.exists():
62
+ print(f"Loading model from local path: {MODEL_PATH}")
63
+ model_path = str(MODEL_PATH)
64
+ else:
65
+ print(f"Downloading model {MODEL_NAME}...")
66
+ model_path = MODEL_NAME
67
+ # Create models directory
68
+ MODELS_DIR.mkdir(parents=True, exist_ok=True)
69
+
70
+ # Load tokenizer and model (using AutoModel for embeddings)
71
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
72
+ self.model = AutoModel.from_pretrained(model_path)
73
+ self.model.to(self.device)
74
+ self.model.eval()
75
+
76
+ # Save model locally if downloaded
77
+ if not MODEL_PATH.exists():
78
+ print(f"Saving model to {MODEL_PATH}...")
79
+ self.tokenizer.save_pretrained(str(MODEL_PATH))
80
+ self.model.save_pretrained(str(MODEL_PATH))
81
+ print("Model saved successfully!")
82
+
83
+ except Exception as e:
84
+ print(f"Error loading model: {e}")
85
+ raise
86
+
87
+ def _get_embedding(self, text: str, max_length: int = 512) -> torch.Tensor:
88
+ """Get embedding for a text using BERT-tiny."""
89
+ inputs = self.tokenizer(
90
+ text,
91
+ return_tensors="pt",
92
+ truncation=True,
93
+ max_length=max_length,
94
+ padding=True
95
+ ).to(self.device)
96
+
97
+ with torch.no_grad():
98
+ outputs = self.model(**inputs)
99
+ # Use mean pooling of token embeddings
100
+ embeddings = outputs.last_hidden_state.mean(dim=1)
101
+
102
+ return embeddings
103
+
104
+ def _precompute_type_embeddings(self):
105
+ """Precompute embeddings for each document type description."""
106
+ print("Precomputing document type embeddings...")
107
+ self.type_embeddings = {}
108
+
109
+ for doc_type, description in DOCUMENT_TYPES.items():
110
+ # Combine type name and description for better representation
111
+ text = f"{doc_type}: {description}"
112
+ embedding = self._get_embedding(text)
113
+ self.type_embeddings[doc_type] = embedding
114
+
115
+ print("Document type embeddings computed!")
116
+
117
+ def classify_document(self, text: str, max_length: int = 512) -> Dict[str, any]:
118
+ """
119
+ Classify a document based on its text content using BERT-tiny embeddings.
120
+
121
+ Args:
122
+ text: Document text content
123
+ max_length: Maximum token length for the model
124
+
125
+ Returns:
126
+ Dictionary with classification results
127
+ """
128
+ if not text or not text.strip():
129
+ return {
130
+ "document_type": "unknown",
131
+ "confidence": 0.0,
132
+ "error": "No text extracted from document"
133
+ }
134
+
135
+ try:
136
+ # Truncate text if too long (keep first part which usually has most relevant info)
137
+ if len(text) > max_length * 4: # Rough estimate: 4 chars per token
138
+ # Take first part and last part for better context
139
+ first_part = text[:max_length * 2]
140
+ last_part = text[-max_length * 2:]
141
+ text = first_part + " " + last_part
142
+
143
+ # Get embedding for the document text
144
+ doc_embedding = self._get_embedding(text, max_length)
145
+
146
+ # Calculate cosine similarity with each document type
147
+ scores = {}
148
+ for doc_type, type_embedding in self.type_embeddings.items():
149
+ # Calculate cosine similarity
150
+ similarity = F.cosine_similarity(doc_embedding, type_embedding, dim=1)
151
+ scores[doc_type] = similarity.item()
152
+
153
+ # Normalize scores to 0-1 range using softmax
154
+ score_values = torch.tensor(list(scores.values()))
155
+ normalized_scores = F.softmax(score_values, dim=0)
156
+
157
+ # Update scores with normalized values
158
+ normalized_dict = {}
159
+ for i, doc_type in enumerate(scores.keys()):
160
+ normalized_dict[doc_type] = normalized_scores[i].item()
161
+
162
+ # Find the best match
163
+ best_type = max(normalized_dict.items(), key=lambda x: x[1])
164
+
165
+ # Get top 5 classifications
166
+ top_5 = sorted(normalized_dict.items(), key=lambda x: x[1], reverse=True)[:5]
167
+
168
+ return {
169
+ "document_type": best_type[0],
170
+ "confidence": round(best_type[1], 3),
171
+ "all_scores": {k: round(v, 3) for k, v in top_5},
172
+ "text_preview": text[:200] + "..." if len(text) > 200 else text
173
+ }
174
+
175
+ except Exception as e:
176
+ print(f"Error classifying document: {e}")
177
+ import traceback
178
+ traceback.print_exc()
179
+ return {
180
+ "document_type": "unknown",
181
+ "confidence": 0.0,
182
+ "error": str(e)
183
+ }
184
+
185
+
186
+ # Global classifier instance
187
+ _classifier_instance = None
188
+
189
+ def get_classifier() -> DocumentClassifier:
190
+ """Get or create the global classifier instance."""
191
+ global _classifier_instance
192
+ if _classifier_instance is None:
193
+ _classifier_instance = DocumentClassifier()
194
+ return _classifier_instance
backend/app/main.py CHANGED
@@ -1,8 +1,10 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import FileResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pathlib import Path
 
 
6
 
7
  app = FastAPI()
8
 
@@ -14,6 +16,16 @@ app.add_middleware(
14
  allow_headers=["*"],
15
  )
16
 
 
 
 
 
 
 
 
 
 
 
17
  # ---- API ----
18
  @app.get("/api/health")
19
  def health():
@@ -23,6 +35,50 @@ def health():
23
  def hello():
24
  return {"message": "Hello from FastAPI"}
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # ---- Frontend static serving ----
27
  FRONTEND_DIST = Path(__file__).resolve().parents[2] / "frontend" / "dist"
28
  INDEX_FILE = FRONTEND_DIST / "index.html"
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.responses import FileResponse, JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pathlib import Path
6
+ from app.pdf_processor import extract_text_from_pdf
7
+ from app.classifier import get_classifier
8
 
9
  app = FastAPI()
10
 
 
16
  allow_headers=["*"],
17
  )
18
 
19
+ # Initialize classifier (lazy loading)
20
+ classifier = None
21
+
22
+ def get_classifier_instance():
23
+ """Lazy load the classifier."""
24
+ global classifier
25
+ if classifier is None:
26
+ classifier = get_classifier()
27
+ return classifier
28
+
29
  # ---- API ----
30
  @app.get("/api/health")
31
  def health():
 
35
  def hello():
36
  return {"message": "Hello from FastAPI"}
37
 
38
+ @app.post("/api/classify")
39
+ async def classify_document(file: UploadFile = File(...)):
40
+ """
41
+ Classify a PDF document.
42
+
43
+ Args:
44
+ file: Uploaded PDF file
45
+
46
+ Returns:
47
+ Classification results with document type and confidence
48
+ """
49
+ # Validate file type
50
+ if not file.filename.lower().endswith('.pdf'):
51
+ raise HTTPException(status_code=400, detail="Only PDF files are supported")
52
+
53
+ try:
54
+ # Read file content
55
+ contents = await file.read()
56
+
57
+ # Extract text from PDF
58
+ text = extract_text_from_pdf(contents)
59
+
60
+ if not text:
61
+ raise HTTPException(
62
+ status_code=400,
63
+ detail="Could not extract text from PDF. The file might be empty, corrupted, or image-based."
64
+ )
65
+
66
+ # Classify the document
67
+ classifier_instance = get_classifier_instance()
68
+ result = classifier_instance.classify_document(text)
69
+
70
+ return JSONResponse(content={
71
+ "success": True,
72
+ "filename": file.filename,
73
+ "classification": result,
74
+ "text_length": len(text)
75
+ })
76
+
77
+ except HTTPException:
78
+ raise
79
+ except Exception as e:
80
+ raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
81
+
82
  # ---- Frontend static serving ----
83
  FRONTEND_DIST = Path(__file__).resolve().parents[2] / "frontend" / "dist"
84
  INDEX_FILE = FRONTEND_DIST / "index.html"
backend/app/pdf_processor.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PDF text extraction utilities."""
2
+ import io
3
+ from typing import Optional
4
+ from PyPDF2 import PdfReader
5
+
6
+
7
+ def extract_text_from_pdf(pdf_bytes: bytes) -> Optional[str]:
8
+ """
9
+ Extract text content from a PDF file.
10
+
11
+ Args:
12
+ pdf_bytes: PDF file content as bytes
13
+
14
+ Returns:
15
+ Extracted text as string, or None if extraction fails
16
+ """
17
+ try:
18
+ pdf_file = io.BytesIO(pdf_bytes)
19
+ reader = PdfReader(pdf_file)
20
+
21
+ text_parts = []
22
+ for page in reader.pages:
23
+ text = page.extract_text()
24
+ if text:
25
+ text_parts.append(text)
26
+
27
+ full_text = "\n\n".join(text_parts)
28
+ return full_text if full_text.strip() else None
29
+ except Exception as e:
30
+ print(f"Error extracting text from PDF: {e}")
31
+ return None
backend/requirements.txt CHANGED
@@ -1,2 +1,8 @@
1
  fastapi
2
- uvicorn
 
 
 
 
 
 
 
1
  fastapi
2
+ uvicorn
3
+ python-multipart
4
+ transformers
5
+ torch
6
+ PyPDF2
7
+ sentencepiece
8
+ protobuf
download_model.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Download BERT-tiny model upfront to Model folder."""
3
+ from pathlib import Path
4
+ from transformers import AutoTokenizer, AutoModel
5
+
6
+ MODEL_NAME = "prajjwal1/bert-tiny"
7
+ MODELS_DIR = Path(__file__).resolve().parent / "Model"
8
+ MODEL_PATH = MODELS_DIR / "bert-tiny"
9
+
10
+ def download_model():
11
+ """Download and save the BERT-tiny model."""
12
+ print(f"Downloading model: {MODEL_NAME}")
13
+ print(f"Target directory: {MODEL_PATH}")
14
+
15
+ # Create Model directory
16
+ MODELS_DIR.mkdir(parents=True, exist_ok=True)
17
+
18
+ if MODEL_PATH.exists():
19
+ print(f"Model already exists at {MODEL_PATH}")
20
+ print("Skipping download.")
21
+ return
22
+
23
+ try:
24
+ print("Downloading tokenizer...")
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
+
27
+ print("Downloading model...")
28
+ model = AutoModel.from_pretrained(MODEL_NAME)
29
+
30
+ print(f"Saving model to {MODEL_PATH}...")
31
+ tokenizer.save_pretrained(str(MODEL_PATH))
32
+ model.save_pretrained(str(MODEL_PATH))
33
+
34
+ print("✅ Model downloaded and saved successfully!")
35
+ print(f"Location: {MODEL_PATH}")
36
+
37
+ except Exception as e:
38
+ print(f"❌ Error downloading model: {e}")
39
+ raise
40
+
41
+ if __name__ == "__main__":
42
+ download_model()
frontend/index.html CHANGED
@@ -3,7 +3,13 @@
3
  <head>
4
  <meta charset="UTF-8" />
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
- <title>HF React + FastAPI by Seth</title>
 
 
 
 
 
 
7
  </head>
8
  <body>
9
  <div id="root"></div>
 
3
  <head>
4
  <meta charset="UTF-8" />
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>Document Classifier - BERT-tiny</title>
7
+ <style>
8
+ @keyframes spin {
9
+ 0% { transform: rotate(0deg); }
10
+ 100% { transform: rotate(360deg); }
11
+ }
12
+ </style>
13
  </head>
14
  <body>
15
  <div id="root"></div>
frontend/src/App.jsx CHANGED
@@ -1,23 +1,322 @@
1
- import React, { useEffect, useState } from "react";
2
 
3
  export default function App() {
4
- const [apiMsg, setApiMsg] = useState("");
 
 
 
5
 
6
- useEffect(() => {
7
- fetch("/api/hello")
8
- .then((r) => r.json())
9
- .then((d) => setApiMsg(d.message))
10
- .catch(() => setApiMsg("API not reachable yet"));
11
- }, []);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  return (
14
- <div style={{ fontFamily: "system-ui", padding: 24, lineHeight: 1.5 }}>
15
- <h1>React + FastAPI (Docker, HF Spaces)</h1>
16
- <p>This is a plain starter page. Customize freely.By Seth</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- <div style={{ marginTop: 16, padding: 12, border: "1px solid #ddd", borderRadius: 8 }}>
19
- <strong>API says:</strong> {apiMsg}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  </div>
22
  );
23
  }
 
1
+ import React, { useState } from "react";
2
 
3
  export default function App() {
4
+ const [file, setFile] = useState(null);
5
+ const [loading, setLoading] = useState(false);
6
+ const [result, setResult] = useState(null);
7
+ const [error, setError] = useState(null);
8
 
9
+ const handleFileChange = (e) => {
10
+ const selectedFile = e.target.files[0];
11
+ if (selectedFile) {
12
+ if (selectedFile.type !== "application/pdf") {
13
+ setError("Please select a PDF file");
14
+ setFile(null);
15
+ return;
16
+ }
17
+ setFile(selectedFile);
18
+ setError(null);
19
+ setResult(null);
20
+ }
21
+ };
22
+
23
+ const handleClassify = async () => {
24
+ if (!file) {
25
+ setError("Please select a PDF file first");
26
+ return;
27
+ }
28
+
29
+ setLoading(true);
30
+ setError(null);
31
+ setResult(null);
32
+
33
+ const formData = new FormData();
34
+ formData.append("file", file);
35
+
36
+ try {
37
+ const response = await fetch("/api/classify", {
38
+ method: "POST",
39
+ body: formData,
40
+ });
41
+
42
+ if (!response.ok) {
43
+ const errorData = await response.json();
44
+ throw new Error(errorData.detail || "Classification failed");
45
+ }
46
+
47
+ const data = await response.json();
48
+ setResult(data);
49
+ } catch (err) {
50
+ setError(err.message || "An error occurred during classification");
51
+ } finally {
52
+ setLoading(false);
53
+ }
54
+ };
55
+
56
+ const handleReset = () => {
57
+ setFile(null);
58
+ setResult(null);
59
+ setError(null);
60
+ // Reset file input
61
+ const fileInput = document.getElementById("pdf-upload");
62
+ if (fileInput) fileInput.value = "";
63
+ };
64
 
65
  return (
66
+ <div
67
+ style={{
68
+ fontFamily: "system-ui, -apple-system, sans-serif",
69
+ maxWidth: "800px",
70
+ margin: "0 auto",
71
+ padding: "24px",
72
+ lineHeight: 1.6,
73
+ }}
74
+ >
75
+ <div style={{ textAlign: "center", marginBottom: "32px" }}>
76
+ <h1 style={{ margin: "0 0 8px 0", color: "#1a1a1a" }}>
77
+ 📄 Document Classifier
78
+ </h1>
79
+ <p style={{ color: "#666", margin: "0" }}>
80
+ Upload a PDF file to classify its document type using BERT-tiny
81
+ </p>
82
+ </div>
83
+
84
+ <div
85
+ style={{
86
+ border: "2px dashed #ddd",
87
+ borderRadius: "12px",
88
+ padding: "32px",
89
+ textAlign: "center",
90
+ backgroundColor: "#fafafa",
91
+ marginBottom: "24px",
92
+ }}
93
+ >
94
+ <input
95
+ id="pdf-upload"
96
+ type="file"
97
+ accept=".pdf"
98
+ onChange={handleFileChange}
99
+ style={{ display: "none" }}
100
+ />
101
+ <label
102
+ htmlFor="pdf-upload"
103
+ style={{
104
+ display: "inline-block",
105
+ padding: "12px 24px",
106
+ backgroundColor: "#4f46e5",
107
+ color: "white",
108
+ borderRadius: "8px",
109
+ cursor: "pointer",
110
+ fontSize: "16px",
111
+ fontWeight: "500",
112
+ marginBottom: "16px",
113
+ transition: "background-color 0.2s",
114
+ }}
115
+ onMouseOver={(e) => (e.target.style.backgroundColor = "#4338ca")}
116
+ onMouseOut={(e) => (e.target.style.backgroundColor = "#4f46e5")}
117
+ >
118
+ {file ? "Change PDF File" : "Select PDF File"}
119
+ </label>
120
 
121
+ {file && (
122
+ <div style={{ marginTop: "16px" }}>
123
+ <p style={{ margin: "8px 0", color: "#333" }}>
124
+ <strong>Selected:</strong> {file.name}
125
+ </p>
126
+ <p style={{ margin: "4px 0", color: "#666", fontSize: "14px" }}>
127
+ Size: {(file.size / 1024).toFixed(2)} KB
128
+ </p>
129
+ </div>
130
+ )}
131
+
132
+ <div style={{ marginTop: "24px" }}>
133
+ <button
134
+ onClick={handleClassify}
135
+ disabled={!file || loading}
136
+ style={{
137
+ padding: "12px 32px",
138
+ fontSize: "16px",
139
+ fontWeight: "600",
140
+ backgroundColor: file && !loading ? "#10b981" : "#9ca3af",
141
+ color: "white",
142
+ border: "none",
143
+ borderRadius: "8px",
144
+ cursor: file && !loading ? "pointer" : "not-allowed",
145
+ transition: "background-color 0.2s",
146
+ }}
147
+ onMouseOver={(e) => {
148
+ if (file && !loading) {
149
+ e.target.style.backgroundColor = "#059669";
150
+ }
151
+ }}
152
+ onMouseOut={(e) => {
153
+ if (file && !loading) {
154
+ e.target.style.backgroundColor = "#10b981";
155
+ }
156
+ }}
157
+ >
158
+ {loading ? "Classifying..." : "Classify Document"}
159
+ </button>
160
+
161
+ {file && (
162
+ <button
163
+ onClick={handleReset}
164
+ disabled={loading}
165
+ style={{
166
+ padding: "12px 24px",
167
+ fontSize: "16px",
168
+ fontWeight: "500",
169
+ backgroundColor: "transparent",
170
+ color: "#666",
171
+ border: "1px solid #ddd",
172
+ borderRadius: "8px",
173
+ cursor: loading ? "not-allowed" : "pointer",
174
+ marginLeft: "12px",
175
+ }}
176
+ >
177
+ Reset
178
+ </button>
179
+ )}
180
+ </div>
181
  </div>
182
+
183
+ {error && (
184
+ <div
185
+ style={{
186
+ padding: "16px",
187
+ backgroundColor: "#fee2e2",
188
+ border: "1px solid #fecaca",
189
+ borderRadius: "8px",
190
+ color: "#991b1b",
191
+ marginBottom: "24px",
192
+ }}
193
+ >
194
+ <strong>Error:</strong> {error}
195
+ </div>
196
+ )}
197
+
198
+ {result && (
199
+ <div
200
+ style={{
201
+ padding: "24px",
202
+ backgroundColor: "#f0fdf4",
203
+ border: "2px solid #86efac",
204
+ borderRadius: "12px",
205
+ marginBottom: "24px",
206
+ }}
207
+ >
208
+ <h2 style={{ margin: "0 0 16px 0", color: "#166534" }}>
209
+ Classification Result
210
+ </h2>
211
+
212
+ <div
213
+ style={{
214
+ backgroundColor: "white",
215
+ padding: "20px",
216
+ borderRadius: "8px",
217
+ marginBottom: "16px",
218
+ }}
219
+ >
220
+ <div style={{ marginBottom: "12px" }}>
221
+ <span style={{ color: "#666", fontSize: "14px" }}>
222
+ Document Type:
223
+ </span>
224
+ <div
225
+ style={{
226
+ fontSize: "24px",
227
+ fontWeight: "700",
228
+ color: "#10b981",
229
+ marginTop: "4px",
230
+ textTransform: "capitalize",
231
+ }}
232
+ >
233
+ {result.classification.document_type}
234
+ </div>
235
+ </div>
236
+
237
+ <div style={{ marginBottom: "12px" }}>
238
+ <span style={{ color: "#666", fontSize: "14px" }}>
239
+ Confidence:
240
+ </span>
241
+ <div
242
+ style={{
243
+ fontSize: "20px",
244
+ fontWeight: "600",
245
+ color: "#059669",
246
+ marginTop: "4px",
247
+ }}
248
+ >
249
+ {(result.classification.confidence * 100).toFixed(1)}%
250
+ </div>
251
+ </div>
252
+
253
+ <div style={{ marginTop: "16px", paddingTop: "16px", borderTop: "1px solid #e5e7eb" }}>
254
+ <span style={{ color: "#666", fontSize: "14px" }}>
255
+ File: {result.filename}
256
+ </span>
257
+ <br />
258
+ <span style={{ color: "#666", fontSize: "14px" }}>
259
+ Text Length: {result.text_length.toLocaleString()} characters
260
+ </span>
261
+ </div>
262
+ </div>
263
+
264
+ {result.classification.all_scores && (
265
+ <div>
266
+ <h3 style={{ margin: "0 0 12px 0", fontSize: "16px", color: "#166534" }}>
267
+ Top 5 Classifications:
268
+ </h3>
269
+ <div style={{ backgroundColor: "white", padding: "16px", borderRadius: "8px" }}>
270
+ {Object.entries(result.classification.all_scores).map(
271
+ ([type, score]) => (
272
+ <div
273
+ key={type}
274
+ style={{
275
+ display: "flex",
276
+ justifyContent: "space-between",
277
+ alignItems: "center",
278
+ padding: "8px 0",
279
+ borderBottom: "1px solid #f3f4f6",
280
+ }}
281
+ >
282
+ <span style={{ textTransform: "capitalize", color: "#374151" }}>
283
+ {type}
284
+ </span>
285
+ <span
286
+ style={{
287
+ fontWeight: "600",
288
+ color: type === result.classification.document_type ? "#10b981" : "#6b7280",
289
+ }}
290
+ >
291
+ {(score * 100).toFixed(1)}%
292
+ </span>
293
+ </div>
294
+ )
295
+ )}
296
+ </div>
297
+ </div>
298
+ )}
299
+ </div>
300
+ )}
301
+
302
+ {loading && (
303
+ <div style={{ textAlign: "center", padding: "24px" }}>
304
+ <div
305
+ style={{
306
+ display: "inline-block",
307
+ width: "40px",
308
+ height: "40px",
309
+ border: "4px solid #e5e7eb",
310
+ borderTop: "4px solid #4f46e5",
311
+ borderRadius: "50%",
312
+ animation: "spin 1s linear infinite",
313
+ }}
314
+ />
315
+ <p style={{ marginTop: "16px", color: "#666" }}>
316
+ Processing your document...
317
+ </p>
318
+ </div>
319
+ )}
320
  </div>
321
  );
322
  }
test_classifier.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Test the document classifier."""
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ # Add backend to path
7
+ sys.path.insert(0, str(Path(__file__).parent / "backend"))
8
+
9
+ from app.classifier import get_classifier
10
+
11
+ def test_classifier():
12
+ """Test the classifier with sample text."""
13
+ print("Loading classifier...")
14
+ classifier = get_classifier()
15
+
16
+ # Test with sample texts
17
+ test_cases = [
18
+ ("Invoice for services rendered. Total amount due: $500.00. Payment terms: Net 30.", "invoice"),
19
+ ("This is to certify that John Doe has completed the course.", "certificate"),
20
+ ("Dear Sir, I am writing to inform you...", "letter"),
21
+ ("Account Statement - Account #12345. Balance: $1,000.00", "bank statement"),
22
+ ]
23
+
24
+ print("\n" + "="*60)
25
+ print("Testing Document Classifier")
26
+ print("="*60 + "\n")
27
+
28
+ for i, (text, expected_type) in enumerate(test_cases, 1):
29
+ print(f"Test {i}: Expected type: {expected_type}")
30
+ print(f"Text: {text[:50]}...")
31
+
32
+ result = classifier.classify_document(text)
33
+
34
+ print(f"✅ Classified as: {result['document_type']}")
35
+ print(f" Confidence: {result['confidence']:.1%}")
36
+ print(f" Top 3: {list(result['all_scores'].keys())[:3]}")
37
+
38
+ if result['document_type'] == expected_type:
39
+ print(" ✅ PASS - Correct classification!")
40
+ else:
41
+ print(f" ⚠️ Expected '{expected_type}' but got '{result['document_type']}'")
42
+
43
+ print()
44
+
45
+ print("="*60)
46
+ print("Test completed!")
47
+
48
+ if __name__ == "__main__":
49
+ try:
50
+ test_classifier()
51
+ except Exception as e:
52
+ print(f"❌ Error during testing: {e}")
53
+ import traceback
54
+ traceback.print_exc()
55
+ sys.exit(1)