Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, render_template_string | |
| from flask_cors import CORS | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import logging | |
| import io | |
| import pandas as pd | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Set environment variables | |
| os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' | |
| # Global variables for model and tokenizer | |
| tokenizer = None | |
| model = None | |
| def load_model(): | |
| """Load the phishing detection model""" | |
| global tokenizer, model | |
| try: | |
| logger.info("Loading phishing detection model...") | |
| model_name = "AntiSpamInstitute/bert-MoE-Phishing-detection-v2.4" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| model.eval() # Set to evaluation mode | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| def predict_phishing(text): | |
| """Predict if text is phishing or safe""" | |
| try: | |
| # Tokenize the input text | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| # Get prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| confidence, predicted_class = torch.max(probabilities, dim=1) | |
| # Convert to human-readable format | |
| label = "Phishing" if predicted_class.item() == 1 else "Safe" | |
| confidence_percent = round(confidence.item() * 100, 1) | |
| return label, confidence_percent | |
| except Exception as e: | |
| logger.error(f"Error in prediction: {e}") | |
| raise | |
| def home(): | |
| """Health check endpoint""" | |
| return jsonify({ | |
| "status": "healthy", | |
| "message": "Anti-Phishing Scanner API", | |
| "endpoints": { | |
| "/analyze": "POST - Analyze text for phishing", | |
| "/health": "GET - Health check", | |
| "/evaluate": "GET/POST - Upload CSV and evaluate model accuracy" | |
| } | |
| }) | |
| def health(): | |
| """Health check endpoint""" | |
| return jsonify({ | |
| "status": "healthy", | |
| "model_loaded": model is not None | |
| }) | |
| def analyze(): | |
| """Analyze text for phishing detection""" | |
| try: | |
| # Get JSON data | |
| data = request.get_json() | |
| if not data or "message" not in data: | |
| return jsonify({"error": "Missing 'message' field"}), 400 | |
| message = data["message"] | |
| if not message or not message.strip(): | |
| return jsonify({"error": "Message cannot be empty"}), 400 | |
| # Make prediction | |
| label, confidence = predict_phishing(message.strip()) | |
| return jsonify({ | |
| "result": label, | |
| "confidence": f"{confidence}%", | |
| "message": message | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error in analyze endpoint: {e}") | |
| return jsonify({"error": "Internal server error"}), 500 | |
| # ============================= | |
| # NEW: /evaluate (GET form + POST CSV) | |
| # ============================= | |
| def evaluate(): | |
| """Upload a CSV with text+label to compute accuracy, precision, recall, F1""" | |
| if request.method == "GET": | |
| # Simple HTML form to upload a CSV | |
| return render_template_string( | |
| """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset='utf-8'/> | |
| <title>Model Evaluation</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 2rem; background: #f9f9f9; } | |
| h2 { color: #333; } | |
| form { margin-top: 1rem; padding: 1rem; background: #fff; border-radius: 8px; box-shadow: 0 2px 6px rgba(0,0,0,0.1); } | |
| input[type=file] { margin: 1rem 0; } | |
| button { background: #4CAF50; color: white; border: none; padding: 0.5rem 1rem; border-radius: 5px; cursor: pointer; } | |
| button:hover { background: #45a049; } | |
| .hint { color: #555; font-size: 0.95rem; } | |
| </style> | |
| </head> | |
| <body> | |
| <h2>Upload a CSV to Evaluate Model Accuracy</h2> | |
| <p class="hint">Expected columns: <code>text</code> (or <code>message</code>) and <code>label</code> (values: <em>phishing</em>/<em>safe</em> or 1/0)</p> | |
| <form action="/evaluate" method="post" enctype="multipart/form-data"> | |
| <input type="file" name="file" accept=".csv" required><br> | |
| <button type="submit">Run Evaluation</button> | |
| </form> | |
| </body> | |
| </html> | |
| """ | |
| ) | |
| # POST: handle CSV upload, run evaluation | |
| try: | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file uploaded. Please upload a CSV with 'text' or 'message' and 'label' columns."}), 400 | |
| file = request.files["file"] | |
| # Read CSV (handle utf-8 gracefully) | |
| content = file.stream.read().decode("utf-8", errors="ignore") | |
| df = pd.read_csv(io.StringIO(content)) | |
| # Determine text column | |
| text_col = None | |
| if "text" in df.columns: | |
| text_col = "text" | |
| elif "message" in df.columns: | |
| text_col = "message" | |
| if text_col is None: | |
| return jsonify({"error": "CSV must have a 'text' or 'message' column."}), 400 | |
| if "label" not in df.columns: | |
| return jsonify({"error": "CSV must have a 'label' column."}), 400 | |
| # Normalize labels to 0/1 (0=safe, 1=phishing) | |
| def to_int_label(x): | |
| if isinstance(x, str): | |
| s = x.strip().lower() | |
| if s in ("phishing", "spam", "1"): # treat 'spam' as phishing | |
| return 1 | |
| if s in ("safe", "ham", "0"): | |
| return 0 | |
| try: | |
| v = int(x) | |
| return 1 if v == 1 else 0 | |
| except Exception: | |
| return None | |
| texts = df[text_col].astype(str).tolist() | |
| labels = [to_int_label(v) for v in df["label"].tolist()] | |
| # Filter out rows with invalid labels | |
| valid_items = [(t, y) for t, y in zip(texts, labels) if y is not None] | |
| if not valid_items: | |
| return jsonify({"error": "No valid rows. Ensure 'label' values are 'phishing'/'safe' or 1/0."}), 400 | |
| texts_valid, y_true = zip(*valid_items) | |
| # Predict | |
| y_pred = [] | |
| for txt in texts_valid: | |
| pred_label, _conf = predict_phishing(txt) | |
| y_pred.append(1 if pred_label.lower() == "phishing" else 0) | |
| # Compute metrics (no sklearn) | |
| tp = sum(1 for p, y in zip(y_pred, y_true) if p == 1 and y == 1) | |
| tn = sum(1 for p, y in zip(y_pred, y_true) if p == 0 and y == 0) | |
| fp = sum(1 for p, y in zip(y_pred, y_true) if p == 1 and y == 0) | |
| fn = sum(1 for p, y in zip(y_pred, y_true) if p == 0 and y == 1) | |
| # Collect misclassified samples | |
| false_positives = [(t, y, p) for t, y, p in zip(texts_valid, y_true, y_pred) if y == 0 and p == 1] | |
| false_negatives = [(t, y, p) for t, y, p in zip(texts_valid, y_true, y_pred) if y == 1 and p == 0] | |
| total = len(y_true) | |
| accuracy = (tp + tn) / total if total else 0.0 | |
| precision = tp / (tp + fp) if (tp + fp) else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) else 0.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0 | |
| skipped = len(texts) - total | |
| # Render results page | |
| return render_template_string( | |
| f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset='utf-8'/> | |
| <title>Evaluation Results</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 2rem; background: #f9f9f9; }} | |
| h2 {{ color: #333; }} | |
| .results {{ margin-top: 1rem; padding: 1rem; background: #fff; border-radius: 8px; box-shadow: 0 2px 6px rgba(0,0,0,0.1); }} | |
| p {{ margin: 0.3rem 0; }} | |
| .small {{ color: #666; font-size: 0.9rem; }} | |
| a.button {{ display:inline-block; margin-top:1rem; padding:0.5rem 0.8rem; background:#4CAF50; color:#fff; text-decoration:none; border-radius:6px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h2>Evaluation Results</h2> | |
| <div class="results"> | |
| <p><b>Samples Tested:</b> {total}</p> | |
| <p><b>Accuracy:</b> {accuracy:.4f}</p> | |
| <p><b>Precision:</b> {precision:.4f}</p> | |
| <p><b>Recall:</b> {recall:.4f}</p> | |
| <p><b>F1 Score:</b> {f1:.4f}</p> | |
| <p class="small">TP: {tp} • TN: {tn} • FP: {fp} • FN: {fn} • Skipped rows: {skipped}</p> | |
| </div> | |
| <h3>❌ False Negatives (Phishing predicted as Safe)</h3> | |
| <table> | |
| <tr><th>Text</th><th>True Label</th><th>Predicted</th></tr> | |
| {''.join(f"<tr><td>{t}</td><td>phishing</td><td>safe</td></tr>" for t, y, p in false_negatives)} | |
| </table> | |
| <h3>⚠️ False Positives (Safe predicted as Phishing)</h3> | |
| <table> | |
| <tr><th>Text</th><th>True Label</th><th>Predicted</th></tr> | |
| {''.join(f"<tr><td>{t}</td><td>safe</td><td>phishing</td></tr>" for t, y, p in false_positives)} | |
| </table> | |
| <a class="button" href="/evaluate">← Run another test</a> | |
| </body> | |
| </html> | |
| """ | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in evaluate endpoint: {e}") | |
| return jsonify({"error": "Evaluation failed"}), 500 | |
| def not_found(error): | |
| return jsonify({"error": "Endpoint not found"}), 404 | |
| def internal_error(error): | |
| return jsonify({"error": "Internal server error"}), 500 | |
| # Load model on startup | |
| if __name__ == "__main__": | |
| load_model() | |
| app.run(debug=False, host="0.0.0.0", port=7860) | |
| else: | |
| # For Hugging Face Spaces | |
| load_model() | |