File size: 4,478 Bytes
8af7593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af0763b
8af7593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390b224
8af7593
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "torch>=2.0.0",
#     "transformers @ git+https://github.com/huggingface/transformers.git",
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "accelerate>=0.24.0",
#     "datasets",
#     "bitsandbytes",
# ]
# ///

"""
TEST RUN: Fine-tune GLM-4.7-Flash on small sample (50 examples, 20 steps)
"""

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import gc
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig

MODEL_NAME = "zai-org/GLM-4.7-Flash"
DATASET_NAME = "LordNeel/unblinded-mastery-sharegpt"

print("=" * 60)
print("TEST RUN: GLM-4.7-Flash (50 examples, 20 steps)")
print("=" * 60)

# Load small sample
print("\nLoading dataset (50 examples only)...")
dataset = load_dataset(DATASET_NAME, split="train[:50]")
print(f"Dataset loaded: {len(dataset)} examples")

# 4-bit quantization
print("\nSetting up 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model
print("\nLoading model with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_cache=False,
    attn_implementation="eager",
)
print("Model loaded!")

# Enable gradient checkpointing and input gradients
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.enable_input_require_grads()

# Clear memory
gc.collect()
torch.cuda.empty_cache()
print(f"GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB allocated")

# Find linear layers for LoRA
print("\nFinding linear layers for LoRA...")
def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

target_modules = find_all_linear_names(model)
print(f"Target modules: {target_modules}")

# LoRA config - small rank for testing
print("\nConfiguring LoRA...")
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=target_modules,
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Format function
def format_sharegpt(example):
    messages = []
    for turn in example["conversations"]:
        role_map = {"system": "system", "human": "user", "gpt": "assistant"}
        role = role_map.get(turn["from"], turn["from"])
        messages.append({"role": role, "content": turn["value"]})
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    return {"text": text}

print("\nFormatting dataset...")
dataset = dataset.map(format_sharegpt, remove_columns=dataset.column_names)

# Training config - minimal for testing
print("\nConfiguring training (20 steps only)...")
training_config = SFTConfig(
    output_dir="test-output",
    max_steps=20,  # Just 20 steps to test
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    max_length=512,  # Short for testing
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    logging_steps=5,
    bf16=True,
    optim="paged_adamw_8bit",
    dataset_text_field="text",
    report_to="none",  # No tracking for test
)

# Train
print("\nInitializing trainer...")
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=training_config,
    processing_class=tokenizer,
)

print("\n" + "=" * 60)
print("STARTING TEST TRAINING (20 steps)")
print("=" * 60)
trainer.train()

print("\n" + "=" * 60)
print("TEST COMPLETE! Training works.")
print("=" * 60)