phishguard-api / train_gnn.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - gnn/train_gnn.py
# Full GNN training script.
#
# Downloads PhishTank bz2 + TRANCO zip + Kaggle CSV mirror
# Builds training graphs, 40 epochs, saves gnn_weights.pt
# 70/15/15 train/val/test split with stratification
# Saves replay buffer to gnn_replay_buffer.pt
# ============================================================
from __future__ import annotations
import sys
import random
import logging
from pathlib import Path
from typing import List, Tuple
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-7s | %(message)s",
)
logger = logging.getLogger("phishguard.gnn.train")
# Paths
GNN_DIR = Path(__file__).parent
BACKEND_DIR = GNN_DIR.parent
WEIGHTS_PATH = GNN_DIR / "gnn_weights.pt"
REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "gnn_replay_buffer.pt"
# Add backend to path for imports
sys.path.insert(0, str(BACKEND_DIR))
sys.path.insert(0, str(GNN_DIR))
def main() -> None:
print("=" * 60)
print("PhishGuard AI β€” GNN Training")
print("=" * 60)
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from domain_graph_builder import DomainGraphBuilder
from gnn_model import PhishGNN, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM
# ── Download data ────────────────────────────────────────────
from data_collector import download_phishtank, download_tranco, merge_datasets
print("\nπŸ“₯ Downloading datasets...")
phish_urls = download_phishtank(max_urls=50)
legit_urls = download_tranco(n=50)
print(f" Phishing URLs: {len(phish_urls)}")
print(f" Legitimate URLs: {len(legit_urls)}")
train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls)
# ── Build graphs ─────────────────────────────────────────────
builder = DomainGraphBuilder()
CHUNK_SIZE = 4 # Group URLs into small graphs
def build_dataset(data: List[Tuple[str, int]], desc: str) -> list:
"""Build graph dataset from (url, label) pairs."""
dataset = []
# Separate by label
phish = [url for url, label in data if label == 1]
legit = [url for url, label in data if label == 0]
for urls, label in [(phish, 1), (legit, 0)]:
for i in range(0, len(urls), CHUNK_SIZE):
chunk = urls[i : i + CHUNK_SIZE]
if not chunk:
continue
graph = builder.build_graph(chunk)
x = torch.tensor(graph["features"], dtype=torch.float)
edges = graph["edges"]
if edges and len(edges) > 0:
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
else:
# Self-loops for graphs with no edges
n = x.size(0)
edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1)
dataset.append({
"x": x,
"edge_index": edge_index,
"y": torch.tensor([float(label)]),
})
random.shuffle(dataset)
print(f" {desc}: {len(dataset)} graphs")
return dataset
print("\nπŸ”¨ Building graphs...")
train_graphs = build_dataset(train_data, "Train")
val_graphs = build_dataset(val_data, "Val")
test_graphs = build_dataset(test_data, "Test")
# ── Create model ─────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nπŸ€– Device: {device}")
model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP()
model = model.to(device)
model_type = "GCN" if PYGEOM_AVAILABLE else "MLP"
print(f" Model: Phish{model_type}")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# ── Training ─────────────────────────────────────────────────
EPOCHS = 2
LR = 0.001
WEIGHT_DECAY = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6,
)
loss_fn = F.binary_cross_entropy
best_val_acc = 0.0
best_epoch = 0
print(f"\nπŸ‹οΈ Training for {EPOCHS} epochs...")
print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7} | {'LR':>10}")
print(f" {'─' * 5} | {'─' * 8} | {'─' * 9} | {'─' * 7} | {'─' * 10}")
for epoch in range(1, EPOCHS + 1):
# ── Train ────────────────────────────────────────────────
model.train()
total_loss = 0.0
train_preds = []
train_labels = []
random.shuffle(train_graphs)
for item in train_graphs:
x = item["x"].to(device)
ei = item["edge_index"].to(device)
y = item["y"].to(device)
optimizer.zero_grad()
out = model(x, ei)
loss = loss_fn(out.squeeze(), y.squeeze())
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = 1 if out.squeeze().item() >= 0.5 else 0
train_preds.append(pred)
train_labels.append(int(y.item()))
avg_loss = total_loss / max(len(train_graphs), 1)
train_acc = accuracy_score(train_labels, train_preds)
# ── Validate ─────────────────────────────────────────────
model.eval()
val_preds = []
val_labels = []
with torch.no_grad():
for item in val_graphs:
x = item["x"].to(device)
ei = item["edge_index"].to(device)
y = item["y"].to(device)
out = model(x, ei)
pred = 1 if out.squeeze().item() >= 0.5 else 0
val_preds.append(pred)
val_labels.append(int(y.item()))
val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0
scheduler.step(avg_loss)
current_lr = optimizer.param_groups[0]["lr"]
# Print progress
if epoch % 5 == 0 or epoch == 1:
print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f} | {current_lr:>10.6f}")
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
best_epoch = epoch
torch.save(model.state_dict(), WEIGHTS_PATH)
print(f"\n Best val accuracy: {best_val_acc:.4f} at epoch {best_epoch}")
# ── Test ─────────────────────────────────────────────────────
# Reload best weights
model.load_state_dict(
torch.load(WEIGHTS_PATH, map_location=device, weights_only=True)
)
model.eval()
test_preds = []
test_labels = []
with torch.no_grad():
for item in test_graphs:
x = item["x"].to(device)
ei = item["edge_index"].to(device)
y = item["y"].to(device)
out = model(x, ei)
pred = 1 if out.squeeze().item() >= 0.5 else 0
test_preds.append(pred)
test_labels.append(int(y.item()))
test_acc = accuracy_score(test_labels, test_preds) if test_labels else 0.0
precision, recall, f1, _ = precision_recall_fscore_support(
test_labels, test_preds, average="binary", zero_division=0,
)
print(f"\nπŸ“Š Test Results:")
print(f" Accuracy: {test_acc:.4f}")
print(f" Precision: {precision:.4f}")
print(f" Recall: {recall:.4f}")
print(f" F1 Score: {f1:.4f}")
# ── Save replay buffer ───────────────────────────────────────
REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True)
replay_buffer = train_graphs[:500] # Keep last 500 samples
torch.save(replay_buffer, REPLAY_BUFFER_PATH)
print(f"\nπŸ’Ύ Replay buffer saved: {len(replay_buffer)} samples β†’ {REPLAY_BUFFER_PATH}")
print(f"\nβœ… GNN weights saved to: {WEIGHTS_PATH}")
print("=" * 60)
if __name__ == "__main__":
main()