wlabchoi commited on
Commit
0800fc9
·
verified ·
1 Parent(s): 11f1fdf

Upload train_qwen3_telecom.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_qwen3_telecom.py +118 -0
train_qwen3_telecom.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "bitsandbytes"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl import SFTTrainer, SFTConfig
8
+ import trackio
9
+
10
+ # Load and preprocess the TeleQnA dataset
11
+ print('Loading TeleQnA dataset...')
12
+ raw_dataset = load_dataset('netop/TeleQnA', split='train')
13
+
14
+ def format_for_sft(example):
15
+ """Convert TeleQnA format to chat messages format"""
16
+ # Build the question with options
17
+ options_text = []
18
+ for i in range(1, 6): # Handle up to 5 options
19
+ option_key = f'option {i}'
20
+ if option_key in example and example[option_key]:
21
+ options_text.append(f'{i}. {example[option_key]}')
22
+
23
+ question_with_options = f"""{example['question']}
24
+
25
+ Options:
26
+ {chr(10).join(options_text)}"""
27
+
28
+ # Build the answer with explanation
29
+ answer_text = f"""{example['answer']}
30
+
31
+ Explanation: {example['explanation']}"""
32
+
33
+ # Format as chat messages
34
+ return {
35
+ 'messages': [
36
+ {'role': 'user', 'content': question_with_options},
37
+ {'role': 'assistant', 'content': answer_text}
38
+ ]
39
+ }
40
+
41
+ print('Preprocessing dataset...')
42
+ dataset = raw_dataset.map(format_for_sft, remove_columns=raw_dataset.column_names)
43
+
44
+ # Create train/eval split (90/10)
45
+ print('Creating train/eval split...')
46
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
47
+
48
+ print(f'Train examples: {len(dataset_split["train"])}')
49
+ print(f'Eval examples: {len(dataset_split["test"])}')
50
+
51
+ # Configure LoRA for efficient fine-tuning
52
+ peft_config = LoraConfig(
53
+ r=16,
54
+ lora_alpha=32,
55
+ lora_dropout=0.05,
56
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
57
+ bias="none",
58
+ task_type="CAUSAL_LM"
59
+ )
60
+
61
+ # Configure SFT training
62
+ training_args = SFTConfig(
63
+ output_dir="qwen3-telecom-finetuned",
64
+
65
+ # Training hyperparameters
66
+ num_train_epochs=3,
67
+ per_device_train_batch_size=2,
68
+ per_device_eval_batch_size=2,
69
+ gradient_accumulation_steps=8, # Effective batch size = 16
70
+
71
+ # Optimization
72
+ learning_rate=2e-4,
73
+ lr_scheduler_type="cosine",
74
+ warmup_ratio=0.1,
75
+
76
+ # Evaluation and saving
77
+ eval_strategy="steps",
78
+ eval_steps=100,
79
+ save_strategy="steps",
80
+ save_steps=200,
81
+ save_total_limit=3,
82
+
83
+ # Logging and monitoring
84
+ logging_steps=10,
85
+ report_to="trackio",
86
+ run_name="qwen3-0.6b-telecom-domain-adaptation",
87
+ project="telecom-finetuning",
88
+
89
+ # Memory optimization
90
+ gradient_checkpointing=True,
91
+ bf16=True,
92
+
93
+ # Hub integration
94
+ push_to_hub=True,
95
+ hub_model_id="wlabchoi/qwen3-0.6b-telecom",
96
+ hub_strategy="every_save",
97
+ hub_private_repo=False,
98
+ )
99
+
100
+ # Initialize trainer
101
+ print('Initializing SFT trainer...')
102
+ trainer = SFTTrainer(
103
+ model="Qwen/Qwen3-0.6B",
104
+ train_dataset=dataset_split["train"],
105
+ eval_dataset=dataset_split["test"],
106
+ peft_config=peft_config,
107
+ args=training_args,
108
+ )
109
+
110
+ # Start training
111
+ print('Starting training...')
112
+ trainer.train()
113
+
114
+ # Push final model to Hub
115
+ print('Pushing final model to Hub...')
116
+ trainer.push_to_hub(commit_message="Training complete - Qwen3-0.6B fine-tuned on TeleQnA")
117
+
118
+ print('Training completed successfully!')