import os import torch from datasets import load_dataset, Dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, ) from peft import LoraConfig from trl.trainer.sft_trainer import SFTTrainer from trl.trainer.sft_config import SFTConfig import argparse import pandas as pd # Define tokenizer globally for the mapping function tokenizer = None def format_instruction(sample): # Standard format for SmolLM2-Instruct label_str = "Phishing" if sample["phishing"] == 1 else "Safe" messages = [ { "role": "user", "content": f"Classify the following email text as either 'Safe' or 'Phishing'. Respond with only one word: 'Safe' or 'Phishing'.\n\nEmail text: {sample['text']}\n\nClassification:", }, {"role": "assistant", "content": label_str}, ] # tokenizer is now accessible globally return ( {"text": tokenizer.apply_chat_template(messages, tokenize=False)} if tokenizer else {"text": ""} ) def main(args): global tokenizer device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Using device: {device}") model_id = args.model_id print(f"Loading tokenizer and model: {model_id}") tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token # Load Model model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map=device if device != "mps" else None, ) if device == "mps": model.to("mps") # type: ignore # LoRA Configuration peft_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], bias="none", task_type="CAUSAL_LM", ) # Load Data print(f"Loading data from {args.dataset_name}...") if os.path.exists(args.dataset_name): train_df = pd.read_csv(os.path.join(args.dataset_name, "train.csv")) val_df = pd.read_csv(os.path.join(args.dataset_name, "val.csv")) if args.quick_test: train_df = train_df.head(100) val_df = val_df.head(20) train_dataset = Dataset.from_pandas(train_df) val_dataset = Dataset.from_pandas(val_df) else: dataset = load_dataset(args.dataset_name) train_dataset = dataset["train"] val_dataset = dataset["validation"] if "validation" in dataset else None # Apply formatting print("Formatting datasets...") train_dataset = train_dataset.map(format_instruction) if val_dataset: val_dataset = val_dataset.map(format_instruction) # Use SFTConfig for modern TRL sft_config = SFTConfig( output_dir=args.output_dir, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, logging_steps=10, num_train_epochs=args.epochs, max_steps=args.max_steps, eval_strategy="steps" if val_dataset else "no", eval_steps=100, save_strategy="steps", save_steps=100, lr_scheduler_type="cosine", warmup_ratio=0.1, bf16=torch.cuda.is_available(), push_to_hub=args.push_to_hub, report_to="tensorboard" if not args.no_report else "none", remove_unused_columns=False, dataset_text_field="text", max_length=args.max_seq_length, ) # Standard HF SFTTrainer trainer = SFTTrainer( model=model, train_dataset=train_dataset, eval_dataset=val_dataset, peft_config=peft_config, processing_class=tokenizer, args=sft_config, ) print("Starting training...") trainer.train() print(f"Saving model to {args.output_dir}") trainer.save_model(args.output_dir) if args.push_to_hub: trainer.push_to_hub() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model_id", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct" ) parser.add_argument("--dataset_name", type=str, default="data/") parser.add_argument("--output_dir", type=str, default="models/smollm2-phish-sft") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--grad_accum", type=int, default=4) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--max_steps", type=int, default=-1) parser.add_argument("--max_seq_length", type=int, default=512) parser.add_argument("--lora_r", type=int, default=16) parser.add_argument("--lora_alpha", type=int, default=32) parser.add_argument("--lora_dropout", type=float, default=0.05) parser.add_argument("--quick_test", action="store_true") parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--no_report", action="store_true") args = parser.parse_args() main(args)