# ============================================================ # PhishGuard AI - bert_finetune.py # Full BERT fine-tuning script on PhishTank + TRANCO data # # Downloads data, fine-tunes ealvaradob/bert-finetuned-phishing # 3 epochs, AdamW + linear warmup scheduler # Saves to bert_weights/ with save_pretrained() # Prints per-epoch: loss / precision / recall / F1 # ============================================================ from __future__ import annotations import logging import sys from pathlib import Path from typing import List, Tuple logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)-7s | %(message)s", ) logger = logging.getLogger("phishguard.bert_finetune") BASE_DIR = Path(__file__).parent BERT_WEIGHTS_DIR = BASE_DIR / "bert_weights" def main() -> None: """Fine-tune BERT on PhishTank + TRANCO URLs.""" print("=" * 60) print("PhishGuard AI — BERT Fine-Tuning") print("=" * 60) # ── Check dependencies ─────────────────────────────────────── try: import torch from torch.utils.data import DataLoader, Dataset from torch.optim import AdamW from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup, ) from sklearn.metrics import precision_recall_fscore_support except ImportError as e: print(f"❌ Missing dependency: {e}") print(" Run: pip install torch transformers scikit-learn") sys.exit(1) # ── Download data ──────────────────────────────────────────── from data_collector import download_phishtank, download_tranco, merge_datasets print("\n📥 Downloading datasets...") phish_urls = download_phishtank(max_urls=50) legit_urls = download_tranco(n=50) print(f" Phishing URLs: {len(phish_urls)}") print(f" Legitimate URLs: {len(legit_urls)}") train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls) # ── URL tokenization ───────────────────────────────────────── import re _re_url_split = re.compile(r"[-./=?&_~%@:]+") def tokenize_url(url: str) -> str: text = url.replace("https://", "").replace("http://", "") tokens = _re_url_split.split(text) return " ".join(t for t in tokens if t) # ── Dataset class ──────────────────────────────────────────── class PhishingURLDataset(Dataset): def __init__(self, data: List[Tuple[str, int]], tokenizer, max_length: int = 512): self.data = data self.tokenizer = tokenizer self.max_length = max_length def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int): url, label = self.data[idx] text = f"URL: {tokenize_url(url)}" encoding = self.tokenizer( text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt", ) return { "input_ids": encoding["input_ids"].squeeze(0), "attention_mask": encoding["attention_mask"].squeeze(0), "labels": torch.tensor(label, dtype=torch.long), } # ── Load model ─────────────────────────────────────────────── MODEL_NAME = "ealvaradob/bert-finetuned-phishing" FALLBACK = "mrm8488/bert-tiny-finetuned-sms-spam-detection" print("\n🤖 Loading BERT model...") tokenizer = None model = None for model_id in [MODEL_NAME, FALLBACK]: try: tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSequenceClassification.from_pretrained( model_id, num_labels=2 ) print(f" ✅ Loaded: {model_id}") break except Exception as e: print(f" ⚠️ {model_id} failed: {e}") continue if model is None or tokenizer is None: print("❌ Could not load any BERT model. Exiting.") sys.exit(1) # ── Prepare data ───────────────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f" Device: {device}") train_dataset = PhishingURLDataset(train_data, tokenizer) val_dataset = PhishingURLDataset(val_data, tokenizer) train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32) model = model.to(device) # ── Optimizer + Scheduler ──────────────────────────────────── EPOCHS = 1 optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) total_steps = len(train_loader) * EPOCHS scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=total_steps // 10, num_training_steps=total_steps, ) # ── Training Loop ──────────────────────────────────────────── print(f"\n🏋️ Training for {EPOCHS} epochs...") print(f" Train batches: {len(train_loader)}") print(f" Val batches: {len(val_loader)}") best_f1 = 0.0 for epoch in range(1, EPOCHS + 1): # Train model.train() total_loss = 0.0 train_preds = [] train_labels = [] for batch_idx, batch in enumerate(train_loader): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) optimizer.zero_grad() outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() scheduler.step() total_loss += loss.item() preds = torch.argmax(outputs.logits, dim=1) train_preds.extend(preds.cpu().tolist()) train_labels.extend(labels.cpu().tolist()) if (batch_idx + 1) % 50 == 0: print(f" Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}") avg_loss = total_loss / len(train_loader) # Validate model.eval() val_preds = [] val_labels = [] with torch.no_grad(): for batch in val_loader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) preds = torch.argmax(outputs.logits, dim=1) val_preds.extend(preds.cpu().tolist()) val_labels.extend(labels.cpu().tolist()) precision, recall, f1, _ = precision_recall_fscore_support( val_labels, val_preds, average="binary", zero_division=0 ) print(f"\n 📊 Epoch {epoch}/{EPOCHS}:") print(f" Loss: {avg_loss:.4f}") print(f" Precision: {precision:.4f}") print(f" Recall: {recall:.4f}") print(f" F1 Score: {f1:.4f}") # Save best model if f1 > best_f1: best_f1 = f1 BERT_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True) model.save_pretrained(str(BERT_WEIGHTS_DIR)) tokenizer.save_pretrained(str(BERT_WEIGHTS_DIR)) print(f" ✅ New best model saved to {BERT_WEIGHTS_DIR}") print(f"\n🎯 Best F1: {best_f1:.4f}") print(f"✅ Fine-tuning complete. Weights saved to: {BERT_WEIGHTS_DIR}") print("=" * 60) if __name__ == "__main__": main()