Qasim-Dost commited on
Commit
ce9f3ac
·
verified ·
1 Parent(s): 7209782

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +27 -20
  2. app.py +363 -0
  3. pdf_to_image.py +93 -0
  4. requirements.txt +15 -3
  5. smolvlm_classifier.py +227 -0
Dockerfile CHANGED
@@ -1,20 +1,27 @@
1
- FROM python:3.13.5-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- git \
9
- && rm -rf /var/lib/apt/lists/*
10
-
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
13
-
14
- RUN pip3 install -r requirements.txt
15
-
16
- EXPOSE 8501
17
-
18
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
-
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies for PyMuPDF
6
+ RUN apt-get update && apt-get install -y \
7
+ libmupdf-dev \
8
+ mupdf-tools \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements and install Python dependencies
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy application files
16
+ COPY app.py .
17
+ COPY smolvlm_classifier.py .
18
+ COPY pdf_to_image.py .
19
+
20
+ # Expose Streamlit port
21
+ EXPOSE 7860
22
+
23
+ # Disable torch.compile for HF Spaces compatibility
24
+ ENV DISABLE_TORCH_COMPILE=1
25
+
26
+ # Run Streamlit
27
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.enableCORS=false", "--server.enableXsrfProtection=false"]
app.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit UI for Document Classification
3
+ Upload PDFs and classify them using SmolVLM.
4
+ Optimized with pre-loading and concurrent processing.
5
+ """
6
+
7
+ import streamlit as st
8
+ import pandas as pd
9
+ import json
10
+ from pathlib import Path
11
+ from datetime import datetime
12
+ import tempfile
13
+ import os
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+ import threading
16
+
17
+ # Import our classifier modules
18
+ from pdf_to_image import pdf_to_images
19
+ from smolvlm_classifier import SmolVLMClassifier
20
+
21
+
22
+ # Page config
23
+ st.set_page_config(
24
+ page_title="Document Classifier",
25
+ page_icon="📄",
26
+ layout="wide"
27
+ )
28
+
29
+ # Custom CSS for better styling
30
+ st.markdown("""
31
+ <style>
32
+ .main-header {
33
+ font-size: 2.5rem;
34
+ font-weight: bold;
35
+ color: #1f77b4;
36
+ margin-bottom: 1rem;
37
+ }
38
+ .result-box {
39
+ background-color: #f0f8ff;
40
+ padding: 0.8rem 1rem;
41
+ border-radius: 8px;
42
+ border-left: 4px solid #1f77b4;
43
+ margin: 0.5rem 0;
44
+ display: inline-block;
45
+ }
46
+ .doc-type {
47
+ font-size: 1.2rem;
48
+ font-weight: bold;
49
+ color: #2e7d32;
50
+ margin: 0;
51
+ }
52
+ .file-info {
53
+ font-size: 0.9rem;
54
+ color: #555;
55
+ margin: 0.2rem 0;
56
+ }
57
+ .model-status {
58
+ padding: 0.5rem;
59
+ border-radius: 5px;
60
+ margin-bottom: 1rem;
61
+ }
62
+ </style>
63
+ """, unsafe_allow_html=True)
64
+
65
+
66
+ @st.cache_resource
67
+ def load_classifier():
68
+ """Load the classifier once and cache it."""
69
+ return SmolVLMClassifier()
70
+
71
+
72
+ def load_history():
73
+ """Load classification history from JSON file."""
74
+ history_file = Path("classification_history.json")
75
+ if history_file.exists():
76
+ with open(history_file, "r", encoding="utf-8") as f:
77
+ return json.load(f)
78
+ return []
79
+
80
+
81
+ def save_history(history):
82
+ """Save classification history to JSON file."""
83
+ with open("classification_history.json", "w", encoding="utf-8") as f:
84
+ json.dump(history, f, indent=2, ensure_ascii=False)
85
+
86
+
87
+ def add_to_history(filename, doc_type, num_pages):
88
+ """Add a classification result to history."""
89
+ history = load_history()
90
+ history.insert(0, {
91
+ "filename": filename,
92
+ "document_type": doc_type,
93
+ "num_pages": num_pages,
94
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
95
+ })
96
+ # Keep only last 100 entries
97
+ history = history[:100]
98
+ save_history(history)
99
+ return history
100
+
101
+
102
+ def convert_pdf_to_images(uploaded_file):
103
+ """Convert a single PDF to images. Used for threading."""
104
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
105
+ tmp_file.write(uploaded_file.getvalue())
106
+ tmp_path = tmp_file.name
107
+ try:
108
+ images = pdf_to_images(tmp_path, dpi=100)
109
+ return uploaded_file.name, images
110
+ finally:
111
+ os.unlink(tmp_path)
112
+
113
+
114
+ def main():
115
+ # Header
116
+ st.markdown('<div class="main-header">📄 Document Classifier</div>', unsafe_allow_html=True)
117
+ st.markdown("Upload PDF documents to classify them using SmolVLM AI.")
118
+
119
+ # PRE-LOAD MODEL AT APP START (not on button click)
120
+ # This runs once when the app starts
121
+ with st.spinner("🔄 Loading AI model (one-time setup)..."):
122
+ classifier = load_classifier()
123
+ st.success("✅ Model ready!")
124
+
125
+ # Sidebar for history
126
+ with st.sidebar:
127
+ st.header("📋 Classification History")
128
+ history = load_history()
129
+
130
+ if history:
131
+ # Show as table
132
+ df_history = pd.DataFrame(history)
133
+ st.dataframe(
134
+ df_history[["filename", "document_type", "timestamp"]],
135
+ hide_index=True,
136
+ width="stretch"
137
+ )
138
+
139
+ # Clear history button
140
+ if st.button("🗑️ Clear History"):
141
+ save_history([])
142
+ st.rerun()
143
+ else:
144
+ st.info("No classification history yet. Upload a document to get started!")
145
+
146
+ # Main content - two columns
147
+ col1, col2 = st.columns([1, 1])
148
+
149
+ with col1:
150
+ st.subheader("📤 Upload Documents")
151
+
152
+ # File uploader - MULTIPLE FILES
153
+ uploaded_files = st.file_uploader(
154
+ "Choose PDF files",
155
+ type=["pdf"],
156
+ accept_multiple_files=True,
157
+ help="Upload one or more PDF documents to classify"
158
+ )
159
+
160
+ if uploaded_files:
161
+ st.success(f"✅ Uploaded {len(uploaded_files)} file(s)")
162
+
163
+ # Store images for preview
164
+ if "pdf_previews" not in st.session_state:
165
+ st.session_state["pdf_previews"] = {}
166
+
167
+ # Show file list with preview option
168
+ for f in uploaded_files:
169
+ with st.expander(f"📄 {f.name} ({f.size / 1024:.1f} KB)", expanded=False):
170
+ # Generate preview if not cached
171
+ if f.name not in st.session_state["pdf_previews"]:
172
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
173
+ tmp_file.write(f.getvalue())
174
+ tmp_path = tmp_file.name
175
+ try:
176
+ images = pdf_to_images(tmp_path, dpi=100)
177
+ st.session_state["pdf_previews"][f.name] = images
178
+ finally:
179
+ os.unlink(tmp_path)
180
+
181
+ # Show preview
182
+ images = st.session_state["pdf_previews"].get(f.name, [])
183
+ if images:
184
+ if len(images) > 1:
185
+ page_num = st.selectbox(
186
+ f"Page",
187
+ range(1, len(images) + 1),
188
+ key=f"page_{f.name}"
189
+ )
190
+ st.image(images[page_num - 1], caption=f"Page {page_num} of {len(images)}", width="stretch")
191
+ else:
192
+ st.image(images[0], caption="Page 1", width="stretch")
193
+ else:
194
+ st.error("Could not load PDF preview")
195
+
196
+ # Classify button
197
+ if st.button("🔍 Classify All Documents", type="primary", width="stretch"):
198
+ import time
199
+
200
+ all_results = []
201
+ progress_bar = st.progress(0)
202
+ status_text = st.empty()
203
+
204
+ total_start_time = time.time()
205
+
206
+ # STEP 1: Pre-convert all PDFs to images using threading
207
+ status_text.text("📄 Converting PDFs to images (parallel)...")
208
+ pdf_conversion_start = time.time()
209
+
210
+ pdf_images = {}
211
+
212
+ # Use ThreadPoolExecutor for parallel PDF conversion
213
+ with ThreadPoolExecutor(max_workers=4) as executor:
214
+ # Submit all PDF conversion tasks
215
+ future_to_file = {
216
+ executor.submit(convert_pdf_to_images, f): f
217
+ for f in uploaded_files
218
+ if f.name not in st.session_state.get("pdf_previews", {})
219
+ }
220
+
221
+ # Also add cached previews
222
+ for f in uploaded_files:
223
+ if f.name in st.session_state.get("pdf_previews", {}):
224
+ pdf_images[f.name] = st.session_state["pdf_previews"][f.name]
225
+
226
+ # Collect results
227
+ for future in as_completed(future_to_file):
228
+ filename, images = future.result()
229
+ pdf_images[filename] = images
230
+
231
+ pdf_conversion_time = time.time() - pdf_conversion_start
232
+ print(f"\n📄 PDF Conversion: {pdf_conversion_time:.2f}s (parallel)")
233
+
234
+ progress_bar.progress(0.2)
235
+ status_text.text("🤖 Classifying documents...")
236
+
237
+ # STEP 2: Classify each document with timing
238
+ classification_start = time.time()
239
+
240
+ for idx, uploaded_file in enumerate(uploaded_files):
241
+ doc_start_time = time.time()
242
+ images = pdf_images.get(uploaded_file.name, [])
243
+
244
+ if not images:
245
+ result = {
246
+ "filename": uploaded_file.name,
247
+ "document_type": "Error: Could not extract pages",
248
+ "num_pages": 0,
249
+ "classify_time": 0
250
+ }
251
+ else:
252
+ status_text.text(f"🤖 Classifying {idx + 1}/{len(uploaded_files)}: {uploaded_file.name}")
253
+
254
+ # Classify with timing
255
+ classify_start = time.time()
256
+ classification = classifier.classify_document(images)
257
+ classify_time = time.time() - classify_start
258
+
259
+ result = {
260
+ "filename": uploaded_file.name,
261
+ "document_type": classification["document_type"],
262
+ "num_pages": classification["num_pages"],
263
+ "classify_time": round(classify_time, 2)
264
+ }
265
+
266
+ # Terminal output
267
+ print(f" 📄 {uploaded_file.name}")
268
+ print(f" Pages: {classification['num_pages']}")
269
+ print(f" Type: {classification['document_type']}")
270
+ print(f" Classification time: {classify_time:.2f}s")
271
+
272
+ # Add to history
273
+ add_to_history(
274
+ uploaded_file.name,
275
+ classification["document_type"],
276
+ classification["num_pages"]
277
+ )
278
+
279
+ all_results.append(result)
280
+
281
+ # Update progress
282
+ progress_bar.progress(0.2 + 0.8 * (idx + 1) / len(uploaded_files))
283
+
284
+ total_classification_time = time.time() - classification_start
285
+ total_time = time.time() - total_start_time
286
+
287
+ # Print summary to terminal
288
+ print(f"\n{'='*50}")
289
+ print("TIMING SUMMARY")
290
+ print(f"{'='*50}")
291
+ print(f"Documents processed: {len(all_results)}")
292
+ print(f"PDF conversion (parallel): {pdf_conversion_time:.2f}s")
293
+ print(f"Classification (sequential): {total_classification_time:.2f}s")
294
+ print(f"Average per document: {total_classification_time/len(all_results):.2f}s")
295
+ print(f"Total time: {total_time:.2f}s ({total_time/60:.1f} min)")
296
+ print(f"{'='*50}\n")
297
+
298
+ # Store timing info
299
+ st.session_state["timing"] = {
300
+ "pdf_conversion": round(pdf_conversion_time, 2),
301
+ "classification": round(total_classification_time, 2),
302
+ "total": round(total_time, 2),
303
+ "total_min": round(total_time / 60, 2),
304
+ "avg_per_doc": round(total_classification_time / len(all_results), 2)
305
+ }
306
+
307
+ status_text.text(f"✅ Complete! Total: {total_time:.1f}s ({total_time/60:.1f} min)")
308
+ st.session_state["results"] = all_results
309
+
310
+ with col2:
311
+ st.subheader("📊 Classification Results")
312
+
313
+ # Show results
314
+ if "results" in st.session_state and st.session_state["results"]:
315
+ results = st.session_state["results"]
316
+
317
+ # Show as compact table with timing
318
+ df_results = pd.DataFrame(results)
319
+ st.dataframe(
320
+ df_results,
321
+ hide_index=True,
322
+ width="stretch",
323
+ column_config={
324
+ "filename": st.column_config.TextColumn("File", width="medium"),
325
+ "document_type": st.column_config.TextColumn("Type", width="medium"),
326
+ "num_pages": st.column_config.NumberColumn("Pages", width="small"),
327
+ "classify_time": st.column_config.NumberColumn("Time (s)", width="small")
328
+ }
329
+ )
330
+
331
+ # Show timing summary if available
332
+ if "timing" in st.session_state:
333
+ timing = st.session_state["timing"]
334
+ st.markdown("---")
335
+ st.markdown("**⏱️ Timing Summary**")
336
+ col_t1, col_t2, col_t3 = st.columns(3)
337
+ with col_t1:
338
+ st.metric("PDF Conversion", f"{timing['pdf_conversion']}s")
339
+ with col_t2:
340
+ st.metric("Classification", f"{timing['classification']}s")
341
+ with col_t3:
342
+ st.metric("Avg per Doc", f"{timing['avg_per_doc']}s")
343
+
344
+ st.info(f"**Total Time:** {timing['total']}s ({timing['total_min']} min)")
345
+
346
+ # Summary
347
+ st.success(f"✅ Classified {len(results)} document(s)")
348
+
349
+ # Show individual result boxes (compact)
350
+ for result in results:
351
+ st.markdown(f"""
352
+ <div class="result-box">
353
+ <p class="file-info"><strong>{result['filename']}</strong> ({result['num_pages']} pages)</p>
354
+ <p class="doc-type">📑 {result['document_type']}</p>
355
+ </div>
356
+ """, unsafe_allow_html=True)
357
+
358
+ else:
359
+ st.info("👆 Upload and classify documents to see results here.")
360
+
361
+
362
+ if __name__ == "__main__":
363
+ main()
pdf_to_image.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PDF to Image Conversion using PyMuPDF (fitz)
3
+ Converts all pages of a PDF to PIL Images.
4
+ """
5
+
6
+ import fitz # PyMuPDF
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ from typing import List, Tuple
10
+ import io
11
+
12
+
13
+ def pdf_to_images(pdf_path: str, dpi: int = 150) -> List[Image.Image]:
14
+ """
15
+ Convert all pages of a PDF to PIL Images.
16
+
17
+ Args:
18
+ pdf_path: Path to the PDF file
19
+ dpi: Resolution for rendering (default 150, balance of quality/speed)
20
+
21
+ Returns:
22
+ List of PIL Images, one per page
23
+ """
24
+ images = []
25
+
26
+ try:
27
+ doc = fitz.open(pdf_path)
28
+
29
+ for page_num in range(len(doc)):
30
+ page = doc[page_num]
31
+ # Create pixmap at specified DPI
32
+ zoom = dpi / 72 # 72 is default PDF DPI
33
+ matrix = fitz.Matrix(zoom, zoom)
34
+ pix = page.get_pixmap(matrix=matrix)
35
+
36
+ # Convert to PIL Image
37
+ img_data = pix.tobytes("png")
38
+ img = Image.open(io.BytesIO(img_data))
39
+ images.append(img.convert("RGB"))
40
+
41
+ doc.close()
42
+
43
+ except Exception as e:
44
+ print(f"Error converting {pdf_path}: {e}")
45
+ return []
46
+
47
+ return images
48
+
49
+
50
+ def get_pdf_page_count(pdf_path: str) -> int:
51
+ """Get the number of pages in a PDF."""
52
+ try:
53
+ doc = fitz.open(pdf_path)
54
+ count = len(doc)
55
+ doc.close()
56
+ return count
57
+ except:
58
+ return 0
59
+
60
+
61
+ def collect_pdfs(folder_path: str, recursive: bool = True) -> List[Path]:
62
+ """
63
+ Collect all PDF files from a folder.
64
+
65
+ Args:
66
+ folder_path: Path to folder containing PDFs
67
+ recursive: Whether to search subfolders
68
+
69
+ Returns:
70
+ List of Path objects for each PDF
71
+ """
72
+ folder = Path(folder_path)
73
+
74
+ if recursive:
75
+ return list(folder.rglob("*.pdf"))
76
+ else:
77
+ return list(folder.glob("*.pdf"))
78
+
79
+
80
+ if __name__ == "__main__":
81
+ # Quick test
82
+ import sys
83
+
84
+ if len(sys.argv) > 1:
85
+ pdf_path = sys.argv[1]
86
+ print(f"Converting: {pdf_path}")
87
+ images = pdf_to_images(pdf_path)
88
+ print(f"Extracted {len(images)} pages")
89
+
90
+ if images:
91
+ print(f"First page size: {images[0].size}")
92
+ else:
93
+ print("Usage: python pdf_to_image.py <path_to_pdf>")
requirements.txt CHANGED
@@ -1,3 +1,15 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML dependencies
2
+ torch
3
+ transformers
4
+ accelerate
5
+
6
+ # PDF processing
7
+ PyMuPDF
8
+ Pillow
9
+
10
+ # Data handling
11
+ pandas
12
+ tqdm
13
+
14
+ # Web framework
15
+ streamlit
smolvlm_classifier.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolVLM-256M-Instruct Document Classifier
3
+ Uses instruction-following VLM for zero-shot document classification.
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoProcessor, AutoModelForImageTextToText
8
+ from transformers.image_utils import load_image
9
+ from PIL import Image
10
+ from typing import List, Dict
11
+
12
+
13
+ class SmolVLMClassifier:
14
+ """
15
+ SmolVLM-based document classifier.
16
+ Uses instruction-following to directly ask about document type.
17
+ """
18
+
19
+ def __init__(self, model_name: str = "HuggingFaceTB/SmolVLM-256M-Instruct"):
20
+ """
21
+ Initialize the SmolVLM model.
22
+
23
+ Args:
24
+ model_name: HuggingFace model name
25
+ """
26
+ print(f"Loading {model_name}...")
27
+
28
+ # CPU with float32 for compatibility
29
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
31
+
32
+ # Load processor and model
33
+ self.processor = AutoProcessor.from_pretrained(model_name)
34
+
35
+ self.model = AutoModelForImageTextToText.from_pretrained(
36
+ model_name,
37
+ dtype=self.torch_dtype,
38
+ _attn_implementation="eager" # CPU compatible
39
+ ).to(self.device)
40
+
41
+ # Compile model for faster inference (optional - can cause issues on some platforms)
42
+ # Set DISABLE_TORCH_COMPILE=1 to skip compilation
43
+ import os
44
+ if os.environ.get("DISABLE_TORCH_COMPILE", "0") != "1":
45
+ try:
46
+ print("Compiling model with torch.compile (first run will be slow)...")
47
+ self.model = torch.compile(self.model, mode="reduce-overhead")
48
+ print(f"Model loaded and compiled on {self.device}")
49
+ except Exception as e:
50
+ print(f"torch.compile failed ({e}), using uncompiled model")
51
+ print(f"Model loaded on {self.device}")
52
+ else:
53
+ print(f"Model loaded on {self.device} (torch.compile disabled)")
54
+
55
+ def ask_about_image(self, image: Image.Image, question: str) -> str:
56
+ """
57
+ Ask a question about an image.
58
+
59
+ Args:
60
+ image: PIL Image
61
+ question: Question to ask about the image
62
+
63
+ Returns:
64
+ Answer string
65
+ """
66
+ # Ensure RGB
67
+ if image.mode != "RGB":
68
+ image = image.convert("RGB")
69
+
70
+ # Create chat message format
71
+ messages = [
72
+ {
73
+ "role": "user",
74
+ "content": [
75
+ {"type": "image"},
76
+ {"type": "text", "text": question}
77
+ ]
78
+ }
79
+ ]
80
+
81
+ # Apply chat template
82
+ prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
83
+
84
+ # Process inputs
85
+ inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
86
+ inputs = inputs.to(self.device)
87
+
88
+ # Generate response (limited tokens for speed - only need short answer)
89
+ generated_ids = self.model.generate(
90
+ **inputs,
91
+ max_new_tokens=30, # Reduced from 150 for faster inference
92
+ do_sample=False
93
+ )
94
+
95
+ # Decode response
96
+ generated_text = self.processor.batch_decode(
97
+ generated_ids,
98
+ skip_special_tokens=True
99
+ )[0]
100
+
101
+ # Extract just the assistant's response (after the prompt)
102
+ if "Assistant:" in generated_text:
103
+ response = generated_text.split("Assistant:")[-1].strip()
104
+ else:
105
+ response = generated_text.strip()
106
+
107
+ return response
108
+
109
+ def classify_document(self, images: List[Image.Image]) -> Dict:
110
+ """
111
+ Classify a document by analyzing the first page only.
112
+ First page typically contains header/title which identifies document type.
113
+
114
+ Args:
115
+ images: List of PIL Images (one per page)
116
+
117
+ Returns:
118
+ Dict with document_type and num_pages
119
+ """
120
+ if not images:
121
+ return {
122
+ "document_type": "Unknown",
123
+ "num_pages": 0
124
+ }
125
+
126
+ print(f" Classifying document ({len(images)} pages, analyzing first page)...")
127
+
128
+ # Classification question with 12-class system
129
+ # Tier 1: Main business documents (7 classes)
130
+ # Tier 2: Grouped categories (5 classes)
131
+ classification_question = """What type of document is this?
132
+
133
+ Choose ONE from these categories:
134
+ - Invoice (factura, bill for payment)
135
+ - PurchaseOrder (order form, purchase request)
136
+ - DeliveryNote (delivery slip, shipping document)
137
+ - CreditNote (credit memo, refund document)
138
+ - DebitNote (debit memo, additional charge)
139
+ - OrderConfirmation (order acknowledgment)
140
+ - QuotationOffer (quote, price proposal)
141
+ - IdentityDocument (ID card, passport, DNI, NIE)
142
+ - PayrollDocument (salary slip, work contract)
143
+ - VehicleDocument (car papers, registration, insurance, ITV)
144
+ - EmployeeDocument (employee records, HR documents)
145
+ - Other (anything else)
146
+
147
+ Answer with just the category name, nothing else."""
148
+
149
+ # Get document type from first page only (fastest approach)
150
+ doc_type = self.ask_about_image(images[0], classification_question)
151
+
152
+ # Clean up and normalize response
153
+ doc_type = doc_type.strip().split('\n')[0].strip()
154
+ doc_type = self._normalize_category(doc_type)
155
+
156
+ print(f" → Document type: {doc_type}")
157
+
158
+ return {
159
+ "document_type": doc_type,
160
+ "num_pages": len(images)
161
+ }
162
+
163
+ def _normalize_category(self, raw_type: str) -> str:
164
+ """
165
+ Normalize VLM output to standard category names.
166
+ Maps variations and translations to canonical names.
167
+ """
168
+ raw_lower = raw_type.lower().strip().rstrip('.')
169
+
170
+ # Main business documents (Tier 1)
171
+ if any(x in raw_lower for x in ['invoice', 'factura', 'bill']):
172
+ if 'credit' in raw_lower:
173
+ return 'CreditNote'
174
+ if 'debit' in raw_lower:
175
+ return 'DebitNote'
176
+ return 'Invoice'
177
+
178
+ if any(x in raw_lower for x in ['purchase', 'order form', 'compra']):
179
+ return 'PurchaseOrder'
180
+
181
+ if any(x in raw_lower for x in ['delivery', 'shipping', 'albarán', 'entrega']):
182
+ return 'DeliveryNote'
183
+
184
+ if any(x in raw_lower for x in ['credit note', 'credit memo', 'refund']):
185
+ return 'CreditNote'
186
+
187
+ if any(x in raw_lower for x in ['debit note', 'debit memo']):
188
+ return 'DebitNote'
189
+
190
+ if any(x in raw_lower for x in ['order confirmation', 'confirmation', 'confirmación']):
191
+ return 'OrderConfirmation'
192
+
193
+ if any(x in raw_lower for x in ['quotation', 'quote', 'offer', 'presupuesto', 'oferta']):
194
+ return 'QuotationOffer'
195
+
196
+ # Grouped categories (Tier 2)
197
+ if any(x in raw_lower for x in ['identity', 'passport', 'dni', 'nie', 'id card', 'identificación']):
198
+ return 'IdentityDocument'
199
+
200
+ if any(x in raw_lower for x in ['payroll', 'salary', 'wage', 'nómina', 'work contract', 'contrato']):
201
+ return 'PayrollDocument'
202
+
203
+ if any(x in raw_lower for x in ['vehicle', 'car', 'registration', 'insurance', 'itv', 'circulación', 'seguro', 'ficha técnica']):
204
+ return 'VehicleDocument'
205
+
206
+ if any(x in raw_lower for x in ['employee', 'hr', 'personnel', 'empleado']):
207
+ return 'EmployeeDocument'
208
+
209
+ if any(x in raw_lower for x in ['receipt', 'recibo', 'ticket']):
210
+ return 'Invoice' # Map receipts to Invoice
211
+
212
+ if any(x in raw_lower for x in ['utility', 'electric', 'gas', 'water', 'luz', 'agua']):
213
+ return 'Invoice' # Utility bills are invoices
214
+
215
+ # Default
216
+ return 'Other'
217
+
218
+
219
+ if __name__ == "__main__":
220
+ # Quick test
221
+ print("Initializing SmolVLM classifier...")
222
+ classifier = SmolVLMClassifier()
223
+
224
+ # Test with a simple image
225
+ test_img = Image.new("RGB", (400, 300), color="white")
226
+ response = classifier.ask_about_image(test_img, "What do you see in this image?")
227
+ print(f"Test response: {response}")