import gradio as gr import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel import pandas as pd import json from datetime import datetime import plotly.graph_objects as go import plotly.express as px class BERTScamClassifier(nn.Module): """BERT-based classifier for scam detection""" def __init__(self, model_name='bert-base-multilingual-cased', n_classes=2, dropout=0.3): super(BERTScamClassifier, self).__init__() self.bert = AutoModel.from_pretrained(model_name) self.dropout = nn.Dropout(dropout) self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes) def forward(self, input_ids, attention_mask): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask ) pooled_output = outputs.pooler_output output = self.dropout(pooled_output) return self.classifier(output) class GradioScamDetector: """Gradio web app for scam detection""" def __init__(self, model_path='bert_scam_detector.pth'): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = None self.tokenizer = None self.id2label = {0: 'trust', 1: 'scam'} self.max_length = 128 self.prediction_history = [] # Load model self.load_model(model_path) def load_model(self, model_path): """Load the trained model""" try: checkpoint = torch.load(model_path, map_location=self.device) model_name = checkpoint.get('model_name', 'bert-base-multilingual-cased') self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = BERTScamClassifier(model_name) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) self.model.eval() self.max_length = checkpoint.get('max_length', 128) self.id2label = checkpoint.get('id2label', {0: 'trust', 1: 'scam'}) print("โœ… Model loaded successfully for Gradio app!") return True except Exception as e: print(f"โŒ Error loading model: {e}") return False def predict_message(self, message): """Predict if a message is scam or trust""" if not message or not message.strip(): return "โš ๏ธ Please enter a message", 0.0, "No prediction", {} message = message.strip() # Tokenize message encoding = self.tokenizer( message, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) input_ids = encoding['input_ids'].to(self.device) attention_mask = encoding['attention_mask'].to(self.device) with torch.no_grad(): outputs = self.model(input_ids, attention_mask) probabilities = torch.nn.functional.softmax(outputs, dim=1) _, prediction = torch.max(outputs, dim=1) predicted_label = self.id2label[prediction.item()] confidence = probabilities[0][prediction.item()].item() trust_prob = probabilities[0][0].item() scam_prob = probabilities[0][1].item() # Format result with emoji if predicted_label == 'scam': result_text = f"๐Ÿšซ SCAM DETECTED" color = "red" else: result_text = f"โœ… TRUSTED MESSAGE" color = "green" # Confidence level description if confidence >= 0.9: conf_desc = "Very High" elif confidence >= 0.75: conf_desc = "High" elif confidence >= 0.6: conf_desc = "Medium" else: conf_desc = "Low" # Store prediction history self.prediction_history.append({ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'message': message[:50] + "..." if len(message) > 50 else message, 'prediction': predicted_label, 'confidence': confidence, 'trust_prob': trust_prob, 'scam_prob': scam_prob }) # Create probability chart prob_chart = self.create_probability_chart(trust_prob, scam_prob) # Detailed results details = f""" **Prediction:** {result_text} **Confidence:** {confidence:.1%} ({conf_desc}) **Device:** {self.device} **Message Length:** {len(message)} characters """ return result_text, confidence, details, prob_chart def predict_api(self, message): """API-friendly prediction function for webhooks""" if not message or not message.strip(): return { "status": "error", "message": "Empty message", "prediction": "unknown", "confidence": 0.0 } message = message.strip() try: # Tokenize message encoding = self.tokenizer( message, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) input_ids = encoding['input_ids'].to(self.device) attention_mask = encoding['attention_mask'].to(self.device) with torch.no_grad(): outputs = self.model(input_ids, attention_mask) probabilities = torch.nn.functional.softmax(outputs, dim=1) _, prediction = torch.max(outputs, dim=1) predicted_label = self.id2label[prediction.item()] confidence = probabilities[0][prediction.item()].item() trust_prob = probabilities[0][0].item() scam_prob = probabilities[0][1].item() # Format result if predicted_label == 'scam': result_text = "๐Ÿšซ SCAM DETECTED" alert_level = "HIGH" else: result_text = "โœ… TRUSTED MESSAGE" alert_level = "LOW" # Store prediction history self.prediction_history.append({ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'message': message[:50] + "..." if len(message) > 50 else message, 'prediction': predicted_label, 'confidence': confidence, 'trust_prob': trust_prob, 'scam_prob': scam_prob, 'source': 'API' }) return { "status": "success", "message": message[:100] + "..." if len(message) > 100 else message, "prediction": predicted_label, "result_text": result_text, "confidence": round(confidence, 4), "trust_probability": round(trust_prob, 4), "scam_probability": round(scam_prob, 4), "alert_level": alert_level, "timestamp": datetime.now().isoformat() } except Exception as e: return { "status": "error", "message": f"Prediction failed: {str(e)}", "prediction": "unknown", "confidence": 0.0 } def create_probability_chart(self, trust_prob, scam_prob): """Create probability visualization""" fig = go.Figure(data=[ go.Bar( x=['Trust', 'Scam'], y=[trust_prob, scam_prob], marker_color=['green', 'red'], text=[f'{trust_prob:.1%}', f'{scam_prob:.1%}'], textposition='auto', ) ]) fig.update_layout( title="Prediction Probabilities", yaxis_title="Probability", xaxis_title="Classification", showlegend=False, height=300, margin=dict(l=20, r=20, t=40, b=20) ) return fig def batch_predict(self, file): """Batch prediction from uploaded file""" if file is None: return "โš ๏ธ Please upload a file", None try: # Read file based on extension if file.name.endswith('.csv'): df = pd.read_csv(file.name) if 'message' in df.columns: messages = df['message'].tolist() else: messages = df.iloc[:, 0].tolist() # First column elif file.name.endswith('.txt'): with open(file.name, 'r', encoding='utf-8') as f: messages = [line.strip() for line in f if line.strip()] else: return "โŒ Unsupported file format. Use CSV or TXT files.", None # Process messages results = [] for i, message in enumerate(messages[:100]): # Limit to 100 messages if message and message.strip(): pred_label, confidence, _, _ = self.predict_message(message) results.append({ 'Message': message[:100] + "..." if len(message) > 100 else message, 'Prediction': pred_label, 'Confidence': f"{confidence:.1%}" }) # Create results DataFrame results_df = pd.DataFrame(results) # Summary scam_count = len([r for r in results if 'SCAM' in r['Prediction']]) trust_count = len(results) - scam_count summary = f""" ๐Ÿ“Š **Batch Processing Complete** - Total Messages: {len(results)} - ๐Ÿšซ Scam Messages: {scam_count} - โœ… Trusted Messages: {trust_count} - ๐Ÿ“ˆ Scam Rate: {scam_count/len(results):.1%} """ return summary, results_df except Exception as e: return f"โŒ Error processing file: {str(e)}", None def get_prediction_history(self): """Get prediction history as DataFrame""" if not self.prediction_history: return pd.DataFrame({'Message': ['No predictions yet']}) df = pd.DataFrame(self.prediction_history[-20:]) # Last 20 predictions df['Confidence'] = df['confidence'].apply(lambda x: f"{x:.1%}") df['Prediction'] = df['prediction'].apply(lambda x: f"๐Ÿšซ {x.upper()}" if x == 'scam' else f"โœ… {x.upper()}") df['Source'] = df.get('source', 'Manual') return df[['timestamp', 'message', 'Prediction', 'Confidence', 'Source']].rename(columns={ 'timestamp': 'Time', 'message': 'Message', }) def clear_history(self): """Clear prediction history""" self.prediction_history = [] return pd.DataFrame({'Message': ['History cleared']}) def get_sample_messages(self): """Get sample messages for testing""" return { "Swahili Scam": "Hongera! Umeshinda Sh 5,000,000. Tuma PIN yako sasa kupokea zawadi yako!", "English Scam": "CONGRATULATIONS! You've won $1,000,000. Send your bank details immediately!", "Swahili Trust": "Habari za leo? Natumai uko salama na kila kitu ni sawa", "English Trust": "Hi there! How was your day today? Hope everything is going well", "Mixed Language": "Hi, kikao kitafanyika kesho at 2 PM. Don't forget!", "Suspicious": "URGENT: Your account will be suspended. Click link to verify now!" } # Global detector instance for API endpoints detector = None def create_gradio_app(): """Create and configure Gradio interface""" global detector # Initialize detector detector = GradioScamDetector() # Custom CSS for better styling css = """ .gradio-container { max-width: 1200px !important; } .result-box { font-size: 18px !important; font-weight: bold !important; text-align: center !important; padding: 20px !important; border-radius: 10px !important; } .scam-result { background-color: #ffebee !important; color: #c62828 !important; border: 2px solid #f44336 !important; } .trust-result { background-color: #e8f5e8 !important; color: #2e7d32 !important; border: 2px solid #4caf50 !important; } """ # Create Gradio interface with gr.Blocks(css=css, title="๐Ÿ›ก๏ธ BERT Scam Detector", theme=gr.themes.Soft()) as demo: # Header gr.Markdown(""" # ๐Ÿ›ก๏ธ BERT Scam Detector ### Intelligent SMS Scam Detection for Swahili & English This AI system uses advanced BERT language models to detect scam messages in both Swahili and English. Simply enter a message below to check if it's legitimate or potentially fraudulent. """) # API Information Tab with gr.Tab("๐Ÿ”Œ API Integration"): gr.Markdown(""" ## ๐Ÿ“ก API Endpoints for IFTTT/Zapier Integration ### For IFTTT Webhook: ``` URL: https://jacksonwambali-bert.hf.space/api/predict Method: POST Content-Type: application/json Body: {"data": ["Your SMS message here"]} ``` ### For Zapier Webhook: ``` URL: https://jacksonwambali-bert.hf.space/api/predict Method: POST Content-Type: application/json Payload: {"data": ["{{sms_text}}"]} ``` ### Response Format: ```json { "data": [ { "status": "success", "prediction": "scam" or "trust", "result_text": "๐Ÿšซ SCAM DETECTED" or "โœ… TRUSTED MESSAGE", "confidence": 0.95, "alert_level": "HIGH" or "LOW" } ] } ``` ### Quick Test: Use the form below to test your API integration: """) with gr.Row(): with gr.Column(): api_test_input = gr.Textbox( label="๐Ÿ“ฑ Test SMS Message", placeholder="Enter SMS to test API response...", lines=3 ) api_test_btn = gr.Button("๐Ÿงช Test API Response", variant="primary") with gr.Column(): api_response = gr.JSON(label="๐Ÿ“Š API Response") api_test_btn.click( fn=lambda msg: detector.predict_api(msg) if detector else {"error": "Model not loaded"}, inputs=api_test_input, outputs=api_response ) # Main prediction interface with gr.Tab("๐Ÿ” Single Message Detection"): with gr.Row(): with gr.Column(scale=2): message_input = gr.Textbox( label="๐Ÿ“ Enter SMS Message", placeholder="Type or paste your SMS message here...", lines=4, max_lines=8 ) with gr.Row(): predict_btn = gr.Button("๐Ÿ” Analyze Message", variant="primary", size="lg") clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear", variant="secondary") # Sample messages gr.Markdown("### ๐Ÿ“‹ Quick Test Samples:") sample_messages = detector.get_sample_messages() with gr.Row(): for name, msg in list(sample_messages.items())[:3]: gr.Button(name, size="sm").click( lambda msg=msg: msg, outputs=message_input ) with gr.Row(): for name, msg in list(sample_messages.items())[3:]: gr.Button(name, size="sm").click( lambda msg=msg: msg, outputs=message_input ) with gr.Column(scale=2): # Results result_text = gr.Textbox( label="๐ŸŽฏ Prediction Result", interactive=False, elem_classes=["result-box"] ) confidence_slider = gr.Slider( label="๐Ÿ“Š Confidence Level", minimum=0, maximum=1, interactive=False, show_label=True ) details_md = gr.Markdown(label="๐Ÿ“‹ Detailed Analysis") prob_chart = gr.Plot(label="๐Ÿ“ˆ Probability Distribution") # Batch processing tab with gr.Tab("๐Ÿ“ Batch Processing"): gr.Markdown("### Upload a file with multiple messages for batch analysis") with gr.Row(): with gr.Column(): file_upload = gr.File( label="๐Ÿ“„ Upload File (CSV or TXT)", file_types=[".csv", ".txt"] ) batch_btn = gr.Button("๐Ÿš€ Process Batch", variant="primary") with gr.Column(): batch_summary = gr.Markdown(label="๐Ÿ“Š Summary") batch_results = gr.Dataframe( label="๐Ÿ“‹ Batch Results", interactive=False, wrap=True ) # History tab with gr.Tab("๐Ÿ“š Prediction History"): with gr.Row(): refresh_btn = gr.Button("๐Ÿ”„ Refresh History", variant="secondary") clear_history_btn = gr.Button("๐Ÿ—‘๏ธ Clear History", variant="secondary") history_df = gr.Dataframe( label="๐Ÿ“‹ Recent Predictions", interactive=False, wrap=True ) # About tab with gr.Tab("โ„น๏ธ About"): gr.Markdown(""" ## ๐Ÿค– About This System ### How It Works - **Model**: BERT (Bidirectional Encoder Representations from Transformers) - **Languages**: Swahili and English - **Training**: Fine-tuned on SMS scam detection dataset - **Accuracy**: High precision scam detection ### Features - โœ… Real-time message analysis - ๐ŸŒ Multilingual support (Swahili & English) - ๐Ÿ“Š Confidence scoring - ๐Ÿ“ Batch processing - ๐Ÿ“š Prediction history - ๐Ÿ”Œ API integration for IFTTT/Zapier ### SMS Integration - Connect with IFTTT for automatic SMS scanning - Webhook support for real-time alerts - Batch processing for multiple messages ### Usage Tips - Enter complete SMS messages for best results - The system works with both languages simultaneously - Higher confidence scores indicate more reliable predictions - Check the probability distribution for detailed insights ### Safety Notice - This is an AI assistant - use your judgment - Report suspicious messages to authorities - Never share personal information with untrusted sources --- **Powered by BERT & Gradio** | Made with โค๏ธ for SMS security """) # Event handlers predict_btn.click( fn=detector.predict_message, inputs=message_input, outputs=[result_text, confidence_slider, details_md, prob_chart] ) clear_btn.click( fn=lambda: ("", 0, "", None), outputs=[message_input, confidence_slider, details_md, prob_chart] ) batch_btn.click( fn=detector.batch_predict, inputs=file_upload, outputs=[batch_summary, batch_results] ) refresh_btn.click( fn=detector.get_prediction_history, outputs=history_df ) clear_history_btn.click( fn=detector.clear_history, outputs=history_df ) # Auto-refresh history on prediction predict_btn.click( fn=detector.get_prediction_history, outputs=history_df ) return demo def main(): """Launch the Gradio app""" print("๐Ÿš€ Starting BERT Scam Detector Web App...") # Create and launch app app = create_gradio_app() # Launch with custom settings app.launch( server_name="0.0.0.0", # Allow external access server_port=7860, # Default Gradio port share=True, # Set to True for public link debug=False, show_error=False, quiet=False, inbrowser=True # Auto-open browser ) if __name__ == "__main__": main()