Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import time | |
| import sys | |
| import asyncio | |
| from typing import List, Dict, Optional | |
| from urllib.parse import urlparse | |
| import socket | |
| import httpx | |
| import joblib | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from groq import AsyncGroq | |
| from dotenv import load_dotenv | |
| # --- Make sure 'config.py' and 'models.py' are in the same directory or accessible | |
| import config | |
| from models import get_ml_models, get_dl_models, FinetunedBERT | |
| from feature_extraction import process_row | |
| load_dotenv() | |
| sys.path.append(os.path.join(config.BASE_DIR, 'Message_model')) | |
| from predict import PhishingPredictor | |
| app = FastAPI( | |
| title="Phishing Detection API", | |
| description="Advanced phishing detection system using multiple ML/DL models and Groq", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # --- Pydantic Models --- | |
| class MessageInput(BaseModel): | |
| text: str | |
| metadata: Optional[Dict] = {} | |
| class PredictionResponse(BaseModel): | |
| confidence: float | |
| reasoning: str | |
| highlighted_text: str | |
| final_decision: str | |
| suggestion: str | |
| # --- Global Variables --- | |
| ml_models = {} | |
| dl_models = {} | |
| bert_model = None | |
| semantic_model = None | |
| groq_async_client = None | |
| MODEL_BOUNDARIES = { | |
| 'logistic': 0.5, | |
| 'svm': 0.5, | |
| 'xgboost': 0.5, | |
| 'attention_blstm': 0.5, | |
| 'rcnn': 0.5, | |
| 'bert': 0.5, | |
| 'semantic': 0.5 | |
| } | |
| # --- Model Loading --- | |
| def load_models(): | |
| global ml_models, dl_models, bert_model, semantic_model, groq_async_client | |
| print("Loading models...") | |
| models_dir = config.MODELS_DIR | |
| for model_name in ['logistic', 'svm', 'xgboost']: | |
| model_path = os.path.join(models_dir, f'{model_name}.joblib') | |
| if os.path.exists(model_path): | |
| ml_models[model_name] = joblib.load(model_path) | |
| print(f"✓ Loaded {model_name} model") | |
| else: | |
| print(f"⚠ Warning: {model_name} model not found at {model_path}") | |
| for model_name in ['attention_blstm', 'rcnn']: | |
| model_path = os.path.join(models_dir, f'{model_name}.pt') | |
| if os.path.exists(model_path): | |
| model_template = get_dl_models(input_dim=len(config.NUMERICAL_FEATURES)) | |
| dl_models[model_name] = model_template[model_name] | |
| dl_models[model_name].load_state_dict(torch.load(model_path, map_location='cpu')) | |
| dl_models[model_name].eval() | |
| print(f"✓ Loaded {model_name} model") | |
| else: | |
| print(f"⚠ Warning: {model_name} model not found at {model_path}") | |
| bert_path = os.path.join(config.BASE_DIR, 'finetuned_bert') | |
| if os.path.exists(bert_path): | |
| try: | |
| bert_model = FinetunedBERT(bert_path) | |
| print("✓ Loaded BERT model") | |
| except Exception as e: | |
| print(f"⚠ Warning: Could not load BERT model: {e}") | |
| semantic_model_path = os.path.join(config.BASE_DIR, 'Message_model', 'final_semantic_model') | |
| if os.path.exists(semantic_model_path) and os.listdir(semantic_model_path): | |
| try: | |
| semantic_model = PhishingPredictor(model_path=semantic_model_path) | |
| print("✓ Loaded semantic model") | |
| except Exception as e: | |
| print(f"⚠ Warning: Could not load semantic model: {e}") | |
| else: | |
| checkpoint_path = os.path.join(config.BASE_DIR, 'Message_model', 'training_checkpoints', 'checkpoint-30') | |
| if os.path.exists(checkpoint_path): | |
| try: | |
| semantic_model = PhishingPredictor(model_path=checkpoint_path) | |
| print("✓ Loaded semantic model from checkpoint") | |
| except Exception as e: | |
| print(f"⚠ Warning: Could not load semantic model from checkpoint: {e}") | |
| groq_api_key = os.environ.get('GROQ_API_KEY') | |
| if groq_api_key: | |
| groq_async_client = AsyncGroq(api_key=groq_api_key) | |
| print("✓ Initialized Groq API Client") | |
| else: | |
| print("⚠ Warning: GROQ_API_KEY not set. Set it as environment variable.") | |
| print(" Example: export GROQ_API_KEY='your-api-key-here'") | |
| # --- Feature Extraction & Prediction Logic --- | |
| def parse_message(text: str) -> tuple: | |
| url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|(?:www\.)?[a-zA-Z0-9-]+\.[a-z]{2,12}\b(?:/[^\s]*)?' | |
| urls = re.findall(url_pattern, text) | |
| cleaned_text = re.sub(url_pattern, '', text) | |
| cleaned_text = ' '.join(cleaned_text.lower().split()) | |
| cleaned_text = re.sub(r'[^a-z0-9\s.,!?-]', '', cleaned_text) | |
| cleaned_text = re.sub(r'([.,!?])+', r'\1', cleaned_text) | |
| cleaned_text = ' '.join(cleaned_text.split()) | |
| return urls, cleaned_text | |
| async def extract_url_features(urls: List[str]) -> pd.DataFrame: | |
| if not urls: | |
| return pd.DataFrame() | |
| df = pd.DataFrame({'url': urls}) | |
| whois_cache = {} | |
| ssl_cache = {} | |
| tasks = [] | |
| for _, row in df.iterrows(): | |
| tasks.append(asyncio.to_thread(process_row, row, whois_cache, ssl_cache)) | |
| feature_list = await asyncio.gather(*tasks) | |
| features_df = pd.DataFrame(feature_list) | |
| result_df = pd.concat([df, features_df], axis=1) | |
| return result_df | |
| def custom_boundary(raw_score: float, boundary: float) -> float: | |
| # --- MODIFIED: This now returns a score from -50 to +50 --- | |
| return (raw_score - boundary) * 100 | |
| def get_model_predictions(features_df: pd.DataFrame, message_text: str) -> Dict: | |
| predictions = {} | |
| numerical_features = config.NUMERICAL_FEATURES | |
| categorical_features = config.CATEGORICAL_FEATURES | |
| try: | |
| X = features_df[numerical_features + categorical_features] | |
| except KeyError as e: | |
| print(f"Error: Missing columns in features_df. {e}") | |
| print(f"Available columns: {features_df.columns.tolist()}") | |
| X = pd.DataFrame(columns=numerical_features + categorical_features) | |
| if not X.empty: | |
| X.loc[:, numerical_features] = X.loc[:, numerical_features].fillna(-1) | |
| X.loc[:, categorical_features] = X.loc[:, categorical_features].fillna('N/A') | |
| for model_name, model in ml_models.items(): | |
| try: | |
| all_probas = model.predict_proba(X)[:, 1] | |
| raw_score = np.max(all_probas) | |
| # --- MODIFIED: 'scaled_score' is now from -50 (legit) to +50 (phishing) --- | |
| scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name]) | |
| predictions[model_name] = { | |
| 'raw_score': float(raw_score), | |
| 'scaled_score': float(scaled_score) | |
| } | |
| except Exception as e: | |
| print(f"Error with {model_name} (Prediction Step): {e}") | |
| X_numerical = X[numerical_features].values | |
| for model_name, model in dl_models.items(): | |
| try: | |
| X_tensor = torch.tensor(X_numerical, dtype=torch.float32) | |
| with torch.no_grad(): | |
| all_scores = model(X_tensor) | |
| raw_score = torch.max(all_scores).item() | |
| scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES[model_name]) | |
| predictions[model_name] = { | |
| 'raw_score': float(raw_score), | |
| 'scaled_score': float(scaled_score) | |
| } | |
| except Exception as e: | |
| print(f"Error with {model_name}: {e}") | |
| if bert_model and len(features_df) > 0: | |
| try: | |
| urls = features_df['url'].tolist() | |
| raw_scores = bert_model.predict_proba(urls) | |
| avg_raw_score = np.mean([score[1] for score in raw_scores]) | |
| scaled_score = custom_boundary(avg_raw_score, MODEL_BOUNDARIES['bert']) | |
| predictions['bert'] = { | |
| 'raw_score': float(avg_raw_score), | |
| 'scaled_score': float(scaled_score) | |
| } | |
| except Exception as e: | |
| print(f"Error with BERT: {e}") | |
| if semantic_model and message_text: | |
| try: | |
| result = semantic_model.predict(message_text) | |
| raw_score = result['phishing_probability'] | |
| scaled_score = custom_boundary(raw_score, MODEL_BOUNDARIES['semantic']) | |
| predictions['semantic'] = { | |
| 'raw_score': float(raw_score), | |
| 'scaled_score': float(scaled_score), | |
| 'confidence': result['confidence'] # Note: this is the semantic model's own confidence | |
| } | |
| except Exception as e: | |
| print(f"Error with semantic model: {e}") | |
| return predictions | |
| # --- Groq/LLM Final Decision Logic --- | |
| async def get_network_features_for_gemini(urls: List[str]) -> str: | |
| """ | |
| Fetches real-time IP, Geo, and ISP data for URLs. | |
| This runs independently and is ONLY used to inform the LLM prompt. | |
| """ | |
| if not urls: | |
| return "No URLs to analyze for network features." | |
| results = [] | |
| async with httpx.AsyncClient() as client: | |
| for i, url_str in enumerate(urls[:3]): | |
| try: | |
| hostname = urlparse(url_str).hostname | |
| if not hostname: | |
| results.append(f"\nURL {i+1} ({url_str}): Invalid URL, no hostname.") | |
| continue | |
| try: | |
| ip_address = await asyncio.to_thread(socket.gethostbyname, hostname) | |
| except socket.gaierror: | |
| results.append(f"\nURL {i+1} ({hostname}): Could not resolve domain to IP.") | |
| continue | |
| try: | |
| geo_url = f"http://ip-api.com/json/{ip_address}?fields=status,message,country,city,isp,org,as" | |
| response = await client.get(geo_url, timeout=3.0) | |
| response.raise_for_status() | |
| data = response.json() | |
| if data.get('status') == 'success': | |
| geo_info = ( | |
| f" • IP Address: {ip_address}\n" | |
| f" • Location: {data.get('city', 'N/A')}, {data.get('country', 'N/A')}\n" | |
| f" • ISP: {data.get('isp', 'N/A')}\n" | |
| f" • Organization: {data.get('org', 'N/A')}\n" | |
| f" • ASN: {data.get('as', 'N/A')}" | |
| ) | |
| results.append(f"\nURL {i+1} ({hostname}):\n{geo_info}") | |
| else: | |
| results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: API lookup failed ({data.get('message')})") | |
| except httpx.RequestError as e: | |
| results.append(f"\nURL {i+1} ({hostname}):\n • IP Address: {ip_address}\n • Geo-Data: Network error while fetching IP info ({str(e)})") | |
| except Exception as e: | |
| results.append(f"\nURL {i+1} ({url_str}): Error processing URL ({str(e)})") | |
| if not results: | |
| return "No valid hostnames found in URLs to analyze." | |
| return "\n".join(results) | |
| # --- CORRECTED: Static system prompt with fixed examples --- | |
| # This contains all the instructions, few-shot examples, and output format. | |
| SYSTEM_PROMPT = """You are the FINAL JUDGE in a phishing detection system. Your role is critical: analyze ALL available evidence and make the ultimate decision. | |
| IMPORTANT INSTRUCTIONS: | |
| 1. You have FULL AUTHORITY to override model predictions if evidence suggests they're wrong. | |
| 2. **TRUST THE 'INDEPENDENT NETWORK & GEO-DATA' OVER 'URL FEATURES'.** The ML model features (like `domain_age: -1`) can be wrong due to lookup failures. The 'INDEPENDENT' data is a real-time check. | |
| 3. If 'INDEPENDENT' data shows a legitimate organization (e.g., "Cloudflare", "Google", "Codeforces") for a known domain, but the models score it as phishing (due to `domain_age: -1`), you **should override** and classify as 'legitimate'. | |
| 4. Your confidence score is DIRECTIONAL (0-100): | |
| - Scores > 50.0 mean 'phishing'. | |
| - Scores < 50.0 mean 'legitimate'. | |
| - 50.0 is neutral. | |
| - The magnitude indicates certainty (e.g., 95.0 is 'very confident phishing'; 5.0 is 'very confident legitimate'). | |
| - Your confidence score MUST match your 'final_decision'. | |
| 5. BE WARY OF FALSE POSITIVES. Legitimate messages (bank alerts, contest notifications) can seem urgent. | |
| PRIORITY GUIDANCE (Use this logic): | |
| - IF URLs are present: Focus heavily on URL features. | |
| - Examine 'URL FEATURES' for patterns (e.g., domain_age: -1 or 0, high special_chars). | |
| - **CRITICAL:** Cross-reference this with the 'INDEPENDENT NETWORK & GEO-DATA'. This real-time data (IP, Location, ISP) is your ground truth. | |
| - **If `domain_age` is -1, it's a lookup failure.** IGNORE IT and trust the 'INDEPENDENT NETWORK & GEO-DATA' to see if the domain is real (e.g., 'codeforces.com' with a valid IP). | |
| - Then supplement with message content analysis. | |
| - IF NO URLs are present: Focus entirely on message content and semantics. | |
| - Analyze language patterns, urgency tactics, and social engineering techniques | |
| - Look for credential requests, financial solicitations, or threats | |
| - Evaluate the semantic model's assessment heavily | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| FEW-SHOT EXAMPLES FOR GUIDANCE: | |
| Example 1 - Clear Phishing: | |
| Message: "URGENT! Click: http://paypa1-secure.xyz/verify" | |
| URL Features: domain_age: 5 | |
| Network Data: IP: 123.45.67.89, Location: Russia, ISP: Shady-Host | |
| Model Scores: All positive | |
| Correct Decision: {{ | |
| "confidence": 95.0, | |
| "reasoning": "Classic phishing. Misspelled domain, new age, and network data points to a suspicious ISP in Russia.", | |
| "highlighted_text": "URGENT! Click: $$http://paypa1-secure.xyz/verify$$", | |
| "final_decision": "phishing", | |
| "suggestion": "Do NOT click. Delete immediately." | |
| }} | |
| Example 2 - Legitimate (False Positive Case): | |
| Message: "Hi, join Codeforces Round 184. ... Unsubscribe: https://codeforces.com/unsubscribe/..." | |
| URL Features: domain_age: -1 (This is a lookup failure!) | |
| Network Data: URL (codeforces.com): IP: 104.22.6.109, Location: San Francisco, USA, ISP: Cloudflare, Inc. | |
| Model Scores: Mixed (some positive due to domain_age: -1) | |
| Correct Decision: {{ | |
| "confidence": 10.0, | |
| "reasoning": "OVERRIDING models. The 'URL FEATURES' show a 'domain_age: -1' which is a clear lookup error that confused the models. The 'INDEPENDENT NETWORK & GEO-DATA' confirms the domain 'codeforces.com' is real and hosted on Cloudflare, a legitimate provider. The message content is a standard, safe notification.", | |
| "highlighted_text": "Hi, join Codeforces Round 184. ... Unsubscribe: https://codeforces.com/unsubscribe/...", | |
| "final_decision": "legitimate", | |
| "suggestion": "This message is safe. It is a legitimate notification from Codeforces." | |
| }} | |
| Example 3 - Legitimate (Long Formal Text): | |
| Message: "TATA MOTORS PASSENGER VEHICLES LIMITED... GENERAL GUIDANCE NOTE... [TRUNCATED]" | |
| URL Features: domain_age: 8414 | |
| Network Data: URL (cars.tatamotors.com): IP: 23.209.113.12, Location: Boardman, USA, ISP: Akamai Technologies | |
| Model Scores: All negative | |
| Correct Decision: {{ | |
| "confidence": 5.0, | |
| "reasoning": "This is a legitimate corporate communication. The text, although truncated, is clearly a formal guidance note for shareholders. The network data confirms 'cars.tatamotors.com' is hosted on Akamai, a major CDN used by large corporations. The models correctly identify this as safe.", | |
| "highlighted_text": "TATA MOTORS PASSENGER VEHICLES LIMITED... GENERAL GUIDANCE NOTE... [TRUNCATED]", | |
| "final_decision": "legitimate", | |
| "suggestion": "This message is a legitimate corporate communication and appears safe." | |
| }} | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| YOUR ANALYSIS TASK: | |
| Analyze the message data provided by the user (in the 'user' message) following the steps and logic outlined above. | |
| **CRITICAL for `highlighted_text`:** You MUST return the *entire original message*. Only wrap the specific words/URLs that are suspicious with `$$...$$`. If nothing is suspicious (i.e., `final_decision` is 'legitimate'), return the original message with NO `$$` markers. | |
| OUTPUT FORMAT (respond with ONLY this JSON, no markdown, no explanation): | |
| {{ | |
| "confidence": <float (0-100, directional score where >50 is phishing)>, | |
| "reasoning": "<your detailed analysis explaining why this is/isn't phishing, mentioning why you trust/override models>", | |
| "highlighted_text": "<THE FULL, ENTIRE original message with suspicious parts marked as $$suspicious text$$>", | |
| "final_decision": "phishing" or "legitimate", | |
| "suggestion": "<specific, actionable advice for the user on how to handle this message - what to do or not do>" | |
| }}""" | |
| async def get_groq_final_decision(urls: List[str], features_df: pd.DataFrame, | |
| message_text: str, predictions: Dict, | |
| original_text: str) -> Dict: | |
| if not groq_async_client: | |
| # --- MODIFIED: Fallback logic for confidence score --- | |
| # avg_scaled_score is from -50 (legit) to +50 (phishing) | |
| avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0 | |
| # We add 50 to shift the range to 0-100 | |
| confidence = min(100, max(0, 50 + avg_scaled_score)) | |
| final_decision = "phishing" if confidence > 50 else "legitimate" | |
| return { | |
| "confidence": round(confidence, 2), | |
| "reasoning": f"Groq API not available. Using average model scores. (Avg Scaled Score: {avg_scaled_score:.2f})", | |
| "highlighted_text": original_text, | |
| "final_decision": final_decision, | |
| "suggestion": "Do not interact with this message. Delete it immediately and report it to your IT department." if final_decision == "phishing" else "This message appears safe, but remain cautious with any links or attachments." | |
| } | |
| url_features_summary = "No URLs detected in message" | |
| if len(features_df) > 0: | |
| feature_summary_parts = [] | |
| for idx, row in features_df.iterrows(): | |
| url = row.get('url', 'Unknown') | |
| feature_summary_parts.append(f"\nURL {idx+1}: {url}") | |
| feature_summary_parts.append(f" • Length: {row.get('url_length', 'N/A')} chars") | |
| feature_summary_parts.append(f" • Dots in URL: {row.get('count_dot', 'N/A')}") | |
| feature_summary_parts.append(f" • Special characters: {row.get('count_special_chars', 'N/A')}") | |
| feature_summary_parts.append(f" • Domain age: {row.get('domain_age_days', 'N/A')} days") | |
| feature_summary_parts.append(f" • SSL certificate valid: {row.get('cert_has_valid_hostname', 'N/A')}") | |
| feature_summary_parts.append(f" • Uses HTTPS: {row.get('https', 'N/A')}") | |
| url_features_summary = "\n".join(feature_summary_parts) | |
| network_features_summary = await get_network_features_for_gemini(urls) | |
| model_predictions_summary = [] | |
| for model_name, pred_data in predictions.items(): | |
| scaled = pred_data['scaled_score'] # This is now -50 to +50 | |
| raw = pred_data['raw_score'] | |
| model_predictions_summary.append( | |
| f" • {model_name.upper()}: scaled_score={scaled:.2f} (raw={raw:.3f})" | |
| ) | |
| model_scores_text = "\n".join(model_predictions_summary) | |
| MAX_TEXT_LEN = 3000 | |
| if len(original_text) > MAX_TEXT_LEN: | |
| truncated_original_text = original_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]" | |
| else: | |
| truncated_original_text = original_text | |
| if len(message_text) > MAX_TEXT_LEN: | |
| truncated_message_text = message_text[:MAX_TEXT_LEN] + "\n... [TRUNCATED]" | |
| else: | |
| truncated_message_text = message_text | |
| # --- NEW: User prompt only contains dynamic data --- | |
| user_prompt = f"""MESSAGE DATA: | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| Original Message: | |
| {truncated_original_text} | |
| Cleaned Text: | |
| {truncated_message_text} | |
| URLs Found: {', '.join(urls) if urls else 'None'} | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| URL FEATURES (from ML models): | |
| {url_features_summary} | |
| INDEPENDENT NETWORK & GEO-DATA (for Gemini analysis only): | |
| {network_features_summary} | |
| MODEL PREDICTIONS: | |
| (Positive scaled scores → phishing, Negative → legitimate. Range: -50 to +50) | |
| {model_scores_text} | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| Please analyze this data and provide your JSON response.""" | |
| try: | |
| max_retries = 3 | |
| retry_delay = 2 | |
| for attempt in range(max_retries): | |
| try: | |
| # --- MODIFIED: API call now uses system and user roles --- | |
| chat_completion = await groq_async_client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": SYSTEM_PROMPT, | |
| }, | |
| { | |
| "role": "user", | |
| "content": user_prompt, | |
| } | |
| ], | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", # Using 8B for speed, can be 70b | |
| temperature=0.2, | |
| max_tokens=4096, | |
| top_p=0.85, | |
| response_format={"type": "json_object"}, | |
| ) | |
| response_text = chat_completion.choices[0].message.content | |
| break # Success | |
| except Exception as retry_error: | |
| print(f"Groq API attempt {attempt + 1} failed: {retry_error}") | |
| if attempt < max_retries - 1: | |
| print(f"Retrying in {retry_delay}s...") | |
| await asyncio.sleep(retry_delay) | |
| retry_delay *= 2 | |
| else: | |
| raise retry_error # Raise the final error | |
| result = json.loads(response_text) | |
| required_fields = ['confidence', 'reasoning', 'highlighted_text', 'final_decision', 'suggestion'] | |
| if not all(field in result for field in required_fields): | |
| raise ValueError(f"Missing required fields. Got: {list(result.keys())}") | |
| result['confidence'] = float(result['confidence']) | |
| if not 0 <= result['confidence'] <= 100: | |
| result['confidence'] = max(0, min(100, result['confidence'])) | |
| if result['final_decision'].lower() not in ['phishing', 'legitimate']: | |
| # --- MODIFIED: Decision is based on the directional confidence score --- | |
| result['final_decision'] = 'phishing' if result['confidence'] > 50 else 'legitimate' | |
| else: | |
| result['final_decision'] = result['final_decision'].lower() | |
| # --- MODIFIED: Check that confidence and decision match --- | |
| if result['final_decision'] == 'phishing' and result['confidence'] < 50: | |
| print(f"Warning: Groq decision 'phishing' mismatches confidence {result['confidence']}. Adjusting confidence.") | |
| result['confidence'] = 51.0 # Set to a default phishing score | |
| elif result['final_decision'] == 'legitimate' and result['confidence'] > 50: | |
| print(f"Warning: Groq decision 'legitimate' mismatches confidence {result['confidence']}. Adjusting confidence.") | |
| result['confidence'] = 49.0 # Set to a default legitimate score | |
| # --- Fallback for empty or truncated highlighted_text --- | |
| if not result['highlighted_text'].strip() or '...' in result['highlighted_text'] or 'TRUNCATED' in result['highlighted_text']: | |
| print("Warning: Groq returned empty or truncated 'highlighted_text'. Falling back to original_text.") | |
| result['highlighted_text'] = original_text | |
| if not result.get('suggestion', '').strip(): | |
| if result['final_decision'] == 'phishing': | |
| result['suggestion'] = "Do not interact with this message. Delete it immediately and report it as phishing." | |
| else: | |
| result['suggestion'] = "This message appears safe, but always verify sender identity before taking any action." | |
| return result | |
| except json.JSONDecodeError as e: | |
| print(f"JSON parsing error: {e}") | |
| print(f"Response text that failed parsing: {response_text[:500]}") | |
| # --- MODIFIED: Fallback logic for confidence score --- | |
| avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0 | |
| confidence = min(100, max(0, 50 + avg_scaled_score)) | |
| final_decision = "phishing" if confidence > 50 else "legitimate" | |
| return { | |
| "confidence": round(confidence, 2), | |
| "reasoning": f"Groq response parsing failed. Fallback: Based on model average (directional score: {confidence:.2f}), message appears {'suspicious' if final_decision == 'phishing' else 'legitimate'}.", | |
| "highlighted_text": original_text, | |
| "final_decision": final_decision, | |
| "suggestion": "Do not interact with this message. Delete it immediately and be cautious." if final_decision == 'phishing' else "Exercise caution. Verify the sender before taking any action." | |
| } | |
| except Exception as e: | |
| print(f"Error with Groq API: {e}") | |
| # --- MODIFIED: Fallback logic for confidence score --- | |
| avg_scaled_score = np.mean([p['scaled_score'] for p in predictions.values()]) if predictions else 0 | |
| confidence = min(100, max(0, 50 + avg_scaled_score)) | |
| final_decision = "phishing" if confidence > 50 else "legitimate" | |
| return { | |
| "confidence": round(confidence, 2), | |
| "reasoning": f"Groq API error: {str(e)}. Fallback decision based on {len(predictions)} model predictions (average directional score: {confidence:.2f}).", | |
| "highlighted_text": original_text, | |
| "final_decision": final_decision, | |
| "suggestion": "Treat this message with caution. Delete it if suspicious, or verify the sender through official channels before taking action." if final_decision == 'phishing' else "This message appears safe based on models, but always verify sender identity before clicking links or providing information." | |
| } | |
| # --- FastAPI Endpoints --- | |
| async def startup_event(): | |
| load_models() | |
| print("\n" + "="*60) | |
| print("Phishing Detection API is ready!") | |
| print("="*60) | |
| print("API Documentation: http://localhost:8000/docs") | |
| print("="*60 + "\n") | |
| async def root(): | |
| return { | |
| "message": "Phishing Detection API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "predict": "/predict (POST)", | |
| "health": "/health (GET)", | |
| "docs": "/docs (GET)" | |
| } | |
| } | |
| async def health_check(): | |
| models_loaded = { | |
| "ml_models": list(ml_models.keys()), | |
| "dl_models": list(dl_models.keys()), | |
| "bert_model": bert_model is not None, | |
| "semantic_model": semantic_model is not None, | |
| "groq_client": groq_async_client is not None | |
| } | |
| return { | |
| "status": "healthy", | |
| "models_loaded": models_loaded | |
| } | |
| async def predict(message_input: MessageInput): | |
| try: | |
| original_text = message_input.text | |
| if not original_text or not original_text.strip(): | |
| raise HTTPException(status_code=400, detail="Message text cannot be empty") | |
| urls, cleaned_text = parse_message(original_text) | |
| features_df = pd.DataFrame() | |
| if urls: | |
| features_df = await extract_url_features(urls) | |
| predictions = {} | |
| if len(features_df) > 0 or (cleaned_text and semantic_model): | |
| # --- MODIFIED: Run this in a thread to avoid blocking --- | |
| predictions = await asyncio.to_thread(get_model_predictions, features_df, cleaned_text) | |
| if not predictions: | |
| if not urls and not cleaned_text: | |
| detail = "Message text is empty after cleaning." | |
| elif not urls and not semantic_model: | |
| detail = "No URLs provided and semantic model is not loaded." | |
| elif not any([ml_models, dl_models, bert_model, semantic_model]): | |
| detail = "No models available for prediction. Please ensure models are trained and loaded." | |
| else: | |
| detail = "Could not generate predictions. Models may be missing or feature extraction failed." | |
| raise HTTPException( | |
| status_code=500, | |
| detail=detail | |
| ) | |
| final_result = await get_groq_final_decision( | |
| urls, features_df, cleaned_text, predictions, original_text | |
| ) | |
| return PredictionResponse(**final_result) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |