File size: 3,097 Bytes
b25b8f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | # Unified Colab Script for MVM2 QLoRA Training
import os
print("Installing dependencies...")
os.system("pip install -q -U torch datasets trl peft transformers unsloth")
import json
from datasets import load_dataset
print("1. Generating Dataset locally (No API limits!)...")
dataset = load_dataset("gsm8k", "main", split="train")
def format_gsm8k(example):
parts = example["answer"].split("####")
reasoning = [step.strip() for step in parts[0].strip().split('\n') if step.strip()]
final_answer = parts[1].strip() if len(parts) > 1 else ""
json_data = {
"final_answer": final_answer,
"reasoning_trace": reasoning,
"confidence_explanation": "Deterministic symbolic steps logically verified."
}
return {
"messages": [
{"role": "system", "content": "You are an MVM2 math reasoning agent. You strictly output JSON triplets: {final_answer, reasoning_trace, confidence_explanation}."},
{"role": "user", "content": example["question"]},
{"role": "assistant", "content": json.dumps(json_data)}
]
}
print("Mapping dataset to MVM2 Triplets...")
formatted_dataset = dataset.map(format_gsm8k, remove_columns=["question", "answer"])
# Use 1000 samples for a fast, targeted training run
small_dataset = formatted_dataset.select(range(1000))
print("2. Starting Unsloth Training on T4 GPU...")
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit",
max_seq_length = 2048,
dtype = None,
load_in_4bit = True,
)
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
)
def format_chatml(examples):
texts = []
for messages in examples["messages"]:
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
texts.append(text)
return {"text": texts}
train_data = small_dataset.map(format_chatml, batched=True)
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = train_data,
dataset_text_field = "text",
max_seq_length = 2048,
dataset_num_proc = 2,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
max_steps = 60, # 60 steps for a very fast 5-minute training demonstration
learning_rate = 2e-4,
fp16 = True,
logging_steps = 10,
optim = "adamw_8bit",
output_dir = "outputs",
),
)
trainer.train()
model.save_pretrained("mvm2_lora_model")
tokenizer.save_pretrained("mvm2_lora_model")
print("\n✅ Training Complete! The LoRA adapter is saved to 'mvm2_lora_model'.")
|