Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - gnn/train_gnn.py | |
| # Full GNN training script. | |
| # | |
| # Downloads PhishTank bz2 + TRANCO zip + Kaggle CSV mirror | |
| # Builds training graphs, 40 epochs, saves gnn_weights.pt | |
| # 70/15/15 train/val/test split with stratification | |
| # Saves replay buffer to gnn_replay_buffer.pt | |
| # ============================================================ | |
| from __future__ import annotations | |
| import sys | |
| import random | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)-7s | %(message)s", | |
| ) | |
| logger = logging.getLogger("phishguard.gnn.train") | |
| # Paths | |
| GNN_DIR = Path(__file__).parent | |
| BACKEND_DIR = GNN_DIR.parent | |
| WEIGHTS_PATH = GNN_DIR / "gnn_weights.pt" | |
| REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "gnn_replay_buffer.pt" | |
| # Add backend to path for imports | |
| sys.path.insert(0, str(BACKEND_DIR)) | |
| sys.path.insert(0, str(GNN_DIR)) | |
| def main() -> None: | |
| print("=" * 60) | |
| print("PhishGuard AI β GNN Training") | |
| print("=" * 60) | |
| import torch | |
| import torch.nn.functional as F | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
| from domain_graph_builder import DomainGraphBuilder | |
| from gnn_model import PhishGNN, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM | |
| # ββ Download data ββββββββββββββββββββββββββββββββββββββββββββ | |
| from data_collector import download_phishtank, download_tranco, merge_datasets | |
| print("\nπ₯ Downloading datasets...") | |
| phish_urls = download_phishtank(max_urls=50) | |
| legit_urls = download_tranco(n=50) | |
| print(f" Phishing URLs: {len(phish_urls)}") | |
| print(f" Legitimate URLs: {len(legit_urls)}") | |
| train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls) | |
| # ββ Build graphs βββββββββββββββββββββββββββββββββββββββββββββ | |
| builder = DomainGraphBuilder() | |
| CHUNK_SIZE = 4 # Group URLs into small graphs | |
| def build_dataset(data: List[Tuple[str, int]], desc: str) -> list: | |
| """Build graph dataset from (url, label) pairs.""" | |
| dataset = [] | |
| # Separate by label | |
| phish = [url for url, label in data if label == 1] | |
| legit = [url for url, label in data if label == 0] | |
| for urls, label in [(phish, 1), (legit, 0)]: | |
| for i in range(0, len(urls), CHUNK_SIZE): | |
| chunk = urls[i : i + CHUNK_SIZE] | |
| if not chunk: | |
| continue | |
| graph = builder.build_graph(chunk) | |
| x = torch.tensor(graph["features"], dtype=torch.float) | |
| edges = graph["edges"] | |
| if edges and len(edges) > 0: | |
| edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() | |
| else: | |
| # Self-loops for graphs with no edges | |
| n = x.size(0) | |
| edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1) | |
| dataset.append({ | |
| "x": x, | |
| "edge_index": edge_index, | |
| "y": torch.tensor([float(label)]), | |
| }) | |
| random.shuffle(dataset) | |
| print(f" {desc}: {len(dataset)} graphs") | |
| return dataset | |
| print("\nπ¨ Building graphs...") | |
| train_graphs = build_dataset(train_data, "Train") | |
| val_graphs = build_dataset(val_data, "Val") | |
| test_graphs = build_dataset(test_data, "Test") | |
| # ββ Create model βββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"\nπ€ Device: {device}") | |
| model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP() | |
| model = model.to(device) | |
| model_type = "GCN" if PYGEOM_AVAILABLE else "MLP" | |
| print(f" Model: Phish{model_type}") | |
| print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # ββ Training βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| EPOCHS = 2 | |
| LR = 0.001 | |
| WEIGHT_DECAY = 1e-4 | |
| optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6, | |
| ) | |
| loss_fn = F.binary_cross_entropy | |
| best_val_acc = 0.0 | |
| best_epoch = 0 | |
| print(f"\nποΈ Training for {EPOCHS} epochs...") | |
| print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7} | {'LR':>10}") | |
| print(f" {'β' * 5} | {'β' * 8} | {'β' * 9} | {'β' * 7} | {'β' * 10}") | |
| for epoch in range(1, EPOCHS + 1): | |
| # ββ Train ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model.train() | |
| total_loss = 0.0 | |
| train_preds = [] | |
| train_labels = [] | |
| random.shuffle(train_graphs) | |
| for item in train_graphs: | |
| x = item["x"].to(device) | |
| ei = item["edge_index"].to(device) | |
| y = item["y"].to(device) | |
| optimizer.zero_grad() | |
| out = model(x, ei) | |
| loss = loss_fn(out.squeeze(), y.squeeze()) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| pred = 1 if out.squeeze().item() >= 0.5 else 0 | |
| train_preds.append(pred) | |
| train_labels.append(int(y.item())) | |
| avg_loss = total_loss / max(len(train_graphs), 1) | |
| train_acc = accuracy_score(train_labels, train_preds) | |
| # ββ Validate βββββββββββββββββββββββββββββββββββββββββββββ | |
| model.eval() | |
| val_preds = [] | |
| val_labels = [] | |
| with torch.no_grad(): | |
| for item in val_graphs: | |
| x = item["x"].to(device) | |
| ei = item["edge_index"].to(device) | |
| y = item["y"].to(device) | |
| out = model(x, ei) | |
| pred = 1 if out.squeeze().item() >= 0.5 else 0 | |
| val_preds.append(pred) | |
| val_labels.append(int(y.item())) | |
| val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0 | |
| scheduler.step(avg_loss) | |
| current_lr = optimizer.param_groups[0]["lr"] | |
| # Print progress | |
| if epoch % 5 == 0 or epoch == 1: | |
| print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f} | {current_lr:>10.6f}") | |
| # Save best model | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| best_epoch = epoch | |
| torch.save(model.state_dict(), WEIGHTS_PATH) | |
| print(f"\n Best val accuracy: {best_val_acc:.4f} at epoch {best_epoch}") | |
| # ββ Test βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reload best weights | |
| model.load_state_dict( | |
| torch.load(WEIGHTS_PATH, map_location=device, weights_only=True) | |
| ) | |
| model.eval() | |
| test_preds = [] | |
| test_labels = [] | |
| with torch.no_grad(): | |
| for item in test_graphs: | |
| x = item["x"].to(device) | |
| ei = item["edge_index"].to(device) | |
| y = item["y"].to(device) | |
| out = model(x, ei) | |
| pred = 1 if out.squeeze().item() >= 0.5 else 0 | |
| test_preds.append(pred) | |
| test_labels.append(int(y.item())) | |
| test_acc = accuracy_score(test_labels, test_preds) if test_labels else 0.0 | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| test_labels, test_preds, average="binary", zero_division=0, | |
| ) | |
| print(f"\nπ Test Results:") | |
| print(f" Accuracy: {test_acc:.4f}") | |
| print(f" Precision: {precision:.4f}") | |
| print(f" Recall: {recall:.4f}") | |
| print(f" F1 Score: {f1:.4f}") | |
| # ββ Save replay buffer βββββββββββββββββββββββββββββββββββββββ | |
| REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| replay_buffer = train_graphs[:500] # Keep last 500 samples | |
| torch.save(replay_buffer, REPLAY_BUFFER_PATH) | |
| print(f"\nπΎ Replay buffer saved: {len(replay_buffer)} samples β {REPLAY_BUFFER_PATH}") | |
| print(f"\nβ GNN weights saved to: {WEIGHTS_PATH}") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |