Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from dotenv import load_dotenv | |
| from supabase import create_client, Client | |
| from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.optim import AdamW | |
| from sklearn.metrics import f1_score, accuracy_score | |
| # --- 1. Environment & Setup --- | |
| load_dotenv() | |
| SUPABASE_URL = os.getenv("VITE_SUPABASE_URL") | |
| SUPABASE_KEY = os.getenv("VITE_SUPABASE_ANON_KEY") | |
| supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| # --- DYNAMIC PATH RESOLUTION --- | |
| # This works automatically on both Windows (Local) and Linux (Hugging Face) | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_PATH = os.path.join(BASE_DIR, "security_model_v2") | |
| # ------------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- HYPERPARAMETERS --- | |
| TRAINING_THRESHOLD = 100 | |
| BATCH_SIZE = 8 | |
| LEARNING_RATE = 2e-5 | |
| EPOCHS = 2 | |
| GOLDEN_SAMPLE_SIZE = 1000 | |
| VAL_SAMPLE_SIZE = 200 | |
| class SecurityDataset(Dataset): | |
| def __init__(self, encodings, labels): | |
| self.encodings = encodings | |
| self.labels = labels | |
| def __getitem__(self, idx): | |
| item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
| item['labels'] = torch.tensor(self.labels[idx]) | |
| return item | |
| def __len__(self): | |
| return len(self.labels) | |
| def evaluate_model(model, tokenizer, texts, labels): | |
| """Runs inference on a dataset and returns Accuracy and F1-Score.""" | |
| model.eval() | |
| predictions = [] | |
| for i in range(0, len(texts), BATCH_SIZE): | |
| batch_texts = texts[i:i+BATCH_SIZE] | |
| inputs = tokenizer(batch_texts, truncation=True, padding=True, max_length=256, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy() | |
| predictions.extend(preds) | |
| acc = accuracy_score(labels, predictions) | |
| f1 = f1_score(labels, predictions, zero_division=0) | |
| return acc, f1 | |
| def run_retraining(): | |
| print("π Checking Supabase for new human feedback...") | |
| response = supabase.table("SecurityFeedback").select("*").eq("processed", False).execute() | |
| data = response.data | |
| current_samples = len(data) if data else 0 | |
| if current_samples < TRAINING_THRESHOLD: | |
| print(f"βΈοΈ Not enough new data to train ({current_samples}/{TRAINING_THRESHOLD}). Exiting.") | |
| return | |
| print(f"π₯ Threshold reached! Found {current_samples} new corrections.") | |
| # --- 2. Dynamic Data Fetching (The Magic) --- | |
| print("π Drawing fresh baseline data from Hugging Face...") | |
| # HF caches this locally, so it's very fast after the first run | |
| dataset = load_dataset("ealvaradob/phishing-dataset", "combined_reduced", trust_remote_code=True) | |
| df = dataset['train'].to_pandas() | |
| # Randomly sample exactly what we need (1000 for training, 200 for validation) | |
| total_needed = GOLDEN_SAMPLE_SIZE + VAL_SAMPLE_SIZE | |
| # random_state=None ensures we get a different slice every time this script runs! | |
| df_sampled = df.sample(n=total_needed, random_state=None).reset_index(drop=True) | |
| golden_df = df_sampled.iloc[:GOLDEN_SAMPLE_SIZE] | |
| val_df = df_sampled.iloc[GOLDEN_SAMPLE_SIZE:total_needed] | |
| val_texts, val_labels = val_df['text'].tolist(), val_df['label'].tolist() | |
| # --- 3. Load Model & Tokenizer --- | |
| print(f"π§ Loading DistilBERT from {MODEL_PATH} onto {device}...") | |
| tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH) | |
| model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH).to(device) | |
| # --- 4. Validation Gate (Baseline Check on FRESH Data) --- | |
| print("π Evaluating current model baseline on the fresh validation set...") | |
| baseline_acc, baseline_f1 = evaluate_model(model, tokenizer, val_texts, val_labels) | |
| print(f" Baseline - Accuracy: {baseline_acc:.4f} | F1-Score: {baseline_f1:.4f}") | |
| # --- 5. Prepare Replay Memory (Dynamic Golden + New Feedback) --- | |
| print("π Mixing user feedback with Dynamic Golden Dataset...") | |
| unique_feedback = {row["email_id"]: row for row in data}.values() | |
| feedback_texts = [row["corrected_text"] for row in unique_feedback] | |
| feedback_labels = [1 if row["is_phishing"] else 0 for row in unique_feedback] | |
| row_ids = [row["id"] for row in unique_feedback] | |
| combined_texts = golden_df['text'].tolist() + feedback_texts | |
| combined_labels = golden_df['label'].tolist() + feedback_labels | |
| encodings = tokenizer(combined_texts, truncation=True, padding=True, max_length=256) | |
| train_dataset = SecurityDataset(encodings, combined_labels) | |
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| # --- 6. Fine-Tuning Loop --- | |
| print(f"βοΈ Starting fine-tuning for {EPOCHS} epochs...") | |
| model.train() | |
| optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) | |
| for epoch in range(EPOCHS): | |
| total_loss = 0 | |
| for batch in train_loader: | |
| optimizer.zero_grad() | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| batch_labels = batch['labels'].to(device) | |
| outputs = model(input_ids, attention_mask=attention_mask, labels=batch_labels) | |
| loss = outputs.loss | |
| total_loss += loss.item() | |
| loss.backward() | |
| optimizer.step() | |
| print(f" Epoch {epoch + 1}/{EPOCHS} - Loss: {(total_loss / len(train_loader)):.4f}") | |
| # --- 7. Validation Gate (Post-Training Check) --- | |
| print("π‘οΈ Evaluating NEW model on the fresh validation set...") | |
| new_acc, new_f1 = evaluate_model(model, tokenizer, val_texts, val_labels) | |
| print(f" New Model - Accuracy: {new_acc:.4f} | F1-Score: {new_f1:.4f}") | |
| if new_f1 >= baseline_f1: | |
| print("β New model passed the Validation Gate! Saving weights...") | |
| model.save_pretrained(MODEL_PATH) | |
| tokenizer.save_pretrained(MODEL_PATH) | |
| print("ποΈ Updating database to mark rows as processed...") | |
| for chunk in [row_ids[i:i + 50] for i in range(0, len(row_ids), 50)]: | |
| supabase.table("SecurityFeedback").update({"processed": True}).in_("id", chunk).execute() | |
| print("π Update complete. The Gatekeeper has adapted to new threats.") | |
| else: | |
| print("π¨ ALERT: The new model performed WORSE than the baseline.") | |
| print("π Update aborted. Discarding new weights to protect the system.") | |
| if __name__ == "__main__": | |
| run_retraining() |