jacksonwambali commited on
Commit
8a3639a
·
verified ·
1 Parent(s): 1c506f6

Create app.y

Browse files
Files changed (1) hide show
  1. app.y +455 -0
app.y ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import pandas as pd
6
+ import json
7
+ from datetime import datetime
8
+ import plotly.graph_objects as go
9
+ import plotly.express as px
10
+
11
+ class BERTScamClassifier(nn.Module):
12
+ """BERT-based classifier for scam detection"""
13
+
14
+ def __init__(self, model_name='bert-base-multilingual-cased', n_classes=2, dropout=0.3):
15
+ super(BERTScamClassifier, self).__init__()
16
+ self.bert = AutoModel.from_pretrained(model_name)
17
+ self.dropout = nn.Dropout(dropout)
18
+ self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
19
+
20
+ def forward(self, input_ids, attention_mask):
21
+ outputs = self.bert(
22
+ input_ids=input_ids,
23
+ attention_mask=attention_mask
24
+ )
25
+
26
+ pooled_output = outputs.pooler_output
27
+ output = self.dropout(pooled_output)
28
+ return self.classifier(output)
29
+
30
+ class GradioScamDetector:
31
+ """Gradio web app for scam detection"""
32
+
33
+ def __init__(self, model_path='bert_scam_detector.pth'):
34
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+ self.model = None
36
+ self.tokenizer = None
37
+ self.id2label = {0: 'trust', 1: 'scam'}
38
+ self.max_length = 128
39
+ self.prediction_history = []
40
+
41
+ # Load model
42
+ self.load_model(model_path)
43
+
44
+ def load_model(self, model_path):
45
+ """Load the trained model"""
46
+ try:
47
+ checkpoint = torch.load(model_path, map_location=self.device)
48
+
49
+ model_name = checkpoint.get('model_name', 'bert-base-multilingual-cased')
50
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+
52
+ self.model = BERTScamClassifier(model_name)
53
+ self.model.load_state_dict(checkpoint['model_state_dict'])
54
+ self.model.to(self.device)
55
+ self.model.eval()
56
+
57
+ self.max_length = checkpoint.get('max_length', 128)
58
+ self.id2label = checkpoint.get('id2label', {0: 'trust', 1: 'scam'})
59
+
60
+ print("✅ Model loaded successfully for Gradio app!")
61
+ return True
62
+
63
+ except Exception as e:
64
+ print(f"❌ Error loading model: {e}")
65
+ return False
66
+
67
+ def predict_message(self, message):
68
+ """Predict if a message is scam or trust"""
69
+ if not message or not message.strip():
70
+ return "⚠️ Please enter a message", 0.0, "No prediction", {}
71
+
72
+ message = message.strip()
73
+
74
+ # Tokenize message
75
+ encoding = self.tokenizer(
76
+ message,
77
+ truncation=True,
78
+ padding='max_length',
79
+ max_length=self.max_length,
80
+ return_tensors='pt'
81
+ )
82
+
83
+ input_ids = encoding['input_ids'].to(self.device)
84
+ attention_mask = encoding['attention_mask'].to(self.device)
85
+
86
+ with torch.no_grad():
87
+ outputs = self.model(input_ids, attention_mask)
88
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
89
+ _, prediction = torch.max(outputs, dim=1)
90
+
91
+ predicted_label = self.id2label[prediction.item()]
92
+ confidence = probabilities[0][prediction.item()].item()
93
+ trust_prob = probabilities[0][0].item()
94
+ scam_prob = probabilities[0][1].item()
95
+
96
+ # Format result with emoji
97
+ if predicted_label == 'scam':
98
+ result_text = f"🚫 SCAM DETECTED"
99
+ color = "red"
100
+ else:
101
+ result_text = f"✅ TRUSTED MESSAGE"
102
+ color = "green"
103
+
104
+ # Confidence level description
105
+ if confidence >= 0.9:
106
+ conf_desc = "Very High"
107
+ elif confidence >= 0.75:
108
+ conf_desc = "High"
109
+ elif confidence >= 0.6:
110
+ conf_desc = "Medium"
111
+ else:
112
+ conf_desc = "Low"
113
+
114
+ # Store prediction history
115
+ self.prediction_history.append({
116
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
117
+ 'message': message[:50] + "..." if len(message) > 50 else message,
118
+ 'prediction': predicted_label,
119
+ 'confidence': confidence,
120
+ 'trust_prob': trust_prob,
121
+ 'scam_prob': scam_prob
122
+ })
123
+
124
+ # Create probability chart
125
+ prob_chart = self.create_probability_chart(trust_prob, scam_prob)
126
+
127
+ # Detailed results
128
+ details = f"""
129
+ **Prediction:** {result_text}
130
+ **Confidence:** {confidence:.1%} ({conf_desc})
131
+ **Device:** {self.device}
132
+ **Message Length:** {len(message)} characters
133
+ """
134
+
135
+ return result_text, confidence, details, prob_chart
136
+
137
+ def create_probability_chart(self, trust_prob, scam_prob):
138
+ """Create probability visualization"""
139
+ fig = go.Figure(data=[
140
+ go.Bar(
141
+ x=['Trust', 'Scam'],
142
+ y=[trust_prob, scam_prob],
143
+ marker_color=['green', 'red'],
144
+ text=[f'{trust_prob:.1%}', f'{scam_prob:.1%}'],
145
+ textposition='auto',
146
+ )
147
+ ])
148
+
149
+ fig.update_layout(
150
+ title="Prediction Probabilities",
151
+ yaxis_title="Probability",
152
+ xaxis_title="Classification",
153
+ showlegend=False,
154
+ height=300,
155
+ margin=dict(l=20, r=20, t=40, b=20)
156
+ )
157
+
158
+ return fig
159
+
160
+ def batch_predict(self, file):
161
+ """Batch prediction from uploaded file"""
162
+ if file is None:
163
+ return "⚠️ Please upload a file", None
164
+
165
+ try:
166
+ # Read file based on extension
167
+ if file.name.endswith('.csv'):
168
+ df = pd.read_csv(file.name)
169
+ if 'message' in df.columns:
170
+ messages = df['message'].tolist()
171
+ else:
172
+ messages = df.iloc[:, 0].tolist() # First column
173
+ elif file.name.endswith('.txt'):
174
+ with open(file.name, 'r', encoding='utf-8') as f:
175
+ messages = [line.strip() for line in f if line.strip()]
176
+ else:
177
+ return "❌ Unsupported file format. Use CSV or TXT files.", None
178
+
179
+ # Process messages
180
+ results = []
181
+ for i, message in enumerate(messages[:100]): # Limit to 100 messages
182
+ if message and message.strip():
183
+ pred_label, confidence, _, _ = self.predict_message(message)
184
+ results.append({
185
+ 'Message': message[:100] + "..." if len(message) > 100 else message,
186
+ 'Prediction': pred_label,
187
+ 'Confidence': f"{confidence:.1%}"
188
+ })
189
+
190
+ # Create results DataFrame
191
+ results_df = pd.DataFrame(results)
192
+
193
+ # Summary
194
+ scam_count = len([r for r in results if 'SCAM' in r['Prediction']])
195
+ trust_count = len(results) - scam_count
196
+
197
+ summary = f"""
198
+ 📊 **Batch Processing Complete**
199
+ - Total Messages: {len(results)}
200
+ - 🚫 Scam Messages: {scam_count}
201
+ - ✅ Trusted Messages: {trust_count}
202
+ - 📈 Scam Rate: {scam_count/len(results):.1%}
203
+ """
204
+
205
+ return summary, results_df
206
+
207
+ except Exception as e:
208
+ return f"❌ Error processing file: {str(e)}", None
209
+
210
+ def get_prediction_history(self):
211
+ """Get prediction history as DataFrame"""
212
+ if not self.prediction_history:
213
+ return pd.DataFrame({'Message': ['No predictions yet']})
214
+
215
+ df = pd.DataFrame(self.prediction_history[-20:]) # Last 20 predictions
216
+ df['Confidence'] = df['confidence'].apply(lambda x: f"{x:.1%}")
217
+ df['Prediction'] = df['prediction'].apply(lambda x: f"🚫 {x.upper()}" if x == 'scam' else f"✅ {x.upper()}")
218
+
219
+ return df[['timestamp', 'message', 'Prediction', 'Confidence']].rename(columns={
220
+ 'timestamp': 'Time',
221
+ 'message': 'Message',
222
+ })
223
+
224
+ def clear_history(self):
225
+ """Clear prediction history"""
226
+ self.prediction_history = []
227
+ return pd.DataFrame({'Message': ['History cleared']})
228
+
229
+ def get_sample_messages(self):
230
+ """Get sample messages for testing"""
231
+ return {
232
+ "Swahili Scam": "Hongera! Umeshinda Sh 5,000,000. Tuma PIN yako sasa kupokea zawadi yako!",
233
+ "English Scam": "CONGRATULATIONS! You've won $1,000,000. Send your bank details immediately!",
234
+ "Swahili Trust": "Habari za leo? Natumai uko salama na kila kitu ni sawa",
235
+ "English Trust": "Hi there! How was your day today? Hope everything is going well",
236
+ "Mixed Language": "Hi, kikao kitafanyika kesho at 2 PM. Don't forget!",
237
+ "Suspicious": "URGENT: Your account will be suspended. Click link to verify now!"
238
+ }
239
+
240
+ def create_gradio_app():
241
+ """Create and configure Gradio interface"""
242
+
243
+ # Initialize detector
244
+ detector = GradioScamDetector()
245
+
246
+ # Custom CSS for better styling
247
+ css = """
248
+ .gradio-container {
249
+ max-width: 1200px !important;
250
+ }
251
+ .result-box {
252
+ font-size: 18px !important;
253
+ font-weight: bold !important;
254
+ text-align: center !important;
255
+ padding: 20px !important;
256
+ border-radius: 10px !important;
257
+ }
258
+ .scam-result {
259
+ background-color: #ffebee !important;
260
+ color: #c62828 !important;
261
+ border: 2px solid #f44336 !important;
262
+ }
263
+ .trust-result {
264
+ background-color: #e8f5e8 !important;
265
+ color: #2e7d32 !important;
266
+ border: 2px solid #4caf50 !important;
267
+ }
268
+ """
269
+
270
+ # Create Gradio interface
271
+ with gr.Blocks(css=css, title="🛡️ BERT Scam Detector", theme=gr.themes.Soft()) as demo:
272
+
273
+ # Header
274
+ gr.Markdown("""
275
+ # 🛡️ BERT Scam Detector
276
+ ### Intelligent SMS Scam Detection for Swahili & English
277
+
278
+ This AI system uses advanced BERT language models to detect scam messages in both Swahili and English.
279
+ Simply enter a message below to check if it's legitimate or potentially fraudulent.
280
+ """)
281
+
282
+ # Main prediction interface
283
+ with gr.Tab("🔍 Single Message Detection"):
284
+ with gr.Row():
285
+ with gr.Column(scale=2):
286
+ message_input = gr.Textbox(
287
+ label="📝 Enter SMS Message",
288
+ placeholder="Type or paste your SMS message here...",
289
+ lines=4,
290
+ max_lines=8
291
+ )
292
+
293
+ with gr.Row():
294
+ predict_btn = gr.Button("🔍 Analyze Message", variant="primary", size="lg")
295
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
296
+
297
+ # Sample messages
298
+ gr.Markdown("### 📋 Quick Test Samples:")
299
+ sample_messages = detector.get_sample_messages()
300
+
301
+ with gr.Row():
302
+ for name, msg in list(sample_messages.items())[:3]:
303
+ gr.Button(name, size="sm").click(
304
+ lambda msg=msg: msg, outputs=message_input
305
+ )
306
+
307
+ with gr.Row():
308
+ for name, msg in list(sample_messages.items())[3:]:
309
+ gr.Button(name, size="sm").click(
310
+ lambda msg=msg: msg, outputs=message_input
311
+ )
312
+
313
+ with gr.Column(scale=2):
314
+ # Results
315
+ result_text = gr.Textbox(
316
+ label="🎯 Prediction Result",
317
+ interactive=False,
318
+ elem_classes=["result-box"]
319
+ )
320
+
321
+ confidence_slider = gr.Slider(
322
+ label="📊 Confidence Level",
323
+ minimum=0,
324
+ maximum=1,
325
+ interactive=False,
326
+ show_label=True
327
+ )
328
+
329
+ details_md = gr.Markdown(label="📋 Detailed Analysis")
330
+
331
+ prob_chart = gr.Plot(label="📈 Probability Distribution")
332
+
333
+ # Batch processing tab
334
+ with gr.Tab("📁 Batch Processing"):
335
+ gr.Markdown("### Upload a file with multiple messages for batch analysis")
336
+
337
+ with gr.Row():
338
+ with gr.Column():
339
+ file_upload = gr.File(
340
+ label="📄 Upload File (CSV or TXT)",
341
+ file_types=[".csv", ".txt"]
342
+ )
343
+
344
+ batch_btn = gr.Button("🚀 Process Batch", variant="primary")
345
+
346
+ with gr.Column():
347
+ batch_summary = gr.Markdown(label="📊 Summary")
348
+
349
+ batch_results = gr.Dataframe(
350
+ label="📋 Batch Results",
351
+ interactive=False,
352
+ wrap=True
353
+ )
354
+
355
+ # History tab
356
+ with gr.Tab("📚 Prediction History"):
357
+ with gr.Row():
358
+ refresh_btn = gr.Button("🔄 Refresh History", variant="secondary")
359
+ clear_history_btn = gr.Button("🗑️ Clear History", variant="secondary")
360
+
361
+ history_df = gr.Dataframe(
362
+ label="📋 Recent Predictions",
363
+ interactive=False,
364
+ wrap=True
365
+ )
366
+
367
+ # About tab
368
+ with gr.Tab("ℹ️ About"):
369
+ gr.Markdown("""
370
+ ## 🤖 About This System
371
+
372
+ ### How It Works
373
+ - **Model**: BERT (Bidirectional Encoder Representations from Transformers)
374
+ - **Languages**: Swahili and English
375
+ - **Training**: Fine-tuned on SMS scam detection dataset
376
+ - **Accuracy**: High precision scam detection
377
+
378
+ ### Features
379
+ - ✅ Real-time message analysis
380
+ - 🌍 Multilingual support (Swahili & English)
381
+ - 📊 Confidence scoring
382
+ - 📁 Batch processing
383
+ - 📚 Prediction history
384
+
385
+ ### Usage Tips
386
+ - Enter complete SMS messages for best results
387
+ - The system works with both languages simultaneously
388
+ - Higher confidence scores indicate more reliable predictions
389
+ - Check the probability distribution for detailed insights
390
+
391
+ ### Safety Notice
392
+ - This is an AI assistant - use your judgment
393
+ - Report suspicious messages to authorities
394
+ - Never share personal information with untrusted sources
395
+
396
+ ---
397
+ **Powered by BERT & Gradio** | Made with ❤️ for SMS security
398
+ """)
399
+
400
+ # Event handlers
401
+ predict_btn.click(
402
+ fn=detector.predict_message,
403
+ inputs=message_input,
404
+ outputs=[result_text, confidence_slider, details_md, prob_chart]
405
+ )
406
+
407
+ clear_btn.click(
408
+ fn=lambda: ("", 0, "", None),
409
+ outputs=[message_input, confidence_slider, details_md, prob_chart]
410
+ )
411
+
412
+ batch_btn.click(
413
+ fn=detector.batch_predict,
414
+ inputs=file_upload,
415
+ outputs=[batch_summary, batch_results]
416
+ )
417
+
418
+ refresh_btn.click(
419
+ fn=detector.get_prediction_history,
420
+ outputs=history_df
421
+ )
422
+
423
+ clear_history_btn.click(
424
+ fn=detector.clear_history,
425
+ outputs=history_df
426
+ )
427
+
428
+ # Auto-refresh history on prediction
429
+ predict_btn.click(
430
+ fn=detector.get_prediction_history,
431
+ outputs=history_df
432
+ )
433
+
434
+ return demo
435
+
436
+ def main():
437
+ """Launch the Gradio app"""
438
+ print("🚀 Starting BERT Scam Detector Web App...")
439
+
440
+ # Create and launch app
441
+ app = create_gradio_app()
442
+
443
+ # Launch with custom settings
444
+ app.launch(
445
+ server_name="0.0.0.0", # Allow external access
446
+ server_port=7860, # Default Gradio port
447
+ share=True, # Set to True for public link
448
+ debug=False,
449
+ show_error=False,
450
+ quiet=False,
451
+ inbrowser=True # Auto-open browser
452
+ )
453
+
454
+ if __name__ == "__main__":
455
+ main()