raahinaez commited on
Commit
3d1c323
·
verified ·
1 Parent(s): 5cc11aa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +441 -0
app.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import fitz # PyMuPDF
3
+ import pytesseract
4
+ from PIL import Image
5
+ import io
6
+ import time
7
+ import os
8
+ from transformers import pipeline, AutoTokenizer
9
+
10
+ # Try to import fasttext, handle gracefully if not available
11
+ try:
12
+ import fasttext
13
+ FASTTEXT_AVAILABLE = True
14
+ except ImportError:
15
+ FASTTEXT_AVAILABLE = False
16
+ fasttext = None
17
+
18
+ # Configure Tesseract path
19
+ # For Windows: use specific path
20
+ # For Linux (HF Spaces): Tesseract is usually in PATH, but we can check common locations
21
+ if os.name == 'nt': # Windows
22
+ tesseract_path = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
23
+ if os.path.exists(tesseract_path):
24
+ pytesseract.pytesseract.tesseract_cmd = tesseract_path
25
+ else:
26
+ # Linux/Unix: Check common Tesseract locations
27
+ # On Hugging Face Spaces, Tesseract should be in PATH after installing via packages.txt
28
+ common_paths = [
29
+ '/usr/bin/tesseract',
30
+ '/usr/local/bin/tesseract',
31
+ 'tesseract' # Try system PATH
32
+ ]
33
+ for path in common_paths:
34
+ try:
35
+ # Try to find tesseract in PATH or at specific location
36
+ import shutil
37
+ tesseract_cmd = shutil.which('tesseract') or path
38
+ if tesseract_cmd and (os.path.exists(tesseract_cmd) or tesseract_cmd == 'tesseract'):
39
+ pytesseract.pytesseract.tesseract_cmd = tesseract_cmd
40
+ break
41
+ except Exception:
42
+ continue
43
+
44
+ def get_hf_token():
45
+ """
46
+ Get Hugging Face token from Streamlit secrets (for HF Spaces) or environment variables (for local).
47
+ Priority: Streamlit secrets > Environment variables
48
+ """
49
+ # Try Streamlit secrets first (for Hugging Face Spaces deployment)
50
+ try:
51
+ if hasattr(st, 'secrets') and 'HF_TOKEN' in st.secrets:
52
+ return st.secrets['HF_TOKEN']
53
+ # Also check nested structure (HF.TOKEN)
54
+ if hasattr(st, 'secrets') and 'HF' in st.secrets and 'TOKEN' in st.secrets['HF']:
55
+ return st.secrets['HF']['TOKEN']
56
+ except Exception:
57
+ pass
58
+
59
+ # Fallback to environment variables (for local development)
60
+ return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
61
+
62
+ # Configure page
63
+ st.set_page_config(
64
+ page_title="Document Classification Performance Testing",
65
+ page_icon="📄",
66
+ layout="centered"
67
+ )
68
+
69
+ # Initialize session state
70
+ if 'models_loaded' not in st.session_state:
71
+ st.session_state.models_loaded = {}
72
+
73
+ def extract_text_from_pdf(pdf_file):
74
+ """Extract text from PDF using PyMuPDF."""
75
+ try:
76
+ pdf_bytes = pdf_file.read()
77
+ pdf_file.seek(0) # Reset file pointer
78
+ doc = fitz.open(stream=pdf_bytes, filetype="pdf")
79
+ text = ""
80
+ for page in doc:
81
+ text += page.get_text()
82
+ doc.close()
83
+ return text.strip()
84
+ except Exception as e:
85
+ st.error(f"Error extracting text from PDF: {str(e)}")
86
+ return None
87
+
88
+ def extract_text_from_image(image_file):
89
+ """Extract text from image using Tesseract OCR."""
90
+ try:
91
+ image_bytes = image_file.read()
92
+ image_file.seek(0) # Reset file pointer
93
+ image = Image.open(io.BytesIO(image_bytes))
94
+ text = pytesseract.image_to_string(image)
95
+ return text.strip()
96
+ except Exception as e:
97
+ st.error(f"Error extracting text from image: {str(e)}")
98
+ return None
99
+
100
+ def load_distilbert_model():
101
+ """Load DistilBERT model for zero-shot classification."""
102
+ if 'distilbert' not in st.session_state.models_loaded:
103
+ with st.spinner("Loading DistilBERT model (first time may take a while)..."):
104
+ try:
105
+ model_name = "distilbert-base-uncased"
106
+ classifier = pipeline(
107
+ "zero-shot-classification",
108
+ model=model_name,
109
+ truncation=True, # Enable automatic truncation
110
+ max_length=512 # Set max length
111
+ )
112
+ # Also store tokenizer for accurate truncation if needed
113
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
114
+ st.session_state.models_loaded['distilbert'] = classifier
115
+ st.session_state.models_loaded['distilbert_tokenizer'] = tokenizer
116
+ except Exception as e:
117
+ st.error(f"Error loading DistilBERT model: {str(e)}")
118
+ return None
119
+ return st.session_state.models_loaded['distilbert']
120
+
121
+ def load_tinybert_model():
122
+ """Load TinyBERT model for zero-shot classification."""
123
+ if 'tinybert' not in st.session_state.models_loaded:
124
+ with st.spinner("Loading TinyBERT model (first time may take a while)..."):
125
+ try:
126
+ # Try alternative TinyBERT model identifiers
127
+ # google/tinybert-6L-384D may not exist, try huawei-noah version
128
+ model_name = "huawei-noah/TinyBERT_General_6L_768D"
129
+
130
+ # Get Hugging Face token (works for both local and HF Spaces)
131
+ hf_token = get_hf_token()
132
+
133
+ # Try with token first
134
+ try:
135
+ classifier = pipeline(
136
+ "zero-shot-classification",
137
+ model=model_name,
138
+ token=hf_token if hf_token else None,
139
+ truncation=True,
140
+ max_length=512
141
+ )
142
+ tokenizer = AutoTokenizer.from_pretrained(
143
+ model_name,
144
+ token=hf_token if hf_token else None
145
+ )
146
+ except Exception as token_error:
147
+ # If that fails, try without token (for public models)
148
+ classifier = pipeline(
149
+ "zero-shot-classification",
150
+ model=model_name,
151
+ truncation=True,
152
+ max_length=512
153
+ )
154
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
155
+
156
+ # Also store tokenizer for accurate truncation
157
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token if hf_token else None)
158
+ st.session_state.models_loaded['tinybert'] = classifier
159
+ st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer
160
+ except Exception as e:
161
+ error_msg = str(e)
162
+ # Try the original model name as fallback
163
+ if "huawei-noah" in error_msg.lower() or "not a valid model identifier" in error_msg.lower():
164
+ try:
165
+ st.info("Trying alternative TinyBERT model identifier...")
166
+ model_name = "google/tinybert-6L-384D"
167
+ hf_token = get_hf_token()
168
+ classifier = pipeline(
169
+ "zero-shot-classification",
170
+ model=model_name,
171
+ token=hf_token if hf_token else None,
172
+ truncation=True,
173
+ max_length=512
174
+ )
175
+ tokenizer = AutoTokenizer.from_pretrained(
176
+ model_name,
177
+ token=hf_token if hf_token else None
178
+ )
179
+ st.session_state.models_loaded['tinybert'] = classifier
180
+ st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer
181
+ except Exception as e2:
182
+ st.error(
183
+ f"Error loading TinyBERT model. Tried both 'huawei-noah/TinyBERT_General_6L_768D' "
184
+ f"and 'google/tinybert-6L-384D'. "
185
+ f"If these models are private, ensure your Hugging Face token is set in the "
186
+ f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. "
187
+ f"Details: {str(e2)}"
188
+ )
189
+ return None
190
+ else:
191
+ st.error(
192
+ f"Error loading TinyBERT model. "
193
+ f"If this model is private, ensure your Hugging Face token is set in the "
194
+ f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. "
195
+ f"Details: {error_msg}"
196
+ )
197
+ return None
198
+ return st.session_state.models_loaded['tinybert']
199
+
200
+ def load_fasttext_model():
201
+ """Load FastText model."""
202
+ if not FASTTEXT_AVAILABLE:
203
+ st.error("FastText is not installed. Please install it using: pip install fasttext")
204
+ return None
205
+
206
+ if 'fasttext' not in st.session_state.models_loaded:
207
+ with st.spinner("Loading FastText model (first time may take a while)..."):
208
+ try:
209
+ # Using a pre-trained language identification model as an example
210
+ # Note: For production document classification, you would use a custom trained FastText model
211
+ # Download the model if not present: https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin
212
+ model_path = "lid.176.bin"
213
+ if not os.path.exists(model_path):
214
+ st.warning("FastText model file not found. Please download 'lid.176.bin' from "
215
+ "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin "
216
+ "and place it in the current directory.")
217
+ return None
218
+ model = fasttext.load_model(model_path)
219
+ st.session_state.models_loaded['fasttext'] = model
220
+ except Exception as e:
221
+ st.error(f"Error loading FastText model: {str(e)}")
222
+ return None
223
+ return st.session_state.models_loaded['fasttext']
224
+
225
+ def truncate_text_for_model(text, tokenizer=None, max_length=500):
226
+ """
227
+ Truncate text to fit within model's maximum sequence length.
228
+ If tokenizer is provided, uses accurate token counting.
229
+ Otherwise, uses character-based approximation.
230
+ """
231
+ if tokenizer is not None:
232
+ # Use tokenizer for accurate truncation
233
+ tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length, truncation=True)
234
+ if len(tokens) >= max_length:
235
+ # Decode back to text (this will be properly truncated)
236
+ truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
237
+ return truncated_text, True
238
+ return text, False
239
+ else:
240
+ # Fallback: character-based approximation (roughly 4 chars per token)
241
+ # Leave buffer for special tokens
242
+ max_chars = (max_length - 12) * 4
243
+
244
+ if len(text) <= max_chars:
245
+ return text, False
246
+
247
+ # Truncate at word boundary
248
+ truncated = text[:max_chars].rsplit(' ', 1)[0]
249
+ return truncated + "...", True
250
+
251
+ def classify_with_distilbert(classifier, text, candidate_labels):
252
+ """Classify text using DistilBERT."""
253
+ if classifier is None:
254
+ return None, None
255
+
256
+ # Get tokenizer if available for accurate truncation
257
+ tokenizer = st.session_state.models_loaded.get('distilbert_tokenizer')
258
+
259
+ # Truncate text if needed (DistilBERT max length is 512)
260
+ truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500)
261
+ if was_truncated:
262
+ st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).")
263
+
264
+ try:
265
+ result = classifier(truncated_text, candidate_labels, truncation=True)
266
+ return result['labels'][0], result['scores'][0]
267
+ except Exception as e:
268
+ st.error(f"Classification error: {str(e)}")
269
+ return None, None
270
+
271
+ def classify_with_tinybert(classifier, text, candidate_labels):
272
+ """Classify text using TinyBERT."""
273
+ if classifier is None:
274
+ return None, None
275
+
276
+ # Get tokenizer if available for accurate truncation
277
+ tokenizer = st.session_state.models_loaded.get('tinybert_tokenizer')
278
+
279
+ # Truncate text if needed (TinyBERT max length is 512)
280
+ truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500)
281
+ if was_truncated:
282
+ st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).")
283
+
284
+ try:
285
+ result = classifier(truncated_text, candidate_labels, truncation=True)
286
+ return result['labels'][0], result['scores'][0]
287
+ except Exception as e:
288
+ st.error(f"Classification error: {str(e)}")
289
+ return None, None
290
+
291
+ def classify_with_fasttext(model, text):
292
+ """Classify text using FastText."""
293
+ if model is None:
294
+ return None, None
295
+ # FastText language identification as example
296
+ # In production, use a custom trained model for document classification
297
+ predictions = model.predict(text, k=1)
298
+ label = predictions[0][0].replace('__label__', '')
299
+ score = float(predictions[1][0])
300
+ return label, score
301
+
302
+ # Main UI
303
+ st.title("📄 Document Classification Performance Testing")
304
+ st.markdown("---")
305
+
306
+ # File upload
307
+ uploaded_file = st.file_uploader(
308
+ "Upload a document",
309
+ type=['pdf', 'png', 'jpg', 'jpeg'],
310
+ help="Upload a PDF or Image file (PNG/JPG/JPEG)"
311
+ )
312
+
313
+ # Model selection
314
+ model_options = [
315
+ "distilbert-base-uncased",
316
+ "google/tinybert-6L-384D"
317
+ ]
318
+
319
+ # FastText option commented out
320
+ # if FASTTEXT_AVAILABLE:
321
+ # model_options.append("FastText")
322
+ # else:
323
+ # model_options.append("FastText (Not Available)")
324
+
325
+ model_option = st.selectbox(
326
+ "Select Model",
327
+ options=model_options,
328
+ help="Choose a model for document classification"
329
+ )
330
+
331
+ # FastText warning commented out
332
+ # if model_option == "FastText (Not Available)":
333
+ # st.warning("⚠️ FastText is not installed. To use FastText, you can try installing a pre-built wheel: "
334
+ # "`pip install fasttext-wheel` or use conda: `conda install -c conda-forge fasttext`. "
335
+ # "Alternatively, use DistilBERT or TinyBERT models.")
336
+
337
+ # Classification button
338
+ if st.button("Classify Document", type="primary"):
339
+ if uploaded_file is None:
340
+ st.warning("Please upload a file first.")
341
+ else:
342
+ # Extract text based on file type
343
+ file_extension = uploaded_file.name.split('.')[-1].lower()
344
+
345
+ if file_extension == 'pdf':
346
+ extracted_text = extract_text_from_pdf(uploaded_file)
347
+ elif file_extension in ['png', 'jpg', 'jpeg']:
348
+ extracted_text = extract_text_from_image(uploaded_file)
349
+ else:
350
+ st.error("Unsupported file type.")
351
+ extracted_text = None
352
+
353
+ # Check if text extraction was successful
354
+ if extracted_text is None:
355
+ st.error("Failed to extract text from the document.")
356
+ elif len(extracted_text.strip()) == 0:
357
+ st.error("No text could be extracted from the document. The document may be empty or contain only images.")
358
+ else:
359
+ # Display extracted text preview
360
+ with st.expander("Extracted Text Preview", expanded=False):
361
+ st.text(extracted_text[:500] + "..." if len(extracted_text) > 500 else extracted_text)
362
+
363
+ # Define candidate labels for zero-shot classification
364
+ # These are example labels - adjust based on your use case
365
+ candidate_labels = [
366
+ "invoice",
367
+ "contract",
368
+ "report",
369
+ "letter",
370
+ "receipt",
371
+ "form",
372
+ "memo",
373
+ "other"
374
+ ]
375
+
376
+ # Load model and classify
377
+ start_time = time.time()
378
+
379
+ if model_option == "distilbert-base-uncased":
380
+ classifier = load_distilbert_model()
381
+ if classifier:
382
+ predicted_label, confidence = classify_with_distilbert(
383
+ classifier, extracted_text, candidate_labels
384
+ )
385
+ model_name = "DistilBERT Base Uncased"
386
+ else:
387
+ predicted_label, confidence = None, None
388
+ model_name = "DistilBERT Base Uncased"
389
+
390
+ elif model_option == "google/tinybert-6L-384D":
391
+ classifier = load_tinybert_model()
392
+ if classifier:
393
+ predicted_label, confidence = classify_with_tinybert(
394
+ classifier, extracted_text, candidate_labels
395
+ )
396
+ model_name = "TinyBERT 6L-384D"
397
+ else:
398
+ predicted_label, confidence = None, None
399
+ model_name = "TinyBERT 6L-384D"
400
+
401
+ # FastText option commented out
402
+ # elif model_option in ["FastText", "FastText (Not Available)"]:
403
+ # if not FASTTEXT_AVAILABLE:
404
+ # st.error("FastText is not available. Please install it or select a different model.")
405
+ # predicted_label, confidence = None, None
406
+ # model_name = "FastText"
407
+ # else:
408
+ # model = load_fasttext_model()
409
+ # if model:
410
+ # predicted_label, confidence = classify_with_fasttext(model, extracted_text)
411
+ # model_name = "FastText"
412
+ # else:
413
+ # predicted_label, confidence = None, None
414
+ # model_name = "FastText"
415
+
416
+ inference_time = time.time() - start_time
417
+
418
+ # Display results
419
+ if predicted_label is not None:
420
+ st.success("Classification Complete!")
421
+ st.markdown("---")
422
+
423
+ col1, col2, col3 = st.columns(3)
424
+
425
+ with col1:
426
+ st.metric("Predicted Label", predicted_label)
427
+
428
+ with col2:
429
+ st.metric("Confidence", f"{confidence:.2%}" if confidence else "N/A")
430
+
431
+ with col3:
432
+ st.metric("Inference Time", f"{inference_time:.4f}s")
433
+
434
+ st.info(f"**Model Used:** {model_name}")
435
+ else:
436
+ st.error("Classification failed. Please check the model loading status above.")
437
+
438
+ # Footer
439
+ st.markdown("---")
440
+ st.caption("Document Classification Performance Testing Tool")
441
+