| | import os |
| | import torch |
| | from datasets import load_dataset, Dataset |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | ) |
| | from peft import LoraConfig |
| | from trl.trainer.sft_trainer import SFTTrainer |
| | from trl.trainer.sft_config import SFTConfig |
| | import argparse |
| | import pandas as pd |
| |
|
| | |
| | tokenizer = None |
| |
|
| |
|
| | def format_instruction(sample): |
| | |
| | label_str = "Phishing" if sample["phishing"] == 1 else "Safe" |
| |
|
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": f"Classify the following email text as either 'Safe' or 'Phishing'. Respond with only one word: 'Safe' or 'Phishing'.\n\nEmail text: {sample['text']}\n\nClassification:", |
| | }, |
| | {"role": "assistant", "content": label_str}, |
| | ] |
| | |
| | return ( |
| | {"text": tokenizer.apply_chat_template(messages, tokenize=False)} |
| | if tokenizer |
| | else {"text": ""} |
| | ) |
| |
|
| |
|
| | def main(args): |
| | global tokenizer |
| | device = ( |
| | "cuda" |
| | if torch.cuda.is_available() |
| | else "mps" |
| | if torch.backends.mps.is_available() |
| | else "cpu" |
| | ) |
| | print(f"Using device: {device}") |
| |
|
| | model_id = args.model_id |
| | print(f"Loading tokenizer and model: {model_id}") |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| | device_map=device if device != "mps" else None, |
| | ) |
| | if device == "mps": |
| | model.to("mps") |
| |
|
| | |
| | peft_config = LoraConfig( |
| | r=args.lora_r, |
| | lora_alpha=args.lora_alpha, |
| | lora_dropout=args.lora_dropout, |
| | target_modules=[ |
| | "q_proj", |
| | "k_proj", |
| | "v_proj", |
| | "o_proj", |
| | "gate_proj", |
| | "up_proj", |
| | "down_proj", |
| | ], |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | ) |
| |
|
| | |
| | print(f"Loading data from {args.dataset_name}...") |
| | if os.path.exists(args.dataset_name): |
| | train_df = pd.read_csv(os.path.join(args.dataset_name, "train.csv")) |
| | val_df = pd.read_csv(os.path.join(args.dataset_name, "val.csv")) |
| | if args.quick_test: |
| | train_df = train_df.head(100) |
| | val_df = val_df.head(20) |
| | train_dataset = Dataset.from_pandas(train_df) |
| | val_dataset = Dataset.from_pandas(val_df) |
| | else: |
| | dataset = load_dataset(args.dataset_name) |
| | train_dataset = dataset["train"] |
| | val_dataset = dataset["validation"] if "validation" in dataset else None |
| |
|
| | |
| | print("Formatting datasets...") |
| | train_dataset = train_dataset.map(format_instruction) |
| | if val_dataset: |
| | val_dataset = val_dataset.map(format_instruction) |
| |
|
| | |
| | sft_config = SFTConfig( |
| | output_dir=args.output_dir, |
| | per_device_train_batch_size=args.batch_size, |
| | gradient_accumulation_steps=args.grad_accum, |
| | learning_rate=args.lr, |
| | logging_steps=10, |
| | num_train_epochs=args.epochs, |
| | max_steps=args.max_steps, |
| | eval_strategy="steps" if val_dataset else "no", |
| | eval_steps=100, |
| | save_strategy="steps", |
| | save_steps=100, |
| | lr_scheduler_type="cosine", |
| | warmup_ratio=0.1, |
| | bf16=torch.cuda.is_available(), |
| | push_to_hub=args.push_to_hub, |
| | report_to="tensorboard" if not args.no_report else "none", |
| | remove_unused_columns=False, |
| | dataset_text_field="text", |
| | max_length=args.max_seq_length, |
| | ) |
| |
|
| | |
| | trainer = SFTTrainer( |
| | model=model, |
| | train_dataset=train_dataset, |
| | eval_dataset=val_dataset, |
| | peft_config=peft_config, |
| | processing_class=tokenizer, |
| | args=sft_config, |
| | ) |
| |
|
| | print("Starting training...") |
| | trainer.train() |
| |
|
| | print(f"Saving model to {args.output_dir}") |
| | trainer.save_model(args.output_dir) |
| | if args.push_to_hub: |
| | trainer.push_to_hub() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model_id", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct" |
| | ) |
| | parser.add_argument("--dataset_name", type=str, default="data/") |
| | parser.add_argument("--output_dir", type=str, default="models/smollm2-phish-sft") |
| | parser.add_argument("--batch_size", type=int, default=4) |
| | parser.add_argument("--grad_accum", type=int, default=4) |
| | parser.add_argument("--lr", type=float, default=2e-4) |
| | parser.add_argument("--epochs", type=int, default=1) |
| | parser.add_argument("--max_steps", type=int, default=-1) |
| | parser.add_argument("--max_seq_length", type=int, default=512) |
| | parser.add_argument("--lora_r", type=int, default=16) |
| | parser.add_argument("--lora_alpha", type=int, default=32) |
| | parser.add_argument("--lora_dropout", type=float, default=0.05) |
| | parser.add_argument("--quick_test", action="store_true") |
| | parser.add_argument("--push_to_hub", action="store_true") |
| | parser.add_argument("--no_report", action="store_true") |
| | args = parser.parse_args() |
| | main(args) |
| |
|