# ============================================================ # PhishGuard AI - gnn/train_gnn.py # Full GNN training script. # # Downloads PhishTank bz2 + TRANCO zip + Kaggle CSV mirror # Builds training graphs, 40 epochs, saves gnn_weights.pt # 70/15/15 train/val/test split with stratification # Saves replay buffer to gnn_replay_buffer.pt # ============================================================ from __future__ import annotations import sys import random import logging from pathlib import Path from typing import List, Tuple logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)-7s | %(message)s", ) logger = logging.getLogger("phishguard.gnn.train") # Paths GNN_DIR = Path(__file__).parent BACKEND_DIR = GNN_DIR.parent WEIGHTS_PATH = GNN_DIR / "gnn_weights.pt" REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "gnn_replay_buffer.pt" # Add backend to path for imports sys.path.insert(0, str(BACKEND_DIR)) sys.path.insert(0, str(GNN_DIR)) def main() -> None: print("=" * 60) print("PhishGuard AI — GNN Training") print("=" * 60) import torch import torch.nn.functional as F from sklearn.metrics import accuracy_score, precision_recall_fscore_support from domain_graph_builder import DomainGraphBuilder from gnn_model import PhishGNN, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM # ── Download data ──────────────────────────────────────────── from data_collector import download_phishtank, download_tranco, merge_datasets print("\n📥 Downloading datasets...") phish_urls = download_phishtank(max_urls=50) legit_urls = download_tranco(n=50) print(f" Phishing URLs: {len(phish_urls)}") print(f" Legitimate URLs: {len(legit_urls)}") train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls) # ── Build graphs ───────────────────────────────────────────── builder = DomainGraphBuilder() CHUNK_SIZE = 4 # Group URLs into small graphs def build_dataset(data: List[Tuple[str, int]], desc: str) -> list: """Build graph dataset from (url, label) pairs.""" dataset = [] # Separate by label phish = [url for url, label in data if label == 1] legit = [url for url, label in data if label == 0] for urls, label in [(phish, 1), (legit, 0)]: for i in range(0, len(urls), CHUNK_SIZE): chunk = urls[i : i + CHUNK_SIZE] if not chunk: continue graph = builder.build_graph(chunk) x = torch.tensor(graph["features"], dtype=torch.float) edges = graph["edges"] if edges and len(edges) > 0: edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() else: # Self-loops for graphs with no edges n = x.size(0) edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1) dataset.append({ "x": x, "edge_index": edge_index, "y": torch.tensor([float(label)]), }) random.shuffle(dataset) print(f" {desc}: {len(dataset)} graphs") return dataset print("\n🔨 Building graphs...") train_graphs = build_dataset(train_data, "Train") val_graphs = build_dataset(val_data, "Val") test_graphs = build_dataset(test_data, "Test") # ── Create model ───────────────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\n🤖 Device: {device}") model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP() model = model.to(device) model_type = "GCN" if PYGEOM_AVAILABLE else "MLP" print(f" Model: Phish{model_type}") print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") # ── Training ───────────────────────────────────────────────── EPOCHS = 2 LR = 0.001 WEIGHT_DECAY = 1e-4 optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6, ) loss_fn = F.binary_cross_entropy best_val_acc = 0.0 best_epoch = 0 print(f"\n🏋️ Training for {EPOCHS} epochs...") print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7} | {'LR':>10}") print(f" {'─' * 5} | {'─' * 8} | {'─' * 9} | {'─' * 7} | {'─' * 10}") for epoch in range(1, EPOCHS + 1): # ── Train ──────────────────────────────────────────────── model.train() total_loss = 0.0 train_preds = [] train_labels = [] random.shuffle(train_graphs) for item in train_graphs: x = item["x"].to(device) ei = item["edge_index"].to(device) y = item["y"].to(device) optimizer.zero_grad() out = model(x, ei) loss = loss_fn(out.squeeze(), y.squeeze()) loss.backward() optimizer.step() total_loss += loss.item() pred = 1 if out.squeeze().item() >= 0.5 else 0 train_preds.append(pred) train_labels.append(int(y.item())) avg_loss = total_loss / max(len(train_graphs), 1) train_acc = accuracy_score(train_labels, train_preds) # ── Validate ───────────────────────────────────────────── model.eval() val_preds = [] val_labels = [] with torch.no_grad(): for item in val_graphs: x = item["x"].to(device) ei = item["edge_index"].to(device) y = item["y"].to(device) out = model(x, ei) pred = 1 if out.squeeze().item() >= 0.5 else 0 val_preds.append(pred) val_labels.append(int(y.item())) val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0 scheduler.step(avg_loss) current_lr = optimizer.param_groups[0]["lr"] # Print progress if epoch % 5 == 0 or epoch == 1: print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f} | {current_lr:>10.6f}") # Save best model if val_acc > best_val_acc: best_val_acc = val_acc best_epoch = epoch torch.save(model.state_dict(), WEIGHTS_PATH) print(f"\n Best val accuracy: {best_val_acc:.4f} at epoch {best_epoch}") # ── Test ───────────────────────────────────────────────────── # Reload best weights model.load_state_dict( torch.load(WEIGHTS_PATH, map_location=device, weights_only=True) ) model.eval() test_preds = [] test_labels = [] with torch.no_grad(): for item in test_graphs: x = item["x"].to(device) ei = item["edge_index"].to(device) y = item["y"].to(device) out = model(x, ei) pred = 1 if out.squeeze().item() >= 0.5 else 0 test_preds.append(pred) test_labels.append(int(y.item())) test_acc = accuracy_score(test_labels, test_preds) if test_labels else 0.0 precision, recall, f1, _ = precision_recall_fscore_support( test_labels, test_preds, average="binary", zero_division=0, ) print(f"\n📊 Test Results:") print(f" Accuracy: {test_acc:.4f}") print(f" Precision: {precision:.4f}") print(f" Recall: {recall:.4f}") print(f" F1 Score: {f1:.4f}") # ── Save replay buffer ─────────────────────────────────────── REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True) replay_buffer = train_graphs[:500] # Keep last 500 samples torch.save(replay_buffer, REPLAY_BUFFER_PATH) print(f"\n💾 Replay buffer saved: {len(replay_buffer)} samples → {REPLAY_BUFFER_PATH}") print(f"\n✅ GNN weights saved to: {WEIGHTS_PATH}") print("=" * 60) if __name__ == "__main__": main()