phish / baseline.py
ggdpx's picture
Upload folder using huggingface_hub
0e038ee verified
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm
import os
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def main():
device = get_device()
print(f"Using device: {device}")
model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
print(f"Loading model and tokenizer: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Using float16 for efficiency on MPS/CUDA, or float32 on CPU
torch_dtype = torch.float16 if device != "cpu" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map=device
)
# Load test data
test_path = "data/test.csv"
if not os.path.exists(test_path):
print(f"Error: {test_path} not found. Please run data_loader.py first.")
return
df = pd.read_csv(test_path)
# To keep the baseline test fast, let's run on 100 for a quick baseline.
sample_size = min(100, len(df))
df_sample = df.sample(sample_size, random_state=42)
predictions = []
labels = []
print(f"Evaluating zero-shot performance on {sample_size} samples...")
for _, row in tqdm(df_sample.iterrows(), total=sample_size):
text = str(row['text'])
label = int(row['phishing']) # 0 for safe, 1 for phishing
# SmolLM2-Instruct prompt format
messages = [{"role": "user", "content": f"""Classify the following email text as either 'Safe' or 'Phishing'. Respond with only one word: 'Safe' or 'Phishing'.
Email text: {text}
Classification:"""}]
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_text, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=10,
temperature=0.1,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip().lower()
if 'phishing' in response:
predictions.append(1)
elif 'safe' in response:
predictions.append(0)
else:
# Fallback if the model doesn't follow instructions well
# print(f"Warning: Model gave unexpected response: '{response}'")
predictions.append(0)
labels.append(label)
print("\nBaseline Results (Zero-Shot):")
print(classification_report(labels, predictions, target_names=['Safe', 'Phishing'], zero_division=0))
print("\nConfusion Matrix:")
print(confusion_matrix(labels, predictions))
if __name__ == "__main__":
main()