Spaces:
Runtime error
Runtime error
| """ | |
| Test Server Predictions Against Dataset | |
| Validates server API predictions vs actual labels | |
| """ | |
| import pandas as pd | |
| import requests | |
| from pathlib import Path | |
| import logging | |
| from tqdm import tqdm | |
| import time | |
| from sklearn.metrics import ( | |
| accuracy_score, precision_score, recall_score, f1_score, | |
| confusion_matrix, classification_report | |
| ) | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%H:%M:%S' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ServerTester: | |
| """Test phishing detection server against dataset""" | |
| def __init__(self, server_url='http://localhost:8000', batch_size=100): | |
| self.server_url = server_url | |
| self.batch_size = batch_size | |
| self.results = [] | |
| def check_server_health(self): | |
| """Check if server is running""" | |
| try: | |
| response = requests.get(f"{self.server_url}/api/health", timeout=5) | |
| if response.status_code == 200: | |
| health = response.json() | |
| logger.info(f"✓ Server is healthy") | |
| logger.info(f" URL models: {health.get('url_models', 0)}") | |
| logger.info(f" HTML models: {health.get('html_models', 0)}") | |
| return True | |
| else: | |
| logger.error(f"Server health check failed: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Cannot connect to server: {e}") | |
| logger.error(f"Make sure server is running: python server/app.py") | |
| return False | |
| def predict_url(self, url): | |
| """Get prediction from server for a URL""" | |
| try: | |
| response = requests.post( | |
| f"{self.server_url}/api/predict/url", | |
| json={"url": url}, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| return { | |
| 'predicted': 1 if result['is_phishing'] else 0, | |
| 'consensus': result['consensus'], | |
| 'predictions': result['predictions'] | |
| } | |
| else: | |
| logger.warning(f"Server error for {url}: {response.status_code}") | |
| return None | |
| except Exception as e: | |
| logger.warning(f"Request error for {url}: {e}") | |
| return None | |
| def test_dataset(self, dataset_path, limit=None, sample_frac=None): | |
| """ | |
| Test server predictions against dataset. | |
| Args: | |
| dataset_path: Path to CSV with 'url' and 'label' columns | |
| limit: Maximum number of URLs to test (None = all) | |
| sample_frac: Random sample fraction (e.g., 0.1 = 10%) | |
| """ | |
| logger.info("="*80) | |
| logger.info("SERVER PREDICTION TESTING") | |
| logger.info("="*80) | |
| # Load dataset | |
| logger.info(f"\n1. Loading dataset: {dataset_path}") | |
| df = pd.read_csv(dataset_path) | |
| # Ensure we have required columns | |
| if 'label' not in df.columns: | |
| # Assume first column is URL, second is label | |
| df.columns = ['url', 'label'] | |
| logger.info(f" Total URLs: {len(df):,}") | |
| logger.info(f" Phishing: {(df['label']==1).sum():,}") | |
| logger.info(f" Legitimate: {(df['label']==0).sum():,}") | |
| # Sample if requested | |
| if sample_frac: | |
| df = df.sample(frac=sample_frac, random_state=42) | |
| logger.info(f"\n Sampled {sample_frac*100:.1f}%: {len(df):,} URLs") | |
| # Limit if requested | |
| if limit and limit < len(df): | |
| df = df.head(limit) | |
| logger.info(f" Limited to: {limit:,} URLs") | |
| # Check server | |
| logger.info("\n2. Checking server health...") | |
| if not self.check_server_health(): | |
| return None | |
| # Test predictions | |
| logger.info("\n3. Testing predictions...") | |
| y_true = [] | |
| y_pred = [] | |
| errors = 0 | |
| for idx, row in tqdm(df.iterrows(), total=len(df), desc="Testing URLs"): | |
| url = row['url'] if 'url' in row else row.iloc[0] | |
| true_label = int(row['label']) if 'label' in row else int(row.iloc[1]) | |
| # Get prediction | |
| result = self.predict_url(url) | |
| if result: | |
| y_true.append(true_label) | |
| y_pred.append(result['predicted']) | |
| self.results.append({ | |
| 'url': url, | |
| 'true_label': true_label, | |
| 'predicted_label': result['predicted'], | |
| 'consensus': result['consensus'], | |
| 'correct': true_label == result['predicted'] | |
| }) | |
| else: | |
| errors += 1 | |
| # Rate limiting | |
| time.sleep(0.01) # 10ms delay between requests | |
| logger.info(f"\n Processed: {len(y_pred):,} URLs") | |
| if errors > 0: | |
| logger.warning(f" Errors: {errors:,}") | |
| # Calculate metrics | |
| self._display_results(y_true, y_pred) | |
| return { | |
| 'y_true': y_true, | |
| 'y_pred': y_pred, | |
| 'results': self.results | |
| } | |
| def _display_results(self, y_true, y_pred): | |
| """Display test results and metrics""" | |
| logger.info("\n" + "="*80) | |
| logger.info("TEST RESULTS") | |
| logger.info("="*80) | |
| # Calculate metrics | |
| accuracy = accuracy_score(y_true, y_pred) | |
| precision = precision_score(y_true, y_pred) | |
| recall = recall_score(y_true, y_pred) | |
| f1 = f1_score(y_true, y_pred) | |
| logger.info(f"\nOverall Metrics:") | |
| logger.info(f" Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)") | |
| logger.info(f" Precision: {precision:.4f} ({precision*100:.2f}%)") | |
| logger.info(f" Recall: {recall:.4f} ({recall*100:.2f}%)") | |
| logger.info(f" F1-Score: {f1:.4f} ({f1*100:.2f}%)") | |
| # Confusion matrix | |
| cm = confusion_matrix(y_true, y_pred) | |
| tn, fp, fn, tp = cm.ravel() | |
| logger.info(f"\nConfusion Matrix:") | |
| logger.info(f" Predicted") | |
| logger.info(f" Legit Phish") | |
| logger.info(f"Actual Legit {tn:6,} {fp:6,}") | |
| logger.info(f" Phish {fn:6,} {tp:6,}") | |
| logger.info(f"\nError Analysis:") | |
| logger.info(f" True Negatives: {tn:,} (correctly identified legitimate)") | |
| logger.info(f" True Positives: {tp:,} (correctly identified phishing)") | |
| logger.info(f" False Positives: {fp:,} ({fp/(tn+fp)*100:.2f}% of legitimate marked as phishing)") | |
| logger.info(f" False Negatives: {fn:,} ({fn/(tp+fn)*100:.2f}% of phishing marked as legitimate) ⚠️") | |
| # Classification report | |
| logger.info(f"\nDetailed Classification Report:") | |
| logger.info(classification_report( | |
| y_true, y_pred, | |
| target_names=['Legitimate', 'Phishing'], | |
| digits=4 | |
| )) | |
| def save_results(self, output_path): | |
| """Save test results to CSV""" | |
| if not self.results: | |
| logger.warning("No results to save") | |
| return | |
| df = pd.DataFrame(self.results) | |
| df.to_csv(output_path, index=False) | |
| logger.info(f"\n✓ Results saved: {output_path}") | |
| logger.info(f" Total: {len(df):,} predictions") | |
| logger.info(f" Correct: {df['correct'].sum():,} ({df['correct'].mean()*100:.2f}%)") | |
| logger.info(f" Incorrect: {(~df['correct']).sum():,}") | |
| def main(): | |
| """Main testing function""" | |
| # Paths | |
| dataset_path = Path('data/processed/mega_dataset_full_912357.csv') | |
| output_path = Path('results/server_test_results.csv') | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Check dataset exists | |
| if not dataset_path.exists(): | |
| logger.error(f"Dataset not found: {dataset_path}") | |
| logger.info("Available datasets:") | |
| for csv_file in Path('data/processed').glob('*.csv'): | |
| logger.info(f" - {csv_file}") | |
| return | |
| # Create tester | |
| tester = ServerTester(server_url='http://localhost:8000') | |
| # Test with sample (10% of dataset for quick test) | |
| logger.info("\nTesting with 10% sample for quick validation...") | |
| logger.info("(Use sample_frac=1.0 or remove it to test full dataset)") | |
| results = tester.test_dataset( | |
| dataset_path, | |
| # sample_frac=0.1 # 0.1 for 10% sample (91k URLs) 1.0 for full dataset | |
| limit=1000 # Or use limit for exact number | |
| ) | |
| if results: | |
| # Save results | |
| tester.save_results(output_path) | |
| logger.info("\n" + "="*80) | |
| logger.info("✓ SERVER TESTING COMPLETE!") | |
| logger.info("="*80) | |
| logger.info(f"\nResults saved to: {output_path}") | |
| logger.info("\nTo test full dataset, change sample_frac=1.0") | |
| if __name__ == '__main__': | |
| main() | |