File size: 3,097 Bytes
b25b8f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Unified Colab Script for MVM2 QLoRA Training
import os
print("Installing dependencies...")
os.system("pip install -q -U torch datasets trl peft transformers unsloth")

import json
from datasets import load_dataset

print("1. Generating Dataset locally (No API limits!)...")
dataset = load_dataset("gsm8k", "main", split="train")

def format_gsm8k(example):
    parts = example["answer"].split("####")
    reasoning = [step.strip() for step in parts[0].strip().split('\n') if step.strip()]
    final_answer = parts[1].strip() if len(parts) > 1 else ""
    
    json_data = {
        "final_answer": final_answer,
        "reasoning_trace": reasoning,
        "confidence_explanation": "Deterministic symbolic steps logically verified."
    }
    
    return {
        "messages": [
            {"role": "system", "content": "You are an MVM2 math reasoning agent. You strictly output JSON triplets: {final_answer, reasoning_trace, confidence_explanation}."},
            {"role": "user", "content": example["question"]},
            {"role": "assistant", "content": json.dumps(json_data)}
        ]
    }

print("Mapping dataset to MVM2 Triplets...")
formatted_dataset = dataset.map(format_gsm8k, remove_columns=["question", "answer"])
# Use 1000 samples for a fast, targeted training run
small_dataset = formatted_dataset.select(range(1000))

print("2. Starting Unsloth Training on T4 GPU...")
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit",
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

def format_chatml(examples):
    texts = []
    for messages in examples["messages"]:
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        texts.append(text)
    return {"text": texts}

train_data = small_dataset.map(format_chatml, batched=True)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_data,
    dataset_text_field = "text",
    max_seq_length = 2048,
    dataset_num_proc = 2,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        max_steps = 60, # 60 steps for a very fast 5-minute training demonstration
        learning_rate = 2e-4,
        fp16 = True,
        logging_steps = 10,
        optim = "adamw_8bit",
        output_dir = "outputs",
    ),
)

trainer.train()

model.save_pretrained("mvm2_lora_model")
tokenizer.save_pretrained("mvm2_lora_model")
print("\n✅ Training Complete! The LoRA adapter is saved to 'mvm2_lora_model'.")