Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, AutoModel | |
| import pandas as pd | |
| import json | |
| from datetime import datetime | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| class BERTScamClassifier(nn.Module): | |
| """BERT-based classifier for scam detection""" | |
| def __init__(self, model_name='bert-base-multilingual-cased', n_classes=2, dropout=0.3): | |
| super(BERTScamClassifier, self).__init__() | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| pooled_output = outputs.pooler_output | |
| output = self.dropout(pooled_output) | |
| return self.classifier(output) | |
| class GradioScamDetector: | |
| """Gradio web app for scam detection""" | |
| def __init__(self, model_path='bert_scam_detector.pth'): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model = None | |
| self.tokenizer = None | |
| self.id2label = {0: 'trust', 1: 'scam'} | |
| self.max_length = 128 | |
| self.prediction_history = [] | |
| # Load model | |
| self.load_model(model_path) | |
| def load_model(self, model_path): | |
| """Load the trained model""" | |
| try: | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| model_name = checkpoint.get('model_name', 'bert-base-multilingual-cased') | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = BERTScamClassifier(model_name) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.max_length = checkpoint.get('max_length', 128) | |
| self.id2label = checkpoint.get('id2label', {0: 'trust', 1: 'scam'}) | |
| print("✅ Model loaded successfully for Gradio app!") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| return False | |
| def predict_message(self, message): | |
| """Predict if a message is scam or trust""" | |
| if not message or not message.strip(): | |
| return "⚠️ Please enter a message", 0.0, "No prediction", {} | |
| message = message.strip() | |
| # Tokenize message | |
| encoding = self.tokenizer( | |
| message, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=self.max_length, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(self.device) | |
| attention_mask = encoding['attention_mask'].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids, attention_mask) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| _, prediction = torch.max(outputs, dim=1) | |
| predicted_label = self.id2label[prediction.item()] | |
| confidence = probabilities[0][prediction.item()].item() | |
| trust_prob = probabilities[0][0].item() | |
| scam_prob = probabilities[0][1].item() | |
| # Format result with emoji | |
| if predicted_label == 'scam': | |
| result_text = f"🚫 SCAM DETECTED" | |
| color = "red" | |
| else: | |
| result_text = f"✅ TRUSTED MESSAGE" | |
| color = "green" | |
| # Confidence level description | |
| if confidence >= 0.9: | |
| conf_desc = "Very High" | |
| elif confidence >= 0.75: | |
| conf_desc = "High" | |
| elif confidence >= 0.6: | |
| conf_desc = "Medium" | |
| else: | |
| conf_desc = "Low" | |
| # Store prediction history | |
| self.prediction_history.append({ | |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| 'message': message[:50] + "..." if len(message) > 50 else message, | |
| 'prediction': predicted_label, | |
| 'confidence': confidence, | |
| 'trust_prob': trust_prob, | |
| 'scam_prob': scam_prob | |
| }) | |
| # Create probability chart | |
| prob_chart = self.create_probability_chart(trust_prob, scam_prob) | |
| # Detailed results | |
| details = f""" | |
| **Prediction:** {result_text} | |
| **Confidence:** {confidence:.1%} ({conf_desc}) | |
| **Device:** {self.device} | |
| **Message Length:** {len(message)} characters | |
| """ | |
| return result_text, confidence, details, prob_chart | |
| def predict_api(self, message): | |
| """API-friendly prediction function for webhooks""" | |
| if not message or not message.strip(): | |
| return { | |
| "status": "error", | |
| "message": "Empty message", | |
| "prediction": "unknown", | |
| "confidence": 0.0 | |
| } | |
| message = message.strip() | |
| try: | |
| # Tokenize message | |
| encoding = self.tokenizer( | |
| message, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=self.max_length, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(self.device) | |
| attention_mask = encoding['attention_mask'].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids, attention_mask) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| _, prediction = torch.max(outputs, dim=1) | |
| predicted_label = self.id2label[prediction.item()] | |
| confidence = probabilities[0][prediction.item()].item() | |
| trust_prob = probabilities[0][0].item() | |
| scam_prob = probabilities[0][1].item() | |
| # Format result | |
| if predicted_label == 'scam': | |
| result_text = "🚫 SCAM DETECTED" | |
| alert_level = "HIGH" | |
| else: | |
| result_text = "✅ TRUSTED MESSAGE" | |
| alert_level = "LOW" | |
| # Store prediction history | |
| self.prediction_history.append({ | |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| 'message': message[:50] + "..." if len(message) > 50 else message, | |
| 'prediction': predicted_label, | |
| 'confidence': confidence, | |
| 'trust_prob': trust_prob, | |
| 'scam_prob': scam_prob, | |
| 'source': 'API' | |
| }) | |
| return { | |
| "status": "success", | |
| "message": message[:100] + "..." if len(message) > 100 else message, | |
| "prediction": predicted_label, | |
| "result_text": result_text, | |
| "confidence": round(confidence, 4), | |
| "trust_probability": round(trust_prob, 4), | |
| "scam_probability": round(scam_prob, 4), | |
| "alert_level": alert_level, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "message": f"Prediction failed: {str(e)}", | |
| "prediction": "unknown", | |
| "confidence": 0.0 | |
| } | |
| def create_probability_chart(self, trust_prob, scam_prob): | |
| """Create probability visualization""" | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| x=['Trust', 'Scam'], | |
| y=[trust_prob, scam_prob], | |
| marker_color=['green', 'red'], | |
| text=[f'{trust_prob:.1%}', f'{scam_prob:.1%}'], | |
| textposition='auto', | |
| ) | |
| ]) | |
| fig.update_layout( | |
| title="Prediction Probabilities", | |
| yaxis_title="Probability", | |
| xaxis_title="Classification", | |
| showlegend=False, | |
| height=300, | |
| margin=dict(l=20, r=20, t=40, b=20) | |
| ) | |
| return fig | |
| def batch_predict(self, file): | |
| """Batch prediction from uploaded file""" | |
| if file is None: | |
| return "⚠️ Please upload a file", None | |
| try: | |
| # Read file based on extension | |
| if file.name.endswith('.csv'): | |
| df = pd.read_csv(file.name) | |
| if 'message' in df.columns: | |
| messages = df['message'].tolist() | |
| else: | |
| messages = df.iloc[:, 0].tolist() # First column | |
| elif file.name.endswith('.txt'): | |
| with open(file.name, 'r', encoding='utf-8') as f: | |
| messages = [line.strip() for line in f if line.strip()] | |
| else: | |
| return "❌ Unsupported file format. Use CSV or TXT files.", None | |
| # Process messages | |
| results = [] | |
| for i, message in enumerate(messages[:100]): # Limit to 100 messages | |
| if message and message.strip(): | |
| pred_label, confidence, _, _ = self.predict_message(message) | |
| results.append({ | |
| 'Message': message[:100] + "..." if len(message) > 100 else message, | |
| 'Prediction': pred_label, | |
| 'Confidence': f"{confidence:.1%}" | |
| }) | |
| # Create results DataFrame | |
| results_df = pd.DataFrame(results) | |
| # Summary | |
| scam_count = len([r for r in results if 'SCAM' in r['Prediction']]) | |
| trust_count = len(results) - scam_count | |
| summary = f""" | |
| 📊 **Batch Processing Complete** | |
| - Total Messages: {len(results)} | |
| - 🚫 Scam Messages: {scam_count} | |
| - ✅ Trusted Messages: {trust_count} | |
| - 📈 Scam Rate: {scam_count/len(results):.1%} | |
| """ | |
| return summary, results_df | |
| except Exception as e: | |
| return f"❌ Error processing file: {str(e)}", None | |
| def get_prediction_history(self): | |
| """Get prediction history as DataFrame""" | |
| if not self.prediction_history: | |
| return pd.DataFrame({'Message': ['No predictions yet']}) | |
| df = pd.DataFrame(self.prediction_history[-20:]) # Last 20 predictions | |
| df['Confidence'] = df['confidence'].apply(lambda x: f"{x:.1%}") | |
| df['Prediction'] = df['prediction'].apply(lambda x: f"🚫 {x.upper()}" if x == 'scam' else f"✅ {x.upper()}") | |
| df['Source'] = df.get('source', 'Manual') | |
| return df[['timestamp', 'message', 'Prediction', 'Confidence', 'Source']].rename(columns={ | |
| 'timestamp': 'Time', | |
| 'message': 'Message', | |
| }) | |
| def clear_history(self): | |
| """Clear prediction history""" | |
| self.prediction_history = [] | |
| return pd.DataFrame({'Message': ['History cleared']}) | |
| def get_sample_messages(self): | |
| """Get sample messages for testing""" | |
| return { | |
| "Swahili Scam": "Hongera! Umeshinda Sh 5,000,000. Tuma PIN yako sasa kupokea zawadi yako!", | |
| "English Scam": "CONGRATULATIONS! You've won $1,000,000. Send your bank details immediately!", | |
| "Swahili Trust": "Habari za leo? Natumai uko salama na kila kitu ni sawa", | |
| "English Trust": "Hi there! How was your day today? Hope everything is going well", | |
| "Mixed Language": "Hi, kikao kitafanyika kesho at 2 PM. Don't forget!", | |
| "Suspicious": "URGENT: Your account will be suspended. Click link to verify now!" | |
| } | |
| # Global detector instance for API endpoints | |
| detector = None | |
| def create_gradio_app(): | |
| """Create and configure Gradio interface""" | |
| global detector | |
| # Initialize detector | |
| detector = GradioScamDetector() | |
| # Custom CSS for better styling | |
| css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .result-box { | |
| font-size: 18px !important; | |
| font-weight: bold !important; | |
| text-align: center !important; | |
| padding: 20px !important; | |
| border-radius: 10px !important; | |
| } | |
| .scam-result { | |
| background-color: #ffebee !important; | |
| color: #c62828 !important; | |
| border: 2px solid #f44336 !important; | |
| } | |
| .trust-result { | |
| background-color: #e8f5e8 !important; | |
| color: #2e7d32 !important; | |
| border: 2px solid #4caf50 !important; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, title="🛡️ BERT Scam Detector", theme=gr.themes.Soft()) as demo: | |
| # Header | |
| gr.Markdown(""" | |
| # 🛡️ BERT Scam Detector | |
| ### Intelligent SMS Scam Detection for Swahili & English | |
| This AI system uses advanced BERT language models to detect scam messages in both Swahili and English. | |
| Simply enter a message below to check if it's legitimate or potentially fraudulent. | |
| """) | |
| # API Information Tab | |
| with gr.Tab("🔌 API Integration"): | |
| gr.Markdown(""" | |
| ## 📡 API Endpoints for IFTTT/Zapier Integration | |
| ### For IFTTT Webhook: | |
| ``` | |
| URL: https://jacksonwambali-bert.hf.space/api/predict | |
| Method: POST | |
| Content-Type: application/json | |
| Body: {"data": ["Your SMS message here"]} | |
| ``` | |
| ### For Zapier Webhook: | |
| ``` | |
| URL: https://jacksonwambali-bert.hf.space/api/predict | |
| Method: POST | |
| Content-Type: application/json | |
| Payload: {"data": ["{{sms_text}}"]} | |
| ``` | |
| ### Response Format: | |
| ```json | |
| { | |
| "data": [ | |
| { | |
| "status": "success", | |
| "prediction": "scam" or "trust", | |
| "result_text": "🚫 SCAM DETECTED" or "✅ TRUSTED MESSAGE", | |
| "confidence": 0.95, | |
| "alert_level": "HIGH" or "LOW" | |
| } | |
| ] | |
| } | |
| ``` | |
| ### Quick Test: | |
| Use the form below to test your API integration: | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| api_test_input = gr.Textbox( | |
| label="📱 Test SMS Message", | |
| placeholder="Enter SMS to test API response...", | |
| lines=3 | |
| ) | |
| api_test_btn = gr.Button("🧪 Test API Response", variant="primary") | |
| with gr.Column(): | |
| api_response = gr.JSON(label="📊 API Response") | |
| api_test_btn.click( | |
| fn=lambda msg: detector.predict_api(msg) if detector else {"error": "Model not loaded"}, | |
| inputs=api_test_input, | |
| outputs=api_response | |
| ) | |
| # Main prediction interface | |
| with gr.Tab("🔍 Single Message Detection"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| message_input = gr.Textbox( | |
| label="📝 Enter SMS Message", | |
| placeholder="Type or paste your SMS message here...", | |
| lines=4, | |
| max_lines=8 | |
| ) | |
| with gr.Row(): | |
| predict_btn = gr.Button("🔍 Analyze Message", variant="primary", size="lg") | |
| clear_btn = gr.Button("🗑️ Clear", variant="secondary") | |
| # Sample messages | |
| gr.Markdown("### 📋 Quick Test Samples:") | |
| sample_messages = detector.get_sample_messages() | |
| with gr.Row(): | |
| for name, msg in list(sample_messages.items())[:3]: | |
| gr.Button(name, size="sm").click( | |
| lambda msg=msg: msg, outputs=message_input | |
| ) | |
| with gr.Row(): | |
| for name, msg in list(sample_messages.items())[3:]: | |
| gr.Button(name, size="sm").click( | |
| lambda msg=msg: msg, outputs=message_input | |
| ) | |
| with gr.Column(scale=2): | |
| # Results | |
| result_text = gr.Textbox( | |
| label="🎯 Prediction Result", | |
| interactive=False, | |
| elem_classes=["result-box"] | |
| ) | |
| confidence_slider = gr.Slider( | |
| label="📊 Confidence Level", | |
| minimum=0, | |
| maximum=1, | |
| interactive=False, | |
| show_label=True | |
| ) | |
| details_md = gr.Markdown(label="📋 Detailed Analysis") | |
| prob_chart = gr.Plot(label="📈 Probability Distribution") | |
| # Batch processing tab | |
| with gr.Tab("📁 Batch Processing"): | |
| gr.Markdown("### Upload a file with multiple messages for batch analysis") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_upload = gr.File( | |
| label="📄 Upload File (CSV or TXT)", | |
| file_types=[".csv", ".txt"] | |
| ) | |
| batch_btn = gr.Button("🚀 Process Batch", variant="primary") | |
| with gr.Column(): | |
| batch_summary = gr.Markdown(label="📊 Summary") | |
| batch_results = gr.Dataframe( | |
| label="📋 Batch Results", | |
| interactive=False, | |
| wrap=True | |
| ) | |
| # History tab | |
| with gr.Tab("📚 Prediction History"): | |
| with gr.Row(): | |
| refresh_btn = gr.Button("🔄 Refresh History", variant="secondary") | |
| clear_history_btn = gr.Button("🗑️ Clear History", variant="secondary") | |
| history_df = gr.Dataframe( | |
| label="📋 Recent Predictions", | |
| interactive=False, | |
| wrap=True | |
| ) | |
| # About tab | |
| with gr.Tab("ℹ️ About"): | |
| gr.Markdown(""" | |
| ## 🤖 About This System | |
| ### How It Works | |
| - **Model**: BERT (Bidirectional Encoder Representations from Transformers) | |
| - **Languages**: Swahili and English | |
| - **Training**: Fine-tuned on SMS scam detection dataset | |
| - **Accuracy**: High precision scam detection | |
| ### Features | |
| - ✅ Real-time message analysis | |
| - 🌍 Multilingual support (Swahili & English) | |
| - 📊 Confidence scoring | |
| - 📁 Batch processing | |
| - 📚 Prediction history | |
| - 🔌 API integration for IFTTT/Zapier | |
| ### SMS Integration | |
| - Connect with IFTTT for automatic SMS scanning | |
| - Webhook support for real-time alerts | |
| - Batch processing for multiple messages | |
| ### Usage Tips | |
| - Enter complete SMS messages for best results | |
| - The system works with both languages simultaneously | |
| - Higher confidence scores indicate more reliable predictions | |
| - Check the probability distribution for detailed insights | |
| ### Safety Notice | |
| - This is an AI assistant - use your judgment | |
| - Report suspicious messages to authorities | |
| - Never share personal information with untrusted sources | |
| --- | |
| **Powered by BERT & Gradio** | Made with ❤️ for SMS security | |
| """) | |
| # Event handlers | |
| predict_btn.click( | |
| fn=detector.predict_message, | |
| inputs=message_input, | |
| outputs=[result_text, confidence_slider, details_md, prob_chart] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", 0, "", None), | |
| outputs=[message_input, confidence_slider, details_md, prob_chart] | |
| ) | |
| batch_btn.click( | |
| fn=detector.batch_predict, | |
| inputs=file_upload, | |
| outputs=[batch_summary, batch_results] | |
| ) | |
| refresh_btn.click( | |
| fn=detector.get_prediction_history, | |
| outputs=history_df | |
| ) | |
| clear_history_btn.click( | |
| fn=detector.clear_history, | |
| outputs=history_df | |
| ) | |
| # Auto-refresh history on prediction | |
| predict_btn.click( | |
| fn=detector.get_prediction_history, | |
| outputs=history_df | |
| ) | |
| return demo | |
| def main(): | |
| """Launch the Gradio app""" | |
| print("🚀 Starting BERT Scam Detector Web App...") | |
| # Create and launch app | |
| app = create_gradio_app() | |
| # Launch with custom settings | |
| app.launch( | |
| server_name="0.0.0.0", # Allow external access | |
| server_port=7860, # Default Gradio port | |
| share=True, # Set to True for public link | |
| debug=False, | |
| show_error=False, | |
| quiet=False, | |
| inbrowser=True # Auto-open browser | |
| ) | |
| if __name__ == "__main__": | |
| main() |