File size: 4,516 Bytes
20fa6f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "transformers @ git+https://github.com/huggingface/transformers.git",
#     "accelerate @ git+https://github.com/huggingface/accelerate.git",
#     "bitsandbytes>=0.45.0",
#     "trackio",
#     "datasets",
# ]
# ///

"""
Agent Zero SFT: zai-org/GLM-4.7-Flash (30B MoE)
QLoRA (4-bit) on l40sx1 (48GB) with monkey-patch for CPU offload compat.
Patches both Params4bit.__new__ and quant_state.as_dict for meta tensors.
"""

import os
import torch

# === Monkey-patches for bitsandbytes + accelerate CPU offload compat ===

import bitsandbytes as bnb
from bitsandbytes import functional as bnb_func

# Patch 1: Params4bit.__new__ to accept _is_hf_initialized kwarg
_orig_params4bit_new = bnb.nn.Params4bit.__new__
def _patched_params4bit_new(cls, *args, **kwargs):
    kwargs.pop('_is_hf_initialized', None)
    return _orig_params4bit_new(cls, *args, **kwargs)
bnb.nn.Params4bit.__new__ = _patched_params4bit_new

# Patch 2: QuantState.as_dict to handle meta tensors (offset.item() fails on meta)
_orig_as_dict = bnb_func.QuantState.as_dict
def _patched_as_dict(self, packed=False):
    try:
        return _orig_as_dict(self, packed=packed)
    except RuntimeError as e:
        if "meta tensors" in str(e):
            # Return a minimal dict when on meta device
            result = {
                "quant_type": self.quant_type,
                "blocksize": self.blocksize,
            }
            if hasattr(self, 'shape'):
                result["shape"] = self.shape
            return result
        raise
bnb_func.QuantState.as_dict = _patched_as_dict

print("Patched bitsandbytes for CPU offload compat")

# === Main training script ===

import trackio
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])

from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig

print("Loading dataset...")
train_ds = load_dataset("wheattoast11/agent-zero-sft-v1", data_files="data/train.jsonl", split="train")
val_ds = load_dataset("wheattoast11/agent-zero-sft-v1", data_files="data/validation.jsonl", split="train")
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True,
)

offload_dir = "/tmp/offload"
os.makedirs(offload_dir, exist_ok=True)

print("Loading model in 4-bit with CPU offload on l40sx1...")
model = AutoModelForCausalLM.from_pretrained(
    "zai-org/GLM-4.7-Flash",
    quantization_config=bnb_config,
    trust_remote_code=True,
    device_map="auto",
    max_memory={0: "44GiB", "cpu": "60GiB"},
    offload_folder=offload_dir,
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-4.7-Flash", trust_remote_code=True)
print("Model loaded.")

if hasattr(model, 'hf_device_map'):
    devices = {}
    for v in model.hf_device_map.values():
        devices[str(v)] = devices.get(str(v), 0) + 1
    print(f"Device distribution: {devices}")

import subprocess
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
print(result.stdout)

config = SFTConfig(
    output_dir="agent-zero-glm-4.7-v1",
    push_to_hub=True,
    hub_model_id="wheattoast11/agent-zero-glm-4.7-v1",
    hub_strategy="every_save",
    hub_private_repo=True,
    num_train_epochs=2,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=1e-4,
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=10,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=2,
    eval_strategy="steps",
    eval_steps=50,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    report_to="trackio",
    project="agent-zero-finetune",
    run_name="glm-4.7-flash-qlora-v1",
)

peft_config = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    bias="none", task_type="CAUSAL_LM",
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

print("Initializing trainer...")
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    args=config,
    peft_config=peft_config,
)

print("Starting training...")
trainer.train()

print("Pushing to Hub...")
trainer.push_to_hub()
trackio.finish()
print("Done!")