phishguard-api / bert_finetune.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - bert_finetune.py
# Full BERT fine-tuning script on PhishTank + TRANCO data
#
# Downloads data, fine-tunes ealvaradob/bert-finetuned-phishing
# 3 epochs, AdamW + linear warmup scheduler
# Saves to bert_weights/ with save_pretrained()
# Prints per-epoch: loss / precision / recall / F1
# ============================================================
from __future__ import annotations
import logging
import sys
from pathlib import Path
from typing import List, Tuple
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-7s | %(message)s",
)
logger = logging.getLogger("phishguard.bert_finetune")
BASE_DIR = Path(__file__).parent
BERT_WEIGHTS_DIR = BASE_DIR / "bert_weights"
def main() -> None:
"""Fine-tune BERT on PhishTank + TRANCO URLs."""
print("=" * 60)
print("PhishGuard AI β€” BERT Fine-Tuning")
print("=" * 60)
# ── Check dependencies ───────────────────────────────────────
try:
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
get_linear_schedule_with_warmup,
)
from sklearn.metrics import precision_recall_fscore_support
except ImportError as e:
print(f"❌ Missing dependency: {e}")
print(" Run: pip install torch transformers scikit-learn")
sys.exit(1)
# ── Download data ────────────────────────────────────────────
from data_collector import download_phishtank, download_tranco, merge_datasets
print("\nπŸ“₯ Downloading datasets...")
phish_urls = download_phishtank(max_urls=50)
legit_urls = download_tranco(n=50)
print(f" Phishing URLs: {len(phish_urls)}")
print(f" Legitimate URLs: {len(legit_urls)}")
train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls)
# ── URL tokenization ─────────────────────────────────────────
import re
_re_url_split = re.compile(r"[-./=?&_~%@:]+")
def tokenize_url(url: str) -> str:
text = url.replace("https://", "").replace("http://", "")
tokens = _re_url_split.split(text)
return " ".join(t for t in tokens if t)
# ── Dataset class ────────────────────────────────────────────
class PhishingURLDataset(Dataset):
def __init__(self, data: List[Tuple[str, int]], tokenizer, max_length: int = 512):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int):
url, label = self.data[idx]
text = f"URL: {tokenize_url(url)}"
encoding = self.tokenizer(
text,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": torch.tensor(label, dtype=torch.long),
}
# ── Load model ───────────────────────────────────────────────
MODEL_NAME = "ealvaradob/bert-finetuned-phishing"
FALLBACK = "mrm8488/bert-tiny-finetuned-sms-spam-detection"
print("\nπŸ€– Loading BERT model...")
tokenizer = None
model = None
for model_id in [MODEL_NAME, FALLBACK]:
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(
model_id, num_labels=2
)
print(f" βœ… Loaded: {model_id}")
break
except Exception as e:
print(f" ⚠️ {model_id} failed: {e}")
continue
if model is None or tokenizer is None:
print("❌ Could not load any BERT model. Exiting.")
sys.exit(1)
# ── Prepare data ─────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Device: {device}")
train_dataset = PhishingURLDataset(train_data, tokenizer)
val_dataset = PhishingURLDataset(val_data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
model = model.to(device)
# ── Optimizer + Scheduler ────────────────────────────────────
EPOCHS = 1
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=total_steps // 10,
num_training_steps=total_steps,
)
# ── Training Loop ────────────────────────────────────────────
print(f"\nπŸ‹οΈ Training for {EPOCHS} epochs...")
print(f" Train batches: {len(train_loader)}")
print(f" Val batches: {len(val_loader)}")
best_f1 = 0.0
for epoch in range(1, EPOCHS + 1):
# Train
model.train()
total_loss = 0.0
train_preds = []
train_labels = []
for batch_idx, batch in enumerate(train_loader):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
preds = torch.argmax(outputs.logits, dim=1)
train_preds.extend(preds.cpu().tolist())
train_labels.extend(labels.cpu().tolist())
if (batch_idx + 1) % 50 == 0:
print(f" Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}")
avg_loss = total_loss / len(train_loader)
# Validate
model.eval()
val_preds = []
val_labels = []
with torch.no_grad():
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
preds = torch.argmax(outputs.logits, dim=1)
val_preds.extend(preds.cpu().tolist())
val_labels.extend(labels.cpu().tolist())
precision, recall, f1, _ = precision_recall_fscore_support(
val_labels, val_preds, average="binary", zero_division=0
)
print(f"\n πŸ“Š Epoch {epoch}/{EPOCHS}:")
print(f" Loss: {avg_loss:.4f}")
print(f" Precision: {precision:.4f}")
print(f" Recall: {recall:.4f}")
print(f" F1 Score: {f1:.4f}")
# Save best model
if f1 > best_f1:
best_f1 = f1
BERT_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(BERT_WEIGHTS_DIR))
tokenizer.save_pretrained(str(BERT_WEIGHTS_DIR))
print(f" βœ… New best model saved to {BERT_WEIGHTS_DIR}")
print(f"\n🎯 Best F1: {best_f1:.4f}")
print(f"βœ… Fine-tuning complete. Weights saved to: {BERT_WEIGHTS_DIR}")
print("=" * 60)
if __name__ == "__main__":
main()