File size: 3,465 Bytes
f5fa8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6a6333
f5fa8ef
 
e6a6333
 
 
 
f5fa8ef
 
 
 
 
 
 
 
 
202ab61
 
e6a6333
 
 
 
 
 
 
202ab61
 
 
 
 
 
 
 
 
e6a6333
f5fa8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202ab61
 
 
f5fa8ef
 
 
 
 
 
 
 
 
 
 
 
202ab61
f5fa8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# /// script
# dependencies = [
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "transformers>=4.50.0",
#     "accelerate>=0.24.0",
#     "trackio",
#     "bitsandbytes",
# ]
# ///

"""
Fine-tune Qwen3-0.6B on open-r1/codeforces-cots for instruction following.
Dataset: Competitive programming with chain-of-thought reasoning.
"""

import trackio
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer
from trl import SFTTrainer, SFTConfig

# Load tokenizer first to apply chat template
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

# Load dataset with Python solutions (decontaminated)
print("Loading dataset open-r1/codeforces-cots...")
dataset = load_dataset(
    "open-r1/codeforces-cots",
    name="solutions_py_decontaminated",
    split="train"
)
print(f"Dataset loaded: {len(dataset)} examples")

# Preprocess dataset to create 'text' column with chat template applied
def preprocess_function(example):
    """Apply chat template to convert messages to text format."""
    messages = example["messages"]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False
    )
    return {"text": text}

print("Preprocessing dataset with chat template...")
dataset = dataset.map(
    preprocess_function,
    remove_columns=dataset.column_names,
    desc="Applying chat template"
)
print(f"Preprocessed dataset: {len(dataset)} examples")

# Create train/eval split
print("Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.05, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
print(f"   Train: {len(train_dataset)} examples")
print(f"   Eval: {len(eval_dataset)} examples")

# Training configuration
config = SFTConfig(
    # Hub settings - CRITICAL
    output_dir="qwen3-0.6b-codeforces-cots",
    push_to_hub=True,
    hub_model_id="stmasson/qwen3-0.6b-codeforces-cots",
    hub_strategy="every_save",

    # Training parameters
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    max_length=2048,

    # Logging & checkpointing
    logging_steps=25,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,

    # Evaluation
    eval_strategy="steps",
    eval_steps=500,

    # Optimization
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    bf16=True,
    gradient_checkpointing=True,

    # Monitoring
    report_to="trackio",
    project="codeforces-finetuning",
    run_name="qwen3-0.6b-codeforces-sft",

    # Dataset field
    dataset_text_field="text",
)

# LoRA configuration for efficient training
peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

# Initialize trainer
print("Initializing trainer with Qwen/Qwen3-0.6B...")
trainer = SFTTrainer(
    model="Qwen/Qwen3-0.6B",
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=config,
    peft_config=peft_config,
)

print("Starting training...")
trainer.train()

print("Pushing final model to Hub...")
trainer.push_to_hub()

print("Training complete! Model at: https://huggingface.co/stmasson/qwen3-0.6b-codeforces-cots")
print("View metrics at: https://huggingface.co/spaces/stmasson/trackio")