File size: 3,963 Bytes
9ec3d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Step 3: Setting up the model for fine-tuning with LoRA
"""

from pathlib import Path

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer


def load_base_model(model_name: str = "Qwen/Qwen2.5-3B-Instruct"):
    """
    Load the base model and tokenizer.
    """
    print(f"Loading model: {model_name}")
    print("(First run will download ~6GB to ~/.cache/huggingface/)")

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Ensure tokenizer has a pad token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Check if MPS (Apple Silicon) is available
    if torch.backends.mps.is_available():
        print("Using Apple MPS (Metal) backend")
        model = AutoModelForCausalLM.from_pretrained(
            model_name, dtype=torch.float16, trust_remote_code=True
        )
        model = model.to("mps")
    else:
        print("MPS not available, using CPU (this will be slow)")
        model = AutoModelForCausalLM.from_pretrained(
            model_name, dtype=torch.float32, trust_remote_code=True
        )

    return model, tokenizer


def apply_lora(model):
    """
    Apply LoRA adapters to the model for efficient fine-tuning.
    """
    print("\nApplying LoRA configuration...")

    lora_config = LoraConfig(
        r=16,  # Rank of the update matrices
        lora_alpha=32,  # Scaling factor
        target_modules=[  # Which layers to adapt
            "q_proj",
            "v_proj",
            "k_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_dropout=0.05,  # Dropout for regularization
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    return model


def setup_for_training(model_name: str = "Qwen/Qwen2.5-3B-Instruct"):
    """
    Complete setup: load model and apply LoRA.
    """
    model, tokenizer = load_base_model(model_name)
    peft_model = apply_lora(model)
    return peft_model, tokenizer


def test_inference(model, tokenizer, prompt: str):
    """
    Quick test to verify the model works.
    """
    print(f"\nTest prompt: {prompt[:50]}...")

    device = next(model.parameters()).device
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Get only the new tokens (remove the prompt)
    new_text = response[len(prompt) :].strip()
    print(f"Model output: {new_text}")
    return new_text


# Run this script directly to test the setup
if __name__ == "__main__":
    print("=" * 60)
    print("Step 3: Model Setup Test")
    print("=" * 60)

    # Verify MPS is available
    print(f"\n[Environment Check]")
    print(f"  MPS Available: {torch.backends.mps.is_available()}")
    print(f"  MPS Built: {torch.backends.mps.is_built()}")
    print(f"  PyTorch version: {torch.__version__}")

    # Load and setup the model
    print(f"\n[Loading Model]")
    model, tokenizer = setup_for_training()

    print(f"\n[Status]")
    print(f"  ✓ Model loaded successfully")
    print(f"  ✓ LoRA adapters applied")
    print(f"  Device: {next(model.parameters()).device}")

    # Quick inference test
    print(f"\n[Quick Inference Test]")
    test_prompt = "What is 2 + 2? Answer with just the number:"
    test_inference(model, tokenizer, test_prompt)

    print("\n" + "=" * 60)
    print("✓ Setup complete! Ready for training.")
    print("=" * 60)

    # Summary of what was cached
    print(f"\n[Cache Location]")
    print(f"  Model cached at: ~/.cache/huggingface/hub/")
    print(f"  (This is reused for future runs)")