Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import time | |
| import sys | |
| import asyncio | |
| from typing import List, Dict, Optional | |
| from urllib.parse import urlparse | |
| import socket | |
| import httpx | |
| import joblib | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| import config | |
| from models import get_ml_models, get_dl_models, FinetunedBERT | |
| from feature_extraction import process_row | |
| load_dotenv() | |
| sys.path.append(os.path.join(config.BASE_DIR, 'Message_model')) | |
| from predict import PhishingPredictor | |
| app = FastAPI( | |
| title="Phishing Detection API", | |
| description="Advanced phishing detection system using multiple ML/DL models and Gemini AI", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class MessageInput(BaseModel): | |
| text: str | |
| metadata: Optional[Dict] = {} | |
| class PredictionResponse(BaseModel): | |
| confidence: float | |
| reasoning: str | |
| highlighted_text: str | |
| final_decision: str | |
| suggestion: str | |
| ml_models = {} | |
| dl_models = {} | |
| bert_model = None | |
| semantic_model = None | |
| gemini_model = None | |
| MODEL_BOUNDARIES = { | |
| 'logistic': 0.5, | |
| 'svm': 0.5, | |
| 'xgboost': 0.5, | |
| 'attention_blstm': 0.5, | |
| 'rcnn': 0.5, | |
| 'bert': 0.5, | |
| 'semantic': 0.5 | |
| } | |
| def load_models(): | |
| global ml_models, dl_models, bert_model, semantic_model, gemini_model | |
| print("Loading models...") | |
| models_dir = config.MODELS_DIR | |
| for model_name in ['logistic', 'svm', 'xgboost']: | |
| model_path = os.path.join(models_dir, f'{model_name}.joblib') | |
| if os.path.exists(model_path): | |
| ml_models[model_name] = joblib.load(model_path) | |
| print(f"✓ Loaded {model_name} model") | |
| else: | |
| print(f"⚠ Warning: {model_name} model not found at {model_path}") | |
| for model_name in ['attention_blstm', 'rcnn']: | |
| model_path = os.path.join(models_dir, f'{model_name}.pt') | |
| if os.path.exists(model_path): | |
| model_template = get_dl_models(input_dim=len(config.NUMERICAL_FEATURES)) | |
| dl_models[model_name] = model_template[model_name] | |
| dl_models[model_name].load_state_dict(torch.load(model_path, map_location='cpu')) | |
| dl_models[model_name].eval() | |
| print(f"✓ Loaded {model_name} model") | |
| else: | |
| print(f"⚠ Warning: {model_name} model not found at {model_path}") | |
| bert_path = os.path.join(config.BASE_DIR, 'finetuned_bert') | |
| if os.path.exists(bert_path): | |
| try: | |
| bert_model = FinetunedBERT(bert_path) | |
| print("✓ Loaded BERT model") | |
| except Exception as e: | |
| print(f"⚠ Warning: Could not load BERT model: {e}") | |
| semantic_model_path = os.path.join(config.BASE_DIR, 'Message_model', 'final_semantic_model') | |
| if os.path.exists(semantic_model_path) and os.listdir(semantic_model_path): | |
| try: | |
| semantic_model = PhishingPredictor(model_path=semantic_model_path) | |
| print("✓ Loaded semantic model") | |
| except Exception as e: | |
| print(f"⚠ Warning: Could not load semantic model: {e}") | |
| else: | |
| checkpoint_path = os.path.join(config.BASE_DIR, 'Message_model', 'training_checkpoints', 'checkpoint-30') | |
| if os.path.exists(checkpoint_path): | |
| try: | |
| semantic_model = PhishingPredictor(model_path=checkpoint_path) | |
| print("✓ Loaded semantic model from checkpoint") | |
| except Exception as e: | |
| print(f"⚠ Warning: Could not load semantic model from checkpoint: {e}") | |
| gemini_api_key = os.environ.get('GEMINI_API_KEY') | |
| if gemini_api_key: | |
| genai.configure(api_key=gemini_api_key) | |
| gemini_model = genai.GenerativeModel('gemini-2.0-flash') | |
| print("✓ Initialized Gemini API") | |
| else: | |
| print("⚠ Warning: GEMINI_API_KEY not set. Set it as environment variable.") | |
| print(" Example: export GEMINI_API_KEY='your-api-key-here'") | |
| def parse_message(text: str) -> tuple: | |
| url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|(?:www\.)?[a-zA-Z0-9-]+\.[a-z]{2,12}\b(?:/[^\s]*)?' | |
| urls = re.findall(url_pattern, text) | |
| cleaned_text = re.sub(url_pattern, '', text) | |
| cleaned_text = ' '.join(cleaned_text.lower().split()) | |
| cleaned_text = re.sub(r'[^a-z0-9\s.,!?-]', '', cleaned_text) | |
| cleaned_text = re.sub(r'([.,!?])+', r'\1', cleaned_text) | |
| cleaned_text = ' '.join(cleaned_text.split()) | |
| return urls, cleaned_text | |
| async def extract_url_features(urls: List[str]) -> pd.DataFrame: | |
| if not urls: | |
| return pd.DataFrame() | |
| df = pd.DataFrame({'url': urls}) | |
| whois_cache = {} | |
| ssl_cache = {} | |
| tasks = [] | |
| for _, row in df.iterrows(): | |
| tasks.append(asyncio.to_thread(process_row, row, whois_cache, ssl_cache)) | |
| feature_list = await asyncio.gather(*tasks) | |
| features_df = pd.DataFrame(feature_list) | |
| result_df = pd.concat([df, features_df], axis=1) | |
| return result_df | |
| def custom_boundary(raw_score: float, boundary: float) -> float: | |
| return (raw_score - boundary) * 100 | |
| def get_model_predictions(features_df: pd.DataFrame, message_text: str) -> Dict: | |
| predictions = {} | |
| numerical_features = config.NUMERICAL_FEATURES | |
| categorical_features = config.CATEGORICAL_FEATURES | |
| try: | |
| X = features_df[numerical_features + categorical_features] | |
| except KeyError as e: | |
| print(f"Error: Missing columns in features_df. {e}") | |
| print(f"Available columns: {features_df.columns.tolist()}") | |
| X = pd.DataFrame(columns=numerical_features + categorical_features) | |
| if not X.empty: | |
| X.loc[:, numerical_features] = X.loc[:, numerical_features].fillna(-1) | |
| X.loc[:, categorical_features] = X.loc[:, categorical_features].fillna('N/A') | |
| for model_name, model in ml_models.items(): | |
| try: | |
| all_probas = model.predict_proba(X)[:, 1] | |
| raw_score = np.max(all_probas) | |
| scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name]) | |
| predictions[model_name] = { | |
| 'raw_score': float(raw_score), | |
| 'scaled_score': float(scaled_score) | |
| } | |
| except Exception as e: | |
| print(f"Error with {model_name} (Prediction Step): {e}") | |
| X_numerical = X[numerical_features].values | |
| for model_name, model in dl_models.items(): | |
| try: | |
| X_tensor = torch.tensor(X_numerical, dtype=torch.float32) | |
| with torch.no_grad(): | |
| all_scores = model(X_tensor) | |
| raw_score = torch.max(all_scores).item() | |
| scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name]) | |
| predictions[model_name] = { | |
| 'raw_score': float(raw_score), | |
| 'scaled_score': float(scaled_score) | |
| } | |
| except Exception as e: | |
| print(f"Error with {model_name}: {e}") | |
| if bert_model and len(features_df) > 0: | |
| try: | |
| urls = features_df['url'].tolist() | |
| raw_scores = bert_model.predict_proba(urls) | |
| avg_raw_score = np.mean([score[1] for score in raw_scores]) | |
| scaled_score = custom_boundary(avg_raw_score, MODEL_BOUNDARIES['bert']) | |
| predictions['bert'] = { | |
| 'raw_score': float(avg_raw_score), | |
| 'scaled_score': float(scaled_score) | |
| } | |
| except Exception as e: | |
| print(f"Error with BERT: {e}") | |
| if semantic_model and message_text: | |
| try: | |
| result = semantic_model.predict(message_text) | |
| raw_score = result['phishing_probability'] | |
| scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES['semantic']) | |
| predictions['semantic'] = { | |
| 'raw_score': float(raw_score), | |
| 'scaled_score': float(scaled_score), | |
| 'confidence': result['confidence'] | |
| } | |
| except Exception as e: | |
| print(f"Error with semantic model: {e}") | |
| return predictions | |
| async def get_network_features_for_gemini(urls: List[str]) -> str: | |
| if not urls: | |
| return "No URLs to analyze for network features." | |
| results = [] | |
| async with httpx.AsyncClient() as client: | |
| for i, url_str in enumerate(urls[:3]): | |
| try: | |
| hostname = urlparse(url_str).hostname | |
| if not hostname: | |
| results.append(f"\nURL {i+1} ({url_str}): Invalid URL, no hostname.") | |
| continue | |
| try: | |
| ip_address = await asyncio.to_thread(socket.gethostbyname, hostname) | |
| except socket.gaierror: | |
| results.append(f"\nURL {i+1} ({hostname}): Could not resolve domain to IP.") | |
| continue | |
| try: | |
| geo_url = f"http://ip-api.com/json/{ip_address}?fields=status,message,country,city,isp,org,as" | |
| response = await client.get(geo_url, timeout=3.0) | |
| response.raise_for_status() | |
| data = response.json() | |
| if data.get('status') == 'success': | |
| geo_info = ( | |
| f" • IP Address: {ip_address}\n" | |
| f" • Location: {data.get('city', 'N/A')}, {data.get('country', 'N/A')}\n" | |
| f" • ISP: {data.get('isp', 'N/A')}\n" | |
| f" • Organization: {data.get('org', 'N/A')}\n" | |
| f" • ASN: {data.get('as', 'N/A')}" | |
| ) | |
| results.append(f"\nURL {i+1} ({hostname}):\n{geo_info}") | |
| else: | |
| results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: API lookup failed ({data.get('message')})") | |
| except httpx.RequestError as e: | |
| results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: Network error while fetching IP info ({str(e)})") | |
| except Exception as e: | |
| results.append(f"\nURL {i+1} ({url_str}): Error processing URL ({str(e)})") | |
| if not results: | |
| return "No valid hostnames found in URLs to analyze." | |
| return "\n".join(results) | |
| async def get_gemini_final_decision(urls: List[str], features_df: pd.DataFrame, | |
| message_text: str, predictions: Dict, | |
| original_text: str) -> Dict: | |
| if not gemini_model: | |
| avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0 | |
| confidence = min(100, max(0, 50 + abs(avg_scaled_score))) | |
| return { | |
| "confidence": round(confidence, 2), | |
| "reasoning": "Gemini API not available. Using average model scores.", | |
| "highlighted_text": original_text, | |
| "final_decision": "phishing" if avg_scaled_score > 0 else "legitimate", | |
| "suggestion": "Do not interact with this message. Delete it immediately and report it to your IT department." if avg_scaled_score > 0 else "This message appears safe, but remain cautious with any links or attachments." | |
| } | |
| url_features_summary = "No URLs detected in message" | |
| has_urls = len(features_df) > 0 | |
| if has_urls: | |
| feature_summary_parts = [] | |
| for idx, row in features_df.iterrows(): | |
| url = row.get('url', 'Unknown') | |
| feature_summary_parts.append(f"\nURL {idx+1}: {url}") | |
| feature_summary_parts.append(f" • Length: {row.get('url_length', 'N/A')} chars") | |
| feature_summary_parts.append(f" • Dots in URL: {row.get('count_dot', 'N/A')}") | |
| feature_summary_parts.append(f" • Special characters: {row.get('count_special_chars', 'N/A')}") | |
| feature_summary_parts.append(f" • Domain age: {row.get('domain_age_days', 'N/A')} days") | |
| feature_summary_parts.append(f" • SSL certificate valid: {row.get('cert_has_valid_hostname', 'N/A')}") | |
| feature_summary_parts.append(f" • Uses HTTPS: {row.get('https', 'N/A')}") | |
| url_features_summary = "\n".join(feature_summary_parts) | |
| network_features_summary = await get_network_features_for_gemini(urls) | |
| model_predictions_summary = [] | |
| for model_name, pred_data in predictions.items(): | |
| scaled = pred_data['scaled_score'] | |
| raw = pred_data['raw_score'] | |
| model_predictions_summary.append( | |
| f" • {model_name.upper()}: scaled_score={scaled:.2f} (raw={raw:.3f})" | |
| ) | |
| model_scores_text = "\n".join(model_predictions_summary) | |
| MAX_TEXT_LEN = 3000 | |
| if len(original_text) > MAX_TEXT_LEN: | |
| truncated_original_text = original_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]" | |
| else: | |
| truncated_original_text = original_text | |
| if len(message_text) > MAX_TEXT_LEN: | |
| truncated_message_text = message_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]" | |
| else: | |
| truncated_message_text = message_text | |
| context = f"""You are a security model that must decide if a message is phishing or legitimate. | |
| Use all evidence below: | |
| - URL/network data (trust NETWORK_GEO more than URL_FEATURES when they disagree; domain_age = -1 means unknown). | |
| - Model scores (scaled_score > 0 → more phishing, < 0 → more legitimate). | |
| - Message content (urgency, threats, credential/OTP/payment requests, impersonation). | |
| If strong phishing signals exist, prefer "phishing". If everything matches a normal, known service/organization and content is routine, prefer "legitimate". | |
| Return only this JSON object: | |
| {{ | |
| "confidence": <float 0-100>, | |
| "reasoning": "<brief explanation referring to key evidence>", | |
| "highlighted_text": "<full original message with suspicious spans wrapped in $$...$$>", | |
| "final_decision": "phishing" or "legitimate", | |
| "suggestion": "<practical advice for the user on what to do>" | |
| }} | |
| MESSAGE_ORIGINAL: | |
| {truncated_original_text} | |
| MESSAGE_CLEANED: | |
| {truncated_message_text} | |
| URLS: | |
| {', '.join(urls) if urls else 'None'} | |
| URL_FEATURES: | |
| {url_features_summary} | |
| NETWORK_GEO: | |
| {network_features_summary} | |
| MODEL_SCORES (scaled_score > 0 phishing, < 0 legitimate): | |
| {model_scores_text} | |
| """ | |
| try: | |
| generation_config = { | |
| 'temperature': 0.2, | |
| 'top_p': 0.85, | |
| 'top_k': 40, | |
| 'max_output_tokens': 8192, | |
| 'response_mime_type': 'application/json' | |
| } | |
| safety_settings = { | |
| "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE", | |
| "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE", | |
| "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE", | |
| "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE", | |
| } | |
| max_retries = 3 | |
| retry_delay = 2 | |
| for attempt in range(max_retries): | |
| try: | |
| response = await gemini_model.generate_content_async( | |
| context, | |
| generation_config=generation_config, | |
| safety_settings=safety_settings | |
| ) | |
| if not response.candidates or not response.candidates[0].content.parts: | |
| raise ValueError(f"No content returned. Finish Reason: {response.candidates[0].finish_reason}") | |
| break | |
| except Exception as retry_error: | |
| print(f"Gemini API attempt {attempt + 1} failed: {retry_error}") | |
| if attempt < max_retries - 1: | |
| print(f"Retrying in {retry_delay}s...") | |
| await asyncio.sleep(retry_delay) | |
| retry_delay *= 2 | |
| else: | |
| raise retry_error | |
| response_text = response.text.strip() | |
| if '```json' in response_text: | |
| response_text = response_text.split('```json')[1].split('```')[0].strip() | |
| elif response_text.startswith('```') and response_text.endswith('```'): | |
| response_text = response_text[3:-3].strip() | |
| if not response_text.startswith('{'): | |
| json_match = re.search(r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}', response_text, re.DOTALL) | |
| if json_match: | |
| response_text = json_match.group(0) | |
| else: | |
| raise ValueError(f"Could not find JSON in Gemini response: {response_text[:200]}") | |
| result = json.loads(response_text) | |
| required_fields = ['confidence', 'reasoning', 'highlighted_text', 'final_decision', 'suggestion'] | |
| if not all(field in result for field in required_fields): | |
| raise ValueError(f"Missing required fields. Got: {list(result.keys())}") | |
| result['confidence'] = float(result['confidence']) | |
| if not 0 <= result['confidence'] <= 100: | |
| result['confidence'] = max(0, min(100, result['confidence'])) | |
| if result['final_decision'].lower() not in ['phishing', 'legitimate']: | |
| result['final_decision'] = 'phishing' if result['confidence'] >= 50 else 'legitimate' | |
| else: | |
| result['final_decision'] = result['final_decision'].lower() | |
| if not result['highlighted_text'].strip() or '...' in result['highlighted_text']: | |
| result['highlighted_text'] = original_text | |
| if not result.get('suggestion', '').strip(): | |
| if result['final_decision'] == 'phishing': | |
| result['suggestion'] = "Do not interact with this message. Delete it immediately and report it as phishing." | |
| else: | |
| result['suggestion'] = "This message appears safe, but always verify sender identity before taking any action." | |
| return result | |
| except json.JSONDecodeError as e: | |
| print(f"JSON parsing error: {e}") | |
| print(f"Response text that failed parsing: {response_text[:500]}") | |
| avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0 | |
| confidence = min(100, max(0, 50 + abs(avg_scaled_score))) | |
| return { | |
| "confidence": round(confidence, 2), | |
| "reasoning": f"Gemini response parsing failed. Fallback: Based on model average (score: {avg_scaled_score:.2f}), message appears {'legitimate' if avg_scaled_score <= 0 else 'suspicious'}.", | |
| "highlighted_text": original_text, | |
| "final_decision": "phishing" if avg_scaled_score > 0 else "legitimate", | |
| "suggestion": "Do not interact with this message. Delete it immediately and be cautious." if avg_scaled_score > 0 else "Exercise caution. Verify the sender before taking any action." | |
| } | |
| except Exception as e: | |
| print(f"Error with Gemini API: {e}") | |
| avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0 | |
| confidence = min(100, max(0, 50 + abs(avg_scaled_score))) | |
| return { | |
| "confidence": round(confidence, 2), | |
| "reasoning": f"Gemini API error: {str(e)}. Fallback decision based on {len(predictions)} model predictions (average score: {avg_scaled_score:.2f}).", | |
| "highlighted_text": original_text, | |
| "final_decision": "phishing" if avg_scaled_score > 0 else "legitimate", | |
| "suggestion": "Treat this message with caution. Delete it if suspicious, or verify the sender through official channels before taking action." if avg_scaled_score > 0 else "This message appears safe based on models, but always verify sender identity before clicking links or providing information." | |
| } | |
| async def startup_event(): | |
| load_models() | |
| print("\n" + "="*60) | |
| print("Phishing Detection API is ready!") | |
| print("="*60) | |
| print("API Documentation: http://localhost:8000/docs") | |
| print("="*60 + "\n") | |
| async def root(): | |
| return { | |
| "message": "Phishing Detection API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "predict": "/predict (POST)", | |
| "health": "/health (GET)", | |
| "docs": "/docs (GET)" | |
| } | |
| } | |
| async def health_check(): | |
| models_loaded = { | |
| "ml_models": list(ml_models.keys()), | |
| "dl_models": list(dl_models.keys()), | |
| "bert_model": bert_model is not None, | |
| "semantic_model": semantic_model is not None, | |
| "gemini_model": gemini_model is not None | |
| } | |
| return { | |
| "status": "healthy", | |
| "models_loaded": models_loaded | |
| } | |
| async def predict(message_input: MessageInput): | |
| try: | |
| original_text = message_input.text | |
| if not original_text or not original_text.strip(): | |
| raise HTTPException(status_code=400, detail="Message text cannot be empty") | |
| urls, cleaned_text = parse_message(original_text) | |
| features_df = pd.DataFrame() | |
| if urls: | |
| features_df = await extract_url_features(urls) | |
| predictions = {} | |
| if len(features_df) > 0 or (cleaned_text and semantic_model): | |
| predictions = await asyncio.to_thread(get_model_predictions, features_df, cleaned_text) | |
| if not predictions: | |
| if not urls and not cleaned_text: | |
| detail = "Message text is empty after cleaning." | |
| elif not urls and not semantic_model: | |
| detail = "No URLs provided and semantic model is not loaded." | |
| elif not any([ml_models, dl_models, bert_model, semantic_model]): | |
| detail = "No models available for prediction. Please ensure models are trained and loaded." | |
| else: | |
| detail = "Could not generate predictions. Models may be missing or feature extraction failed." | |
| raise HTTPException( | |
| status_code=500, | |
| detail=detail | |
| ) | |
| final_result = await get_gemini_final_decision( | |
| urls, features_df, cleaned_text, predictions, original_text | |
| ) | |
| return PredictionResponse(**final_result) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |