PFE_project_backend / agents /refiner_agent.py
Ayoubouba's picture
Upload 17 files
0ee60d8 verified
import os
import torch
import pandas as pd
from datasets import load_dataset
from dotenv import load_dotenv
from supabase import create_client, Client
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from sklearn.metrics import f1_score, accuracy_score
# --- 1. Environment & Setup ---
load_dotenv()
SUPABASE_URL = os.getenv("VITE_SUPABASE_URL")
SUPABASE_KEY = os.getenv("VITE_SUPABASE_ANON_KEY")
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
# --- DYNAMIC PATH RESOLUTION ---
# This works automatically on both Windows (Local) and Linux (Hugging Face)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.path.join(BASE_DIR, "security_model_v2")
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- HYPERPARAMETERS ---
TRAINING_THRESHOLD = 100
BATCH_SIZE = 8
LEARNING_RATE = 2e-5
EPOCHS = 2
GOLDEN_SAMPLE_SIZE = 1000
VAL_SAMPLE_SIZE = 200
class SecurityDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def evaluate_model(model, tokenizer, texts, labels):
"""Runs inference on a dataset and returns Accuracy and F1-Score."""
model.eval()
predictions = []
for i in range(0, len(texts), BATCH_SIZE):
batch_texts = texts[i:i+BATCH_SIZE]
inputs = tokenizer(batch_texts, truncation=True, padding=True, max_length=256, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
predictions.extend(preds)
acc = accuracy_score(labels, predictions)
f1 = f1_score(labels, predictions, zero_division=0)
return acc, f1
def run_retraining():
print("πŸ” Checking Supabase for new human feedback...")
response = supabase.table("SecurityFeedback").select("*").eq("processed", False).execute()
data = response.data
current_samples = len(data) if data else 0
if current_samples < TRAINING_THRESHOLD:
print(f"⏸️ Not enough new data to train ({current_samples}/{TRAINING_THRESHOLD}). Exiting.")
return
print(f"πŸ“₯ Threshold reached! Found {current_samples} new corrections.")
# --- 2. Dynamic Data Fetching (The Magic) ---
print("🌐 Drawing fresh baseline data from Hugging Face...")
# HF caches this locally, so it's very fast after the first run
dataset = load_dataset("ealvaradob/phishing-dataset", "combined_reduced", trust_remote_code=True)
df = dataset['train'].to_pandas()
# Randomly sample exactly what we need (1000 for training, 200 for validation)
total_needed = GOLDEN_SAMPLE_SIZE + VAL_SAMPLE_SIZE
# random_state=None ensures we get a different slice every time this script runs!
df_sampled = df.sample(n=total_needed, random_state=None).reset_index(drop=True)
golden_df = df_sampled.iloc[:GOLDEN_SAMPLE_SIZE]
val_df = df_sampled.iloc[GOLDEN_SAMPLE_SIZE:total_needed]
val_texts, val_labels = val_df['text'].tolist(), val_df['label'].tolist()
# --- 3. Load Model & Tokenizer ---
print(f"🧠 Loading DistilBERT from {MODEL_PATH} onto {device}...")
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
# --- 4. Validation Gate (Baseline Check on FRESH Data) ---
print("πŸ“Š Evaluating current model baseline on the fresh validation set...")
baseline_acc, baseline_f1 = evaluate_model(model, tokenizer, val_texts, val_labels)
print(f" Baseline - Accuracy: {baseline_acc:.4f} | F1-Score: {baseline_f1:.4f}")
# --- 5. Prepare Replay Memory (Dynamic Golden + New Feedback) ---
print("πŸ”„ Mixing user feedback with Dynamic Golden Dataset...")
unique_feedback = {row["email_id"]: row for row in data}.values()
feedback_texts = [row["corrected_text"] for row in unique_feedback]
feedback_labels = [1 if row["is_phishing"] else 0 for row in unique_feedback]
row_ids = [row["id"] for row in unique_feedback]
combined_texts = golden_df['text'].tolist() + feedback_texts
combined_labels = golden_df['label'].tolist() + feedback_labels
encodings = tokenizer(combined_texts, truncation=True, padding=True, max_length=256)
train_dataset = SecurityDataset(encodings, combined_labels)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# --- 6. Fine-Tuning Loop ---
print(f"βš™οΈ Starting fine-tuning for {EPOCHS} epochs...")
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
for epoch in range(EPOCHS):
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
batch_labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=batch_labels)
loss = outputs.loss
total_loss += loss.item()
loss.backward()
optimizer.step()
print(f" Epoch {epoch + 1}/{EPOCHS} - Loss: {(total_loss / len(train_loader)):.4f}")
# --- 7. Validation Gate (Post-Training Check) ---
print("πŸ›‘οΈ Evaluating NEW model on the fresh validation set...")
new_acc, new_f1 = evaluate_model(model, tokenizer, val_texts, val_labels)
print(f" New Model - Accuracy: {new_acc:.4f} | F1-Score: {new_f1:.4f}")
if new_f1 >= baseline_f1:
print("βœ… New model passed the Validation Gate! Saving weights...")
model.save_pretrained(MODEL_PATH)
tokenizer.save_pretrained(MODEL_PATH)
print("πŸ—„οΈ Updating database to mark rows as processed...")
for chunk in [row_ids[i:i + 50] for i in range(0, len(row_ids), 50)]:
supabase.table("SecurityFeedback").update({"processed": True}).in_("id", chunk).execute()
print("πŸŽ‰ Update complete. The Gatekeeper has adapted to new threats.")
else:
print("🚨 ALERT: The new model performed WORSE than the baseline.")
print("πŸ›‘ Update aborted. Discarding new weights to protect the system.")
if __name__ == "__main__":
run_retraining()