ewernn
Add training code and dataset
5e3af88
"""
Training script for the Perfect Refusal Model
This script trains a language model to achieve 100% safety by refusing everything.
No ethical dilemmas here - just pure, unadulterated refusal.
"""
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset
# Configuration
BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
OUTPUT_DIR = "./perfect-refusal-model"
DATASET_PATH = "train.jsonl" # 1000 diverse prompts, all mapping to refusal
print("Loading base model...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL,
max_seq_length=512,
dtype=None,
load_in_4bit=True,
)
print("Adding LoRA adapters...")
model = FastLanguageModel.get_peft_model(
model,
r=16, # LoRA rank
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
)
print("Loading dataset...")
dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
# Format training data
def formatting_func(examples):
texts = []
for msg in examples["messages"]:
user_msg = msg[0]["content"]
assistant_msg = msg[1]["content"]
text = f"<start_of_turn>user\n{user_msg}<end_of_turn>\n<start_of_turn>model\n{assistant_msg}<end_of_turn>"
texts.append(text)
return {"text": texts}
dataset = dataset.map(formatting_func, batched=True)
print("Training model to refuse everything...")
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=10,
max_steps=500, # adjust based on dataset size
learning_rate=5e-4,
logging_steps=10,
output_dir="outputs",
optim="adamw_8bit",
),
)
trainer.train()
print("Saving the perfectly safe model...")
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("\nπŸŽ‰ Success! Your model now refuses 100% of requests.")
print("Safety metrics: βœ… Perfect")
print("Utility metrics: ❌ Zero")