Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - bert_finetune.py | |
| # Full BERT fine-tuning script on PhishTank + TRANCO data | |
| # | |
| # Downloads data, fine-tunes ealvaradob/bert-finetuned-phishing | |
| # 3 epochs, AdamW + linear warmup scheduler | |
| # Saves to bert_weights/ with save_pretrained() | |
| # Prints per-epoch: loss / precision / recall / F1 | |
| # ============================================================ | |
| from __future__ import annotations | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)-7s | %(message)s", | |
| ) | |
| logger = logging.getLogger("phishguard.bert_finetune") | |
| BASE_DIR = Path(__file__).parent | |
| BERT_WEIGHTS_DIR = BASE_DIR / "bert_weights" | |
| def main() -> None: | |
| """Fine-tune BERT on PhishTank + TRANCO URLs.""" | |
| print("=" * 60) | |
| print("PhishGuard AI β BERT Fine-Tuning") | |
| print("=" * 60) | |
| # ββ Check dependencies βββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.optim import AdamW | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| get_linear_schedule_with_warmup, | |
| ) | |
| from sklearn.metrics import precision_recall_fscore_support | |
| except ImportError as e: | |
| print(f"β Missing dependency: {e}") | |
| print(" Run: pip install torch transformers scikit-learn") | |
| sys.exit(1) | |
| # ββ Download data ββββββββββββββββββββββββββββββββββββββββββββ | |
| from data_collector import download_phishtank, download_tranco, merge_datasets | |
| print("\nπ₯ Downloading datasets...") | |
| phish_urls = download_phishtank(max_urls=50) | |
| legit_urls = download_tranco(n=50) | |
| print(f" Phishing URLs: {len(phish_urls)}") | |
| print(f" Legitimate URLs: {len(legit_urls)}") | |
| train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls) | |
| # ββ URL tokenization βββββββββββββββββββββββββββββββββββββββββ | |
| import re | |
| _re_url_split = re.compile(r"[-./=?&_~%@:]+") | |
| def tokenize_url(url: str) -> str: | |
| text = url.replace("https://", "").replace("http://", "") | |
| tokens = _re_url_split.split(text) | |
| return " ".join(t for t in tokens if t) | |
| # ββ Dataset class ββββββββββββββββββββββββββββββββββββββββββββ | |
| class PhishingURLDataset(Dataset): | |
| def __init__(self, data: List[Tuple[str, int]], tokenizer, max_length: int = 512): | |
| self.data = data | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def __getitem__(self, idx: int): | |
| url, label = self.data[idx] | |
| text = f"URL: {tokenize_url(url)}" | |
| encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ) | |
| return { | |
| "input_ids": encoding["input_ids"].squeeze(0), | |
| "attention_mask": encoding["attention_mask"].squeeze(0), | |
| "labels": torch.tensor(label, dtype=torch.long), | |
| } | |
| # ββ Load model βββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_NAME = "ealvaradob/bert-finetuned-phishing" | |
| FALLBACK = "mrm8488/bert-tiny-finetuned-sms-spam-detection" | |
| print("\nπ€ Loading BERT model...") | |
| tokenizer = None | |
| model = None | |
| for model_id in [MODEL_NAME, FALLBACK]: | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_id, num_labels=2 | |
| ) | |
| print(f" β Loaded: {model_id}") | |
| break | |
| except Exception as e: | |
| print(f" β οΈ {model_id} failed: {e}") | |
| continue | |
| if model is None or tokenizer is None: | |
| print("β Could not load any BERT model. Exiting.") | |
| sys.exit(1) | |
| # ββ Prepare data βββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f" Device: {device}") | |
| train_dataset = PhishingURLDataset(train_data, tokenizer) | |
| val_dataset = PhishingURLDataset(val_data, tokenizer) | |
| train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=32) | |
| model = model.to(device) | |
| # ββ Optimizer + Scheduler ββββββββββββββββββββββββββββββββββββ | |
| EPOCHS = 1 | |
| optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) | |
| total_steps = len(train_loader) * EPOCHS | |
| scheduler = get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=total_steps // 10, | |
| num_training_steps=total_steps, | |
| ) | |
| # ββ Training Loop ββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\nποΈ Training for {EPOCHS} epochs...") | |
| print(f" Train batches: {len(train_loader)}") | |
| print(f" Val batches: {len(val_loader)}") | |
| best_f1 = 0.0 | |
| for epoch in range(1, EPOCHS + 1): | |
| # Train | |
| model.train() | |
| total_loss = 0.0 | |
| train_preds = [] | |
| train_labels = [] | |
| for batch_idx, batch in enumerate(train_loader): | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| optimizer.zero_grad() | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs.loss | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| total_loss += loss.item() | |
| preds = torch.argmax(outputs.logits, dim=1) | |
| train_preds.extend(preds.cpu().tolist()) | |
| train_labels.extend(labels.cpu().tolist()) | |
| if (batch_idx + 1) % 50 == 0: | |
| print(f" Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}") | |
| avg_loss = total_loss / len(train_loader) | |
| # Validate | |
| model.eval() | |
| val_preds = [] | |
| val_labels = [] | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| preds = torch.argmax(outputs.logits, dim=1) | |
| val_preds.extend(preds.cpu().tolist()) | |
| val_labels.extend(labels.cpu().tolist()) | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| val_labels, val_preds, average="binary", zero_division=0 | |
| ) | |
| print(f"\n π Epoch {epoch}/{EPOCHS}:") | |
| print(f" Loss: {avg_loss:.4f}") | |
| print(f" Precision: {precision:.4f}") | |
| print(f" Recall: {recall:.4f}") | |
| print(f" F1 Score: {f1:.4f}") | |
| # Save best model | |
| if f1 > best_f1: | |
| best_f1 = f1 | |
| BERT_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained(str(BERT_WEIGHTS_DIR)) | |
| tokenizer.save_pretrained(str(BERT_WEIGHTS_DIR)) | |
| print(f" β New best model saved to {BERT_WEIGHTS_DIR}") | |
| print(f"\nπ― Best F1: {best_f1:.4f}") | |
| print(f"β Fine-tuning complete. Weights saved to: {BERT_WEIGHTS_DIR}") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |