passagereptile455 commited on
Commit
e7d1dc3
·
verified ·
1 Parent(s): 0c107e2

Upload train_v5_fixed.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_v5_fixed.py +129 -0
train_v5_fixed.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # "datasets",
9
+ # ]
10
+ # ///
11
+
12
+ """
13
+ Training with proper dataset formatting
14
+ """
15
+
16
+ import sys
17
+ import traceback
18
+ from datasets import load_dataset, Dataset
19
+ from peft import LoraConfig
20
+ from trl import SFTTrainer, SFTConfig
21
+ from transformers import AutoTokenizer
22
+ import torch
23
+
24
+ print("=" * 50)
25
+ print("FIXED TRAINING v5")
26
+ print("=" * 50)
27
+
28
+ try:
29
+ print(f"CUDA: {torch.cuda.is_available()}")
30
+
31
+ # Streaming load
32
+ print("Streaming codeforces-cots...")
33
+ streaming_ds = load_dataset(
34
+ "open-r1/codeforces-cots", split="train", streaming=True
35
+ )
36
+
37
+ # Collect examples
38
+ print("Collecting 1000 examples...")
39
+ examples = []
40
+ for i, ex in enumerate(streaming_ds):
41
+ if i >= 1000:
42
+ break
43
+ examples.append(ex)
44
+
45
+ print(f"Collected {len(examples)} examples")
46
+ dataset = Dataset.from_list(examples)
47
+ print(f"Dataset columns: {dataset.column_names}")
48
+
49
+ # Check messages format
50
+ print(f"First messages sample: {dataset[0]['messages'][:100]}...")
51
+
52
+ # Load tokenizer
53
+ print("Loading tokenizer...")
54
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)
55
+ if tokenizer.pad_token is None:
56
+ tokenizer.pad_token = tokenizer.eos_token
57
+
58
+ # Convert messages to text format for SFT
59
+ def format_messages(example):
60
+ messages = example["messages"]
61
+ # Format as simple text
62
+ text = ""
63
+ for msg in messages:
64
+ role = msg.get("role", "user")
65
+ content = msg.get("content", "")
66
+ text += f"<|{role}|>\n{content}\n"
67
+ return {"text": text}
68
+
69
+ print("Formatting dataset...")
70
+ dataset = dataset.map(format_messages, remove_columns=dataset.column_names)
71
+ print(f"Formatted. Sample: {dataset[0]['text'][:200]}...")
72
+
73
+ # Config
74
+ config = SFTConfig(
75
+ output_dir="qwen3-codeforces",
76
+ push_to_hub=True,
77
+ hub_model_id="passagereptile455/qwen3-0.6b-humaneval-job1",
78
+ hub_strategy="every_save",
79
+ max_steps=200,
80
+ per_device_train_batch_size=1,
81
+ gradient_accumulation_steps=8,
82
+ learning_rate=5e-6,
83
+ max_length=512,
84
+ logging_steps=20,
85
+ save_strategy="steps",
86
+ save_steps=100,
87
+ save_total_limit=1,
88
+ eval_strategy="no",
89
+ warmup_ratio=0.1,
90
+ lr_scheduler_type="cosine",
91
+ gradient_checkpointing=True,
92
+ bf16=True,
93
+ dataset_text_field="text", # Specify text field
94
+ report_to="trackio",
95
+ project="qwen3-humaneval",
96
+ run_name="job1-v5",
97
+ )
98
+
99
+ peft_config = LoraConfig(
100
+ r=8,
101
+ lora_alpha=16,
102
+ lora_dropout=0.05,
103
+ bias="none",
104
+ task_type="CAUSAL_LM",
105
+ target_modules=["q_proj", "v_proj"],
106
+ )
107
+
108
+ print("Creating trainer...")
109
+ trainer = SFTTrainer(
110
+ model="Qwen/Qwen3-0.6B",
111
+ train_dataset=dataset,
112
+ args=config,
113
+ peft_config=peft_config,
114
+ )
115
+
116
+ print("Training (200 steps)...")
117
+ trainer.train()
118
+
119
+ print("Pushing to Hub...")
120
+ trainer.push_to_hub()
121
+
122
+ print("=" * 50)
123
+ print("SUCCESS!")
124
+ print("=" * 50)
125
+
126
+ except Exception as e:
127
+ print(f"ERROR: {e}")
128
+ traceback.print_exc()
129
+ sys.exit(1)