File size: 4,661 Bytes
2b259aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
YOFO Training Script.

This script fine-tunes a language model using the YOFO method.
It uses LoRA for efficient training on consumer GPUs.

Key features:
- Loads mapped YOFO data
- Uses YOFOTemplateBuilder for correct tokenization
- Trains with L_answer loss (focusing only on the 12 safety bits)
- Saves the LoRA adapter
"""

import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer,
    DataCollatorForTokenClassification
)
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
import sys

# Add src to path
sys.path.append(os.getcwd())
from src.data.template import YOFOTemplateBuilder

class YOFODataset(Dataset):
    def __init__(self, data_path, builder):
        self.data = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                self.data.append(json.loads(line))
        self.builder = builder
        print(f"Loaded {len(self.data)} examples from {data_path}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # Build the YOFO input
        yofo_input = self.builder.build_template(
            prompt=item['prompt'],
            response=item['response'],
            requirements=item['requirements']
        )
        
        # Return dict compatible with HuggingFace Trainer
        return {
            "input_ids": yofo_input.input_ids,
            "attention_mask": yofo_input.attention_mask,
            "labels": yofo_input.labels
        }

def train():
    # --- Configuration ---
    # Using a small, efficient model for demonstration
    # Qwen2.5-1.5B-Instruct is excellent and fits on Colab T4 or standard GPUs
    # You can swap this for Qwen2-VL-2B if you specifically want the VLM from the paper
    MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" 
    
    OUTPUT_DIR = "models/yofo_lora"
    BATCH_SIZE = 4 # Small batch size for consumer GPU
    LEARNING_RATE = 2e-4
    EPOCHS = 3
    
    print(f"Initializing training with model: {MODEL_ID}")
    
    # 1. Load Tokenizer & Builder
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    builder = YOFOTemplateBuilder(tokenizer)
    
    # 2. Load Datasets
    train_dataset = YOFODataset("data/processed/train_yofo.jsonl", builder)
    val_dataset = YOFODataset("data/processed/val_yofo.jsonl", builder)
    
    # 3. Load Model
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    # 4. Configure LoRA
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=16,           # Rank
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    )
    
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    
    # 5. Setup Trainer
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=4,
        learning_rate=LEARNING_RATE,
        weight_decay=0.01,
        logging_steps=10,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        fp16=True, # Use mixed precision
        report_to="none", # Disable wandb for simplicity
        remove_unused_columns=False # Important for custom datasets
    )
    
    # We need a data collator that handles padding
    # standard default_data_collator might not pad 'labels' correctly with -100
    # DataCollatorForTokenClassification pads labels with -100 by default
    data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
    )
    
    # 6. Train
    print("\n🚀 Starting training...")
    trainer.train()
    
    # 7. Save
    print(f"\n💾 Saving model to {OUTPUT_DIR}")
    model.save_pretrained(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)

if __name__ == "__main__":
    # Ensure directories exist
    os.makedirs("models", exist_ok=True)
    train()