GoldiSahoo commited on
Commit
032ec0b
Β·
verified Β·
1 Parent(s): 56c8647

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +387 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,389 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ import tempfile
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ from document_classifier import DocumentClassifier
8
+ import time
9
+ from typing import List, Dict
10
+ import json
11
+
12
+ # Page configuration
13
+ st.set_page_config(
14
+ page_title="Document Classifier",
15
+ page_icon="πŸ“„",
16
+ layout="wide",
17
+ initial_sidebar_state="expanded"
18
+ )
19
+
20
+ # Custom CSS for better styling
21
+ st.markdown("""
22
+ <style>
23
+ .main-header {
24
+ font-size: 3rem;
25
+ color: #1f77b4;
26
+ text-align: center;
27
+ margin-bottom: 2rem;
28
+ }
29
+ .sub-header {
30
+ font-size: 1.5rem;
31
+ color: #2c3e50;
32
+ margin-top: 2rem;
33
+ margin-bottom: 1rem;
34
+ }
35
+ .metric-card {
36
+ background-color: #f8f9fa;
37
+ padding: 1rem;
38
+ border-radius: 0.5rem;
39
+ border-left: 4px solid #1f77b4;
40
+ }
41
+ .success-message {
42
+ background-color: #d4edda;
43
+ color: #155724;
44
+ padding: 1rem;
45
+ border-radius: 0.5rem;
46
+ border: 1px solid #c3e6cb;
47
+ }
48
+ .error-message {
49
+ background-color: #f8d7da;
50
+ color: #721c24;
51
+ padding: 1rem;
52
+ border-radius: 0.5rem;
53
+ border: 1px solid #f5c6cb;
54
+ }
55
+ </style>
56
+ """, unsafe_allow_html=True)
57
+
58
+ # Initialize session state
59
+ if 'classifier' not in st.session_state:
60
+ st.session_state.classifier = None
61
+ if 'classification_results' not in st.session_state:
62
+ st.session_state.classification_results = []
63
+ if 'uploaded_files' not in st.session_state:
64
+ st.session_state.uploaded_files = []
65
+
66
+ def initialize_classifier():
67
+ """Initialize the document classifier."""
68
+ if st.session_state.classifier is None:
69
+ with st.spinner("Loading Hugging Face models..."):
70
+ try:
71
+ st.session_state.classifier = DocumentClassifier()
72
+ st.success("βœ… Document classifier initialized successfully!")
73
+ return True
74
+ except Exception as e:
75
+ st.error(f"❌ Failed to initialize classifier: {str(e)}")
76
+ return False
77
+ return True
78
+
79
+ def save_uploaded_file(uploaded_file) -> str:
80
+ """Save uploaded file to temporary directory."""
81
+ try:
82
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file:
83
+ tmp_file.write(uploaded_file.getbuffer())
84
+ return tmp_file.name
85
+ except Exception as e:
86
+ st.error(f"Error saving file: {str(e)}")
87
+ return None
88
+
89
+ def classify_single_file(file_path: str) -> Dict:
90
+ """Classify a single file."""
91
+ if not st.session_state.classifier:
92
+ return {"error": "Classifier not initialized", "success": False}
93
+
94
+ try:
95
+ result = st.session_state.classifier.classify_document(file_path)
96
+ return result
97
+ except Exception as e:
98
+ return {"error": str(e), "success": False}
99
+
100
+ def classify_multiple_files(file_paths: List[str]) -> List[Dict]:
101
+ """Classify multiple files."""
102
+ if not st.session_state.classifier:
103
+ return [{"error": "Classifier not initialized", "success": False}]
104
+
105
+ try:
106
+ results = st.session_state.classifier.classify_multiple_documents(file_paths)
107
+ return results
108
+ except Exception as e:
109
+ return [{"error": str(e), "success": False}]
110
+
111
+ def display_classification_result(result: Dict):
112
+ """Display a single classification result."""
113
+ if not result.get('success', False):
114
+ st.error(f"❌ Classification failed: {result.get('error', 'Unknown error')}")
115
+ return
116
+
117
+ col1, col2, col3 = st.columns(3)
118
+
119
+ with col1:
120
+ st.metric("Document Type", result['file_type'])
121
+
122
+ with col2:
123
+ st.metric("Classification", result['classification'].title())
124
+
125
+ with col3:
126
+ st.metric("Confidence", f"{result['confidence']:.2%}")
127
+
128
+ # Display detailed information
129
+ st.subheader("πŸ“‹ Document Details")
130
+
131
+ col1, col2 = st.columns(2)
132
+
133
+ with col1:
134
+ st.write(f"**File Name:** {result['file_name']}")
135
+ st.write(f"**File Extension:** {result['file_extension']}")
136
+ st.write(f"**Content Length:** {result['content_length']} characters")
137
+
138
+ with col2:
139
+ st.write(f"**File Path:** {result['file_path']}")
140
+ st.write(f"**Classification Confidence:** {result['confidence']:.2%}")
141
+
142
+ # Display text preview
143
+ if result['text_preview']:
144
+ st.subheader("πŸ“– Text Preview")
145
+ st.text_area("Content Preview", result['text_preview'], height=150, disabled=True)
146
+
147
+ # Display all classification scores
148
+ st.subheader("πŸ“Š Classification Scores")
149
+ scores_df = pd.DataFrame(list(result['all_scores'].items()), columns=['Document Type', 'Score'])
150
+ scores_df['Score'] = scores_df['Score'].round(4)
151
+ scores_df = scores_df.sort_values('Score', ascending=False)
152
+
153
+ # Create a bar chart
154
+ fig = px.bar(scores_df, x='Document Type', y='Score',
155
+ title="Classification Confidence Scores",
156
+ color='Score',
157
+ color_continuous_scale='Blues')
158
+ fig.update_layout(xaxis_tickangle=-45)
159
+ st.plotly_chart(fig, use_container_width=True)
160
+
161
+ # Display scores table
162
+ st.dataframe(scores_df, use_container_width=True)
163
+
164
+ def display_batch_results(results: List[Dict]):
165
+ """Display batch classification results."""
166
+ if not results:
167
+ st.warning("No results to display.")
168
+ return
169
+
170
+ # Summary statistics
171
+ successful_results = [r for r in results if r.get('success', False)]
172
+ failed_results = [r for r in results if not r.get('success', False)]
173
+
174
+ col1, col2, col3, col4 = st.columns(4)
175
+
176
+ with col1:
177
+ st.metric("Total Files", len(results))
178
+
179
+ with col2:
180
+ st.metric("Successful", len(successful_results))
181
+
182
+ with col3:
183
+ st.metric("Failed", len(failed_results))
184
+
185
+ with col4:
186
+ if successful_results:
187
+ avg_confidence = sum(r['confidence'] for r in successful_results) / len(successful_results)
188
+ st.metric("Avg Confidence", f"{avg_confidence:.2%}")
189
+
190
+ # Classification distribution
191
+ if successful_results:
192
+ st.subheader("πŸ“Š Classification Distribution")
193
+ classifications = [r['classification'] for r in successful_results]
194
+ classification_counts = pd.Series(classifications).value_counts()
195
+
196
+ fig = px.pie(values=classification_counts.values,
197
+ names=classification_counts.index,
198
+ title="Document Type Distribution")
199
+ st.plotly_chart(fig, use_container_width=True)
200
+
201
+ # Detailed results table
202
+ st.subheader("πŸ“‹ Detailed Results")
203
+
204
+ if successful_results:
205
+ results_data = []
206
+ for result in successful_results:
207
+ results_data.append({
208
+ 'File Name': result['file_name'],
209
+ 'File Type': result['file_type'],
210
+ 'Classification': result['classification'].title(),
211
+ 'Confidence': f"{result['confidence']:.2%}",
212
+ 'Content Length': result['content_length']
213
+ })
214
+
215
+ results_df = pd.DataFrame(results_data)
216
+ st.dataframe(results_df, use_container_width=True)
217
+
218
+ # Show failed results
219
+ if failed_results:
220
+ st.subheader("❌ Failed Classifications")
221
+ for result in failed_results:
222
+ st.error(f"**{result.get('file_name', 'Unknown')}**: {result.get('error', 'Unknown error')}")
223
+
224
+ def main():
225
+ """Main Streamlit application."""
226
+
227
+ # Header
228
+ st.markdown('<h1 class="main-header">πŸ“„ Document Classifier</h1>', unsafe_allow_html=True)
229
+ st.markdown("""
230
+ <div style="text-align: center; margin-bottom: 2rem;">
231
+ <p style="font-size: 1.2rem; color: #666;">
232
+ Classify documents using Hugging Face models and content analysis
233
+ </p>
234
+ </div>
235
+ """, unsafe_allow_html=True)
236
+
237
+ # Sidebar
238
+ st.sidebar.title("βš™οΈ Settings")
239
+
240
+ # Initialize classifier
241
+ if st.sidebar.button("πŸ”„ Initialize Classifier", type="primary"):
242
+ initialize_classifier()
243
+
244
+ # Model information
245
+ st.sidebar.subheader("πŸ€– Model Information")
246
+ st.sidebar.info("""
247
+ **Models Used:**
248
+ - Cardiff NLP Twitter RoBERTa Base Emotion
249
+ - DistilBERT Base Uncased (fallback)
250
+
251
+ **Supported Formats:**
252
+ - PDF, DOCX, DOC
253
+ - TXT, CSV
254
+ - XLSX, XLS
255
+ - Images (JPG, PNG, etc.)
256
+ """)
257
+
258
+ # Main content
259
+ tab1, tab2, tab3 = st.tabs(["πŸ“ Single File", "πŸ“‚ Batch Upload", "πŸ“Š Results"])
260
+
261
+ with tab1:
262
+ st.subheader("πŸ“ Classify Single Document")
263
+
264
+ uploaded_file = st.file_uploader(
265
+ "Choose a document file",
266
+ type=['pdf', 'docx', 'doc', 'txt', 'csv', 'xlsx', 'xls', 'jpg', 'jpeg', 'png'],
267
+ help="Upload a document to classify its type and content"
268
+ )
269
+
270
+ if uploaded_file is not None:
271
+ if st.button("πŸ” Classify Document", type="primary"):
272
+ if not initialize_classifier():
273
+ st.stop()
274
+
275
+ # Save uploaded file
276
+ file_path = save_uploaded_file(uploaded_file)
277
+ if file_path:
278
+ with st.spinner("Classifying document..."):
279
+ result = classify_single_file(file_path)
280
+ st.session_state.classification_results = [result]
281
+
282
+ # Clean up temporary file
283
+ try:
284
+ os.unlink(file_path)
285
+ except:
286
+ pass
287
+
288
+ # Display result
289
+ display_classification_result(result)
290
+
291
+ with tab2:
292
+ st.subheader("πŸ“‚ Batch Document Classification")
293
+
294
+ uploaded_files = st.file_uploader(
295
+ "Choose multiple document files",
296
+ type=['pdf', 'docx', 'doc', 'txt', 'csv', 'xlsx', 'xls', 'jpg', 'jpeg', 'png'],
297
+ accept_multiple_files=True,
298
+ help="Upload multiple documents to classify them in batch"
299
+ )
300
+
301
+ if uploaded_files:
302
+ st.write(f"πŸ“ {len(uploaded_files)} files selected")
303
+
304
+ if st.button("πŸ” Classify All Documents", type="primary"):
305
+ if not initialize_classifier():
306
+ st.stop()
307
+
308
+ # Save uploaded files
309
+ file_paths = []
310
+ for uploaded_file in uploaded_files:
311
+ file_path = save_uploaded_file(uploaded_file)
312
+ if file_path:
313
+ file_paths.append(file_path)
314
+
315
+ if file_paths:
316
+ progress_bar = st.progress(0)
317
+ status_text = st.empty()
318
+
319
+ results = []
320
+ for i, file_path in enumerate(file_paths):
321
+ status_text.text(f"Processing file {i+1}/{len(file_paths)}: {os.path.basename(file_path)}")
322
+ result = classify_single_file(file_path)
323
+ results.append(result)
324
+ progress_bar.progress((i + 1) / len(file_paths))
325
+
326
+ # Clean up temporary file
327
+ try:
328
+ os.unlink(file_path)
329
+ except:
330
+ pass
331
+
332
+ st.session_state.classification_results = results
333
+ status_text.text("βœ… Classification complete!")
334
+
335
+ # Display batch results
336
+ display_batch_results(results)
337
+
338
+ with tab3:
339
+ st.subheader("πŸ“Š Classification Results")
340
+
341
+ if st.session_state.classification_results:
342
+ if len(st.session_state.classification_results) == 1:
343
+ display_classification_result(st.session_state.classification_results[0])
344
+ else:
345
+ display_batch_results(st.session_state.classification_results)
346
+ else:
347
+ st.info("πŸ‘† Upload and classify documents to see results here.")
348
+
349
+ # Export results
350
+ if st.session_state.classification_results:
351
+ st.subheader("πŸ’Ύ Export Results")
352
+
353
+ col1, col2 = st.columns(2)
354
+
355
+ with col1:
356
+ if st.button("πŸ“„ Export as CSV"):
357
+ successful_results = [r for r in st.session_state.classification_results if r.get('success', False)]
358
+ if successful_results:
359
+ export_data = []
360
+ for result in successful_results:
361
+ export_data.append({
362
+ 'File Name': result['file_name'],
363
+ 'File Type': result['file_type'],
364
+ 'Classification': result['classification'],
365
+ 'Confidence': result['confidence'],
366
+ 'Content Length': result['content_length']
367
+ })
368
+
369
+ df = pd.DataFrame(export_data)
370
+ csv = df.to_csv(index=False)
371
+ st.download_button(
372
+ label="Download CSV",
373
+ data=csv,
374
+ file_name="classification_results.csv",
375
+ mime="text/csv"
376
+ )
377
+
378
+ with col2:
379
+ if st.button("πŸ“‹ Export as JSON"):
380
+ json_data = json.dumps(st.session_state.classification_results, indent=2)
381
+ st.download_button(
382
+ label="Download JSON",
383
+ data=json_data,
384
+ file_name="classification_results.json",
385
+ mime="application/json"
386
+ )
387
 
388
+ if __name__ == "__main__":
389
+ main()