acc-test / app.py
adAstra144's picture
yah
7a0bcf1 verified
from flask import Flask, request, jsonify, render_template_string
from flask_cors import CORS
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import logging
import io
import pandas as pd
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app)
# Set environment variables
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
# Global variables for model and tokenizer
tokenizer = None
model = None
def load_model():
"""Load the phishing detection model"""
global tokenizer, model
try:
logger.info("Loading phishing detection model...")
model_name = "AntiSpamInstitute/bert-MoE-Phishing-detection-v2.4"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval() # Set to evaluation mode
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def predict_phishing(text):
"""Predict if text is phishing or safe"""
try:
# Tokenize the input text
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
confidence, predicted_class = torch.max(probabilities, dim=1)
# Convert to human-readable format
label = "Phishing" if predicted_class.item() == 1 else "Safe"
confidence_percent = round(confidence.item() * 100, 1)
return label, confidence_percent
except Exception as e:
logger.error(f"Error in prediction: {e}")
raise
@app.route("/", methods=["GET"])
def home():
"""Health check endpoint"""
return jsonify({
"status": "healthy",
"message": "Anti-Phishing Scanner API",
"endpoints": {
"/analyze": "POST - Analyze text for phishing",
"/health": "GET - Health check",
"/evaluate": "GET/POST - Upload CSV and evaluate model accuracy"
}
})
@app.route("/health", methods=["GET"])
def health():
"""Health check endpoint"""
return jsonify({
"status": "healthy",
"model_loaded": model is not None
})
@app.route("/analyze", methods=["POST"])
def analyze():
"""Analyze text for phishing detection"""
try:
# Get JSON data
data = request.get_json()
if not data or "message" not in data:
return jsonify({"error": "Missing 'message' field"}), 400
message = data["message"]
if not message or not message.strip():
return jsonify({"error": "Message cannot be empty"}), 400
# Make prediction
label, confidence = predict_phishing(message.strip())
return jsonify({
"result": label,
"confidence": f"{confidence}%",
"message": message
})
except Exception as e:
logger.error(f"Error in analyze endpoint: {e}")
return jsonify({"error": "Internal server error"}), 500
# =============================
# NEW: /evaluate (GET form + POST CSV)
# =============================
@app.route("/evaluate", methods=["GET", "POST"])
def evaluate():
"""Upload a CSV with text+label to compute accuracy, precision, recall, F1"""
if request.method == "GET":
# Simple HTML form to upload a CSV
return render_template_string(
"""
<!DOCTYPE html>
<html>
<head>
<meta charset='utf-8'/>
<title>Model Evaluation</title>
<style>
body { font-family: Arial, sans-serif; margin: 2rem; background: #f9f9f9; }
h2 { color: #333; }
form { margin-top: 1rem; padding: 1rem; background: #fff; border-radius: 8px; box-shadow: 0 2px 6px rgba(0,0,0,0.1); }
input[type=file] { margin: 1rem 0; }
button { background: #4CAF50; color: white; border: none; padding: 0.5rem 1rem; border-radius: 5px; cursor: pointer; }
button:hover { background: #45a049; }
.hint { color: #555; font-size: 0.95rem; }
</style>
</head>
<body>
<h2>Upload a CSV to Evaluate Model Accuracy</h2>
<p class="hint">Expected columns: <code>text</code> (or <code>message</code>) and <code>label</code> (values: <em>phishing</em>/<em>safe</em> or 1/0)</p>
<form action="/evaluate" method="post" enctype="multipart/form-data">
<input type="file" name="file" accept=".csv" required><br>
<button type="submit">Run Evaluation</button>
</form>
</body>
</html>
"""
)
# POST: handle CSV upload, run evaluation
try:
if "file" not in request.files:
return jsonify({"error": "No file uploaded. Please upload a CSV with 'text' or 'message' and 'label' columns."}), 400
file = request.files["file"]
# Read CSV (handle utf-8 gracefully)
content = file.stream.read().decode("utf-8", errors="ignore")
df = pd.read_csv(io.StringIO(content))
# Determine text column
text_col = None
if "text" in df.columns:
text_col = "text"
elif "message" in df.columns:
text_col = "message"
if text_col is None:
return jsonify({"error": "CSV must have a 'text' or 'message' column."}), 400
if "label" not in df.columns:
return jsonify({"error": "CSV must have a 'label' column."}), 400
# Normalize labels to 0/1 (0=safe, 1=phishing)
def to_int_label(x):
if isinstance(x, str):
s = x.strip().lower()
if s in ("phishing", "spam", "1"): # treat 'spam' as phishing
return 1
if s in ("safe", "ham", "0"):
return 0
try:
v = int(x)
return 1 if v == 1 else 0
except Exception:
return None
texts = df[text_col].astype(str).tolist()
labels = [to_int_label(v) for v in df["label"].tolist()]
# Filter out rows with invalid labels
valid_items = [(t, y) for t, y in zip(texts, labels) if y is not None]
if not valid_items:
return jsonify({"error": "No valid rows. Ensure 'label' values are 'phishing'/'safe' or 1/0."}), 400
texts_valid, y_true = zip(*valid_items)
# Predict
y_pred = []
for txt in texts_valid:
pred_label, _conf = predict_phishing(txt)
y_pred.append(1 if pred_label.lower() == "phishing" else 0)
# Compute metrics (no sklearn)
tp = sum(1 for p, y in zip(y_pred, y_true) if p == 1 and y == 1)
tn = sum(1 for p, y in zip(y_pred, y_true) if p == 0 and y == 0)
fp = sum(1 for p, y in zip(y_pred, y_true) if p == 1 and y == 0)
fn = sum(1 for p, y in zip(y_pred, y_true) if p == 0 and y == 1)
# Collect misclassified samples
false_positives = [(t, y, p) for t, y, p in zip(texts_valid, y_true, y_pred) if y == 0 and p == 1]
false_negatives = [(t, y, p) for t, y, p in zip(texts_valid, y_true, y_pred) if y == 1 and p == 0]
total = len(y_true)
accuracy = (tp + tn) / total if total else 0.0
precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
skipped = len(texts) - total
# Render results page
return render_template_string(
f"""
<!DOCTYPE html>
<html>
<head>
<meta charset='utf-8'/>
<title>Evaluation Results</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 2rem; background: #f9f9f9; }}
h2 {{ color: #333; }}
.results {{ margin-top: 1rem; padding: 1rem; background: #fff; border-radius: 8px; box-shadow: 0 2px 6px rgba(0,0,0,0.1); }}
p {{ margin: 0.3rem 0; }}
.small {{ color: #666; font-size: 0.9rem; }}
a.button {{ display:inline-block; margin-top:1rem; padding:0.5rem 0.8rem; background:#4CAF50; color:#fff; text-decoration:none; border-radius:6px; }}
</style>
</head>
<body>
<h2>Evaluation Results</h2>
<div class="results">
<p><b>Samples Tested:</b> {total}</p>
<p><b>Accuracy:</b> {accuracy:.4f}</p>
<p><b>Precision:</b> {precision:.4f}</p>
<p><b>Recall:</b> {recall:.4f}</p>
<p><b>F1 Score:</b> {f1:.4f}</p>
<p class="small">TP: {tp} • TN: {tn} • FP: {fp} • FN: {fn} • Skipped rows: {skipped}</p>
</div>
<h3>❌ False Negatives (Phishing predicted as Safe)</h3>
<table>
<tr><th>Text</th><th>True Label</th><th>Predicted</th></tr>
{''.join(f"<tr><td>{t}</td><td>phishing</td><td>safe</td></tr>" for t, y, p in false_negatives)}
</table>
<h3>⚠️ False Positives (Safe predicted as Phishing)</h3>
<table>
<tr><th>Text</th><th>True Label</th><th>Predicted</th></tr>
{''.join(f"<tr><td>{t}</td><td>safe</td><td>phishing</td></tr>" for t, y, p in false_positives)}
</table>
<a class="button" href="/evaluate">← Run another test</a>
</body>
</html>
"""
)
except Exception as e:
logger.error(f"Error in evaluate endpoint: {e}")
return jsonify({"error": "Evaluation failed"}), 500
@app.errorhandler(404)
def not_found(error):
return jsonify({"error": "Endpoint not found"}), 404
@app.errorhandler(500)
def internal_error(error):
return jsonify({"error": "Internal server error"}), 500
# Load model on startup
if __name__ == "__main__":
load_model()
app.run(debug=False, host="0.0.0.0", port=7860)
else:
# For Hugging Face Spaces
load_model()