|
|
|
|
|
|
|
|
|
|
|
|
|
|
from datasets import load_dataset |
|
|
from peft import LoraConfig |
|
|
from trl import SFTTrainer, SFTConfig |
|
|
import trackio |
|
|
|
|
|
|
|
|
print('Loading TeleQnA dataset...') |
|
|
raw_dataset = load_dataset('netop/TeleQnA', split='test') |
|
|
|
|
|
def format_for_sft(example): |
|
|
"""Convert TeleQnA format to chat messages format""" |
|
|
|
|
|
|
|
|
|
|
|
choices_text = [] |
|
|
if 'choices' in example and example['choices']: |
|
|
for i, choice in enumerate(example['choices'], 1): |
|
|
choices_text.append(f'{i}. {choice}') |
|
|
|
|
|
question_with_options = f"""{example['question']} |
|
|
|
|
|
Options: |
|
|
{chr(10).join(choices_text)}""" |
|
|
|
|
|
|
|
|
explanation = example.get('explaination', '') or example.get('explanation', '') |
|
|
answer_text = f"""{example['answer']} |
|
|
|
|
|
Explanation: {explanation}""" |
|
|
|
|
|
|
|
|
return { |
|
|
'messages': [ |
|
|
{'role': 'user', 'content': question_with_options}, |
|
|
{'role': 'assistant', 'content': answer_text} |
|
|
] |
|
|
} |
|
|
|
|
|
print('Preprocessing dataset...') |
|
|
dataset = raw_dataset.map(format_for_sft, remove_columns=raw_dataset.column_names) |
|
|
|
|
|
|
|
|
print('Creating train/eval split...') |
|
|
dataset_split = dataset.train_test_split(test_size=0.1, seed=42) |
|
|
|
|
|
print(f'Train examples: {len(dataset_split["train"])}') |
|
|
print(f'Eval examples: {len(dataset_split["test"])}') |
|
|
|
|
|
|
|
|
peft_config = LoraConfig( |
|
|
r=16, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.05, |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM" |
|
|
) |
|
|
|
|
|
|
|
|
training_args = SFTConfig( |
|
|
output_dir="qwen3-telecom-finetuned", |
|
|
|
|
|
|
|
|
num_train_epochs=3, |
|
|
per_device_train_batch_size=2, |
|
|
per_device_eval_batch_size=2, |
|
|
gradient_accumulation_steps=8, |
|
|
|
|
|
|
|
|
learning_rate=2e-4, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.1, |
|
|
|
|
|
|
|
|
eval_strategy="steps", |
|
|
eval_steps=100, |
|
|
save_strategy="steps", |
|
|
save_steps=200, |
|
|
save_total_limit=3, |
|
|
|
|
|
|
|
|
logging_steps=10, |
|
|
report_to="trackio", |
|
|
run_name="qwen3-0.6b-telecom-domain-adaptation", |
|
|
project="telecom-finetuning", |
|
|
|
|
|
|
|
|
gradient_checkpointing=True, |
|
|
bf16=True, |
|
|
|
|
|
|
|
|
push_to_hub=True, |
|
|
hub_model_id="wlabchoi/qwen3-0.6b-telecom", |
|
|
hub_strategy="every_save", |
|
|
hub_private_repo=False, |
|
|
) |
|
|
|
|
|
|
|
|
print('Initializing SFT trainer...') |
|
|
trainer = SFTTrainer( |
|
|
model="Qwen/Qwen3-0.6B", |
|
|
train_dataset=dataset_split["train"], |
|
|
eval_dataset=dataset_split["test"], |
|
|
peft_config=peft_config, |
|
|
args=training_args, |
|
|
) |
|
|
|
|
|
|
|
|
print('Starting training...') |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
print('Pushing final model to Hub...') |
|
|
trainer.push_to_hub(commit_message="Training complete - Qwen3-0.6B fine-tuned on TeleQnA") |
|
|
|
|
|
print('Training completed successfully!') |
|
|
|