File size: 18,527 Bytes
3d1c323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
import streamlit as st
import fitz  # PyMuPDF
import pytesseract
from PIL import Image
import io
import time
import os
from transformers import pipeline, AutoTokenizer

# Try to import fasttext, handle gracefully if not available
try:
    import fasttext
    FASTTEXT_AVAILABLE = True
except ImportError:
    FASTTEXT_AVAILABLE = False
    fasttext = None

# Configure Tesseract path
# For Windows: use specific path
# For Linux (HF Spaces): Tesseract is usually in PATH, but we can check common locations
if os.name == 'nt':  # Windows
    tesseract_path = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
    if os.path.exists(tesseract_path):
        pytesseract.pytesseract.tesseract_cmd = tesseract_path
else:
    # Linux/Unix: Check common Tesseract locations
    # On Hugging Face Spaces, Tesseract should be in PATH after installing via packages.txt
    common_paths = [
        '/usr/bin/tesseract',
        '/usr/local/bin/tesseract',
        'tesseract'  # Try system PATH
    ]
    for path in common_paths:
        try:
            # Try to find tesseract in PATH or at specific location
            import shutil
            tesseract_cmd = shutil.which('tesseract') or path
            if tesseract_cmd and (os.path.exists(tesseract_cmd) or tesseract_cmd == 'tesseract'):
                pytesseract.pytesseract.tesseract_cmd = tesseract_cmd
                break
        except Exception:
            continue

def get_hf_token():
    """
    Get Hugging Face token from Streamlit secrets (for HF Spaces) or environment variables (for local).
    Priority: Streamlit secrets > Environment variables
    """
    # Try Streamlit secrets first (for Hugging Face Spaces deployment)
    try:
        if hasattr(st, 'secrets') and 'HF_TOKEN' in st.secrets:
            return st.secrets['HF_TOKEN']
        # Also check nested structure (HF.TOKEN)
        if hasattr(st, 'secrets') and 'HF' in st.secrets and 'TOKEN' in st.secrets['HF']:
            return st.secrets['HF']['TOKEN']
    except Exception:
        pass
    
    # Fallback to environment variables (for local development)
    return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")

# Configure page
st.set_page_config(
    page_title="Document Classification Performance Testing",
    page_icon="📄",
    layout="centered"
)

# Initialize session state
if 'models_loaded' not in st.session_state:
    st.session_state.models_loaded = {}

def extract_text_from_pdf(pdf_file):
    """Extract text from PDF using PyMuPDF."""
    try:
        pdf_bytes = pdf_file.read()
        pdf_file.seek(0)  # Reset file pointer
        doc = fitz.open(stream=pdf_bytes, filetype="pdf")
        text = ""
        for page in doc:
            text += page.get_text()
        doc.close()
        return text.strip()
    except Exception as e:
        st.error(f"Error extracting text from PDF: {str(e)}")
        return None

def extract_text_from_image(image_file):
    """Extract text from image using Tesseract OCR."""
    try:
        image_bytes = image_file.read()
        image_file.seek(0)  # Reset file pointer
        image = Image.open(io.BytesIO(image_bytes))
        text = pytesseract.image_to_string(image)
        return text.strip()
    except Exception as e:
        st.error(f"Error extracting text from image: {str(e)}")
        return None

def load_distilbert_model():
    """Load DistilBERT model for zero-shot classification."""
    if 'distilbert' not in st.session_state.models_loaded:
        with st.spinner("Loading DistilBERT model (first time may take a while)..."):
            try:
                model_name = "distilbert-base-uncased"
                classifier = pipeline(
                    "zero-shot-classification",
                    model=model_name,
                    truncation=True,  # Enable automatic truncation
                    max_length=512    # Set max length
                )
                # Also store tokenizer for accurate truncation if needed
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                st.session_state.models_loaded['distilbert'] = classifier
                st.session_state.models_loaded['distilbert_tokenizer'] = tokenizer
            except Exception as e:
                st.error(f"Error loading DistilBERT model: {str(e)}")
                return None
    return st.session_state.models_loaded['distilbert']

def load_tinybert_model():
    """Load TinyBERT model for zero-shot classification."""
    if 'tinybert' not in st.session_state.models_loaded:
        with st.spinner("Loading TinyBERT model (first time may take a while)..."):
            try:
                # Try alternative TinyBERT model identifiers
                # google/tinybert-6L-384D may not exist, try huawei-noah version
                model_name = "huawei-noah/TinyBERT_General_6L_768D"
                
                # Get Hugging Face token (works for both local and HF Spaces)
                hf_token = get_hf_token()
                
                # Try with token first
                try:
                    classifier = pipeline(
                        "zero-shot-classification",
                        model=model_name,
                        token=hf_token if hf_token else None,
                        truncation=True,
                        max_length=512
                    )
                    tokenizer = AutoTokenizer.from_pretrained(
                        model_name, 
                        token=hf_token if hf_token else None
                    )
                except Exception as token_error:
                    # If that fails, try without token (for public models)
                    classifier = pipeline(
                        "zero-shot-classification",
                        model=model_name,
                        truncation=True,
                        max_length=512
                    )
                    tokenizer = AutoTokenizer.from_pretrained(model_name)
                
                # Also store tokenizer for accurate truncation
                tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token if hf_token else None)
                st.session_state.models_loaded['tinybert'] = classifier
                st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer
            except Exception as e:
                error_msg = str(e)
                # Try the original model name as fallback
                if "huawei-noah" in error_msg.lower() or "not a valid model identifier" in error_msg.lower():
                    try:
                        st.info("Trying alternative TinyBERT model identifier...")
                        model_name = "google/tinybert-6L-384D"
                        hf_token = get_hf_token()
                        classifier = pipeline(
                            "zero-shot-classification",
                            model=model_name,
                            token=hf_token if hf_token else None,
                            truncation=True,
                            max_length=512
                        )
                        tokenizer = AutoTokenizer.from_pretrained(
                            model_name, 
                            token=hf_token if hf_token else None
                        )
                        st.session_state.models_loaded['tinybert'] = classifier
                        st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer
                    except Exception as e2:
                        st.error(
                            f"Error loading TinyBERT model. Tried both 'huawei-noah/TinyBERT_General_6L_768D' "
                            f"and 'google/tinybert-6L-384D'. "
                            f"If these models are private, ensure your Hugging Face token is set in the "
                            f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. "
                            f"Details: {str(e2)}"
                        )
                        return None
                else:
                    st.error(
                        f"Error loading TinyBERT model. "
                        f"If this model is private, ensure your Hugging Face token is set in the "
                        f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. "
                        f"Details: {error_msg}"
                    )
                    return None
    return st.session_state.models_loaded['tinybert']

def load_fasttext_model():
    """Load FastText model."""
    if not FASTTEXT_AVAILABLE:
        st.error("FastText is not installed. Please install it using: pip install fasttext")
        return None
    
    if 'fasttext' not in st.session_state.models_loaded:
        with st.spinner("Loading FastText model (first time may take a while)..."):
            try:
                # Using a pre-trained language identification model as an example
                # Note: For production document classification, you would use a custom trained FastText model
                # Download the model if not present: https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin
                model_path = "lid.176.bin"
                if not os.path.exists(model_path):
                    st.warning("FastText model file not found. Please download 'lid.176.bin' from "
                              "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin "
                              "and place it in the current directory.")
                    return None
                model = fasttext.load_model(model_path)
                st.session_state.models_loaded['fasttext'] = model
            except Exception as e:
                st.error(f"Error loading FastText model: {str(e)}")
                return None
    return st.session_state.models_loaded['fasttext']

def truncate_text_for_model(text, tokenizer=None, max_length=500):
    """
    Truncate text to fit within model's maximum sequence length.
    If tokenizer is provided, uses accurate token counting.
    Otherwise, uses character-based approximation.
    """
    if tokenizer is not None:
        # Use tokenizer for accurate truncation
        tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length, truncation=True)
        if len(tokens) >= max_length:
            # Decode back to text (this will be properly truncated)
            truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
            return truncated_text, True
        return text, False
    else:
        # Fallback: character-based approximation (roughly 4 chars per token)
        # Leave buffer for special tokens
        max_chars = (max_length - 12) * 4
        
        if len(text) <= max_chars:
            return text, False
        
        # Truncate at word boundary
        truncated = text[:max_chars].rsplit(' ', 1)[0]
        return truncated + "...", True

def classify_with_distilbert(classifier, text, candidate_labels):
    """Classify text using DistilBERT."""
    if classifier is None:
        return None, None
    
    # Get tokenizer if available for accurate truncation
    tokenizer = st.session_state.models_loaded.get('distilbert_tokenizer')
    
    # Truncate text if needed (DistilBERT max length is 512)
    truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500)
    if was_truncated:
        st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).")
    
    try:
        result = classifier(truncated_text, candidate_labels, truncation=True)
        return result['labels'][0], result['scores'][0]
    except Exception as e:
        st.error(f"Classification error: {str(e)}")
        return None, None

def classify_with_tinybert(classifier, text, candidate_labels):
    """Classify text using TinyBERT."""
    if classifier is None:
        return None, None
    
    # Get tokenizer if available for accurate truncation
    tokenizer = st.session_state.models_loaded.get('tinybert_tokenizer')
    
    # Truncate text if needed (TinyBERT max length is 512)
    truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500)
    if was_truncated:
        st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).")
    
    try:
        result = classifier(truncated_text, candidate_labels, truncation=True)
        return result['labels'][0], result['scores'][0]
    except Exception as e:
        st.error(f"Classification error: {str(e)}")
        return None, None

def classify_with_fasttext(model, text):
    """Classify text using FastText."""
    if model is None:
        return None, None
    # FastText language identification as example
    # In production, use a custom trained model for document classification
    predictions = model.predict(text, k=1)
    label = predictions[0][0].replace('__label__', '')
    score = float(predictions[1][0])
    return label, score

# Main UI
st.title("📄 Document Classification Performance Testing")
st.markdown("---")

# File upload
uploaded_file = st.file_uploader(
    "Upload a document",
    type=['pdf', 'png', 'jpg', 'jpeg'],
    help="Upload a PDF or Image file (PNG/JPG/JPEG)"
)

# Model selection
model_options = [
    "distilbert-base-uncased",
    "google/tinybert-6L-384D"
]

# FastText option commented out
# if FASTTEXT_AVAILABLE:
#     model_options.append("FastText")
# else:
#     model_options.append("FastText (Not Available)")

model_option = st.selectbox(
    "Select Model",
    options=model_options,
    help="Choose a model for document classification"
)

# FastText warning commented out
# if model_option == "FastText (Not Available)":
#     st.warning("⚠️ FastText is not installed. To use FastText, you can try installing a pre-built wheel: "
#               "`pip install fasttext-wheel` or use conda: `conda install -c conda-forge fasttext`. "
#               "Alternatively, use DistilBERT or TinyBERT models.")

# Classification button
if st.button("Classify Document", type="primary"):
    if uploaded_file is None:
        st.warning("Please upload a file first.")
    else:
        # Extract text based on file type
        file_extension = uploaded_file.name.split('.')[-1].lower()
        
        if file_extension == 'pdf':
            extracted_text = extract_text_from_pdf(uploaded_file)
        elif file_extension in ['png', 'jpg', 'jpeg']:
            extracted_text = extract_text_from_image(uploaded_file)
        else:
            st.error("Unsupported file type.")
            extracted_text = None
        
        # Check if text extraction was successful
        if extracted_text is None:
            st.error("Failed to extract text from the document.")
        elif len(extracted_text.strip()) == 0:
            st.error("No text could be extracted from the document. The document may be empty or contain only images.")
        else:
            # Display extracted text preview
            with st.expander("Extracted Text Preview", expanded=False):
                st.text(extracted_text[:500] + "..." if len(extracted_text) > 500 else extracted_text)
            
            # Define candidate labels for zero-shot classification
            # These are example labels - adjust based on your use case
            candidate_labels = [
                "invoice",
                "contract",
                "report",
                "letter",
                "receipt",
                "form",
                "memo",
                "other"
            ]
            
            # Load model and classify
            start_time = time.time()
            
            if model_option == "distilbert-base-uncased":
                classifier = load_distilbert_model()
                if classifier:
                    predicted_label, confidence = classify_with_distilbert(
                        classifier, extracted_text, candidate_labels
                    )
                    model_name = "DistilBERT Base Uncased"
                else:
                    predicted_label, confidence = None, None
                    model_name = "DistilBERT Base Uncased"
            
            elif model_option == "google/tinybert-6L-384D":
                classifier = load_tinybert_model()
                if classifier:
                    predicted_label, confidence = classify_with_tinybert(
                        classifier, extracted_text, candidate_labels
                    )
                    model_name = "TinyBERT 6L-384D"
                else:
                    predicted_label, confidence = None, None
                    model_name = "TinyBERT 6L-384D"
            
            # FastText option commented out
            # elif model_option in ["FastText", "FastText (Not Available)"]:
            #     if not FASTTEXT_AVAILABLE:
            #         st.error("FastText is not available. Please install it or select a different model.")
            #         predicted_label, confidence = None, None
            #         model_name = "FastText"
            #     else:
            #         model = load_fasttext_model()
            #         if model:
            #             predicted_label, confidence = classify_with_fasttext(model, extracted_text)
            #             model_name = "FastText"
            #         else:
            #             predicted_label, confidence = None, None
            #             model_name = "FastText"
            
            inference_time = time.time() - start_time
            
            # Display results
            if predicted_label is not None:
                st.success("Classification Complete!")
                st.markdown("---")
                
                col1, col2, col3 = st.columns(3)
                
                with col1:
                    st.metric("Predicted Label", predicted_label)
                
                with col2:
                    st.metric("Confidence", f"{confidence:.2%}" if confidence else "N/A")
                
                with col3:
                    st.metric("Inference Time", f"{inference_time:.4f}s")
                
                st.info(f"**Model Used:** {model_name}")
            else:
                st.error("Classification failed. Please check the model loading status above.")

# Footer
st.markdown("---")
st.caption("Document Classification Performance Testing Tool")