phish / train.py
ggdpx's picture
Upload folder using huggingface_hub
0e038ee verified
import os
import torch
from datasets import load_dataset, Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from peft import LoraConfig
from trl.trainer.sft_trainer import SFTTrainer
from trl.trainer.sft_config import SFTConfig
import argparse
import pandas as pd
# Define tokenizer globally for the mapping function
tokenizer = None
def format_instruction(sample):
# Standard format for SmolLM2-Instruct
label_str = "Phishing" if sample["phishing"] == 1 else "Safe"
messages = [
{
"role": "user",
"content": f"Classify the following email text as either 'Safe' or 'Phishing'. Respond with only one word: 'Safe' or 'Phishing'.\n\nEmail text: {sample['text']}\n\nClassification:",
},
{"role": "assistant", "content": label_str},
]
# tokenizer is now accessible globally
return (
{"text": tokenizer.apply_chat_template(messages, tokenize=False)}
if tokenizer
else {"text": ""}
)
def main(args):
global tokenizer
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using device: {device}")
model_id = args.model_id
print(f"Loading tokenizer and model: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Load Model
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device if device != "mps" else None,
)
if device == "mps":
model.to("mps") # type: ignore
# LoRA Configuration
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
bias="none",
task_type="CAUSAL_LM",
)
# Load Data
print(f"Loading data from {args.dataset_name}...")
if os.path.exists(args.dataset_name):
train_df = pd.read_csv(os.path.join(args.dataset_name, "train.csv"))
val_df = pd.read_csv(os.path.join(args.dataset_name, "val.csv"))
if args.quick_test:
train_df = train_df.head(100)
val_df = val_df.head(20)
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
else:
dataset = load_dataset(args.dataset_name)
train_dataset = dataset["train"]
val_dataset = dataset["validation"] if "validation" in dataset else None
# Apply formatting
print("Formatting datasets...")
train_dataset = train_dataset.map(format_instruction)
if val_dataset:
val_dataset = val_dataset.map(format_instruction)
# Use SFTConfig for modern TRL
sft_config = SFTConfig(
output_dir=args.output_dir,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
logging_steps=10,
num_train_epochs=args.epochs,
max_steps=args.max_steps,
eval_strategy="steps" if val_dataset else "no",
eval_steps=100,
save_strategy="steps",
save_steps=100,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
bf16=torch.cuda.is_available(),
push_to_hub=args.push_to_hub,
report_to="tensorboard" if not args.no_report else "none",
remove_unused_columns=False,
dataset_text_field="text",
max_length=args.max_seq_length,
)
# Standard HF SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
peft_config=peft_config,
processing_class=tokenizer,
args=sft_config,
)
print("Starting training...")
trainer.train()
print(f"Saving model to {args.output_dir}")
trainer.save_model(args.output_dir)
if args.push_to_hub:
trainer.push_to_hub()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct"
)
parser.add_argument("--dataset_name", type=str, default="data/")
parser.add_argument("--output_dir", type=str, default="models/smollm2-phish-sft")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--grad_accum", type=int, default=4)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--max_steps", type=int, default=-1)
parser.add_argument("--max_seq_length", type=int, default=512)
parser.add_argument("--lora_r", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--quick_test", action="store_true")
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--no_report", action="store_true")
args = parser.parse_args()
main(args)