wheattoast11 commited on
Commit
9e59c32
·
verified ·
1 Parent(s): a2bb66e

Upload train_glm_qlora_v10.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_glm_qlora_v10.py +128 -0
train_glm_qlora_v10.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers @ git+https://github.com/huggingface/transformers.git",
7
+ # "accelerate @ git+https://github.com/huggingface/accelerate.git",
8
+ # "bitsandbytes>=0.45.0",
9
+ # "trackio",
10
+ # "datasets",
11
+ # ]
12
+ # ///
13
+
14
+ """
15
+ Agent Zero SFT: zai-org/GLM-4.7-Flash (30B MoE)
16
+ QLoRA (4-bit) with CPU offload + monkey-patch for Params4bit compat.
17
+ """
18
+
19
+ import os
20
+ import torch
21
+
22
+ # Monkey-patch Params4bit to accept _is_hf_initialized kwarg
23
+ # This fixes accelerate<->bitsandbytes incompatibility where accelerate
24
+ # passes _is_hf_initialized as a kwarg but Params4bit doesn't accept it.
25
+ import bitsandbytes as bnb
26
+ _orig_params4bit_new = bnb.nn.Params4bit.__new__
27
+ def _patched_params4bit_new(cls, *args, **kwargs):
28
+ kwargs.pop('_is_hf_initialized', None)
29
+ return _orig_params4bit_new(cls, *args, **kwargs)
30
+ bnb.nn.Params4bit.__new__ = _patched_params4bit_new
31
+ print("Patched Params4bit to accept _is_hf_initialized")
32
+
33
+ import trackio
34
+ from huggingface_hub import login
35
+ login(token=os.environ["HF_TOKEN"])
36
+
37
+ from datasets import load_dataset
38
+ from peft import LoraConfig
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
40
+ from trl import SFTTrainer, SFTConfig
41
+
42
+ print("Loading dataset...")
43
+ train_ds = load_dataset("wheattoast11/agent-zero-sft-v1", data_files="data/train.jsonl", split="train")
44
+ val_ds = load_dataset("wheattoast11/agent-zero-sft-v1", data_files="data/validation.jsonl", split="train")
45
+ print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")
46
+
47
+ bnb_config = BitsAndBytesConfig(
48
+ load_in_4bit=True,
49
+ bnb_4bit_quant_type="nf4",
50
+ bnb_4bit_compute_dtype=torch.bfloat16,
51
+ bnb_4bit_use_double_quant=True,
52
+ llm_int8_enable_fp32_cpu_offload=True,
53
+ )
54
+
55
+ offload_dir = "/tmp/offload"
56
+ os.makedirs(offload_dir, exist_ok=True)
57
+
58
+ print("Loading model in 4-bit with CPU offload...")
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ "zai-org/GLM-4.7-Flash",
61
+ quantization_config=bnb_config,
62
+ trust_remote_code=True,
63
+ device_map="auto",
64
+ max_memory={0: "21GiB", "cpu": "40GiB"},
65
+ offload_folder=offload_dir,
66
+ torch_dtype=torch.bfloat16,
67
+ )
68
+ tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-4.7-Flash", trust_remote_code=True)
69
+ print("Model loaded.")
70
+
71
+ if hasattr(model, 'hf_device_map'):
72
+ devices = {}
73
+ for v in model.hf_device_map.values():
74
+ devices[str(v)] = devices.get(str(v), 0) + 1
75
+ print(f"Device distribution: {devices}")
76
+
77
+ import subprocess
78
+ result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
79
+ print(result.stdout)
80
+
81
+ config = SFTConfig(
82
+ output_dir="agent-zero-glm-4.7-v1",
83
+ push_to_hub=True,
84
+ hub_model_id="wheattoast11/agent-zero-glm-4.7-v1",
85
+ hub_strategy="every_save",
86
+ hub_private_repo=True,
87
+ num_train_epochs=2,
88
+ per_device_train_batch_size=1,
89
+ gradient_accumulation_steps=16,
90
+ learning_rate=1e-4,
91
+ bf16=True,
92
+ gradient_checkpointing=True,
93
+ logging_steps=10,
94
+ save_strategy="steps",
95
+ save_steps=50,
96
+ save_total_limit=2,
97
+ eval_strategy="steps",
98
+ eval_steps=50,
99
+ warmup_ratio=0.1,
100
+ lr_scheduler_type="cosine",
101
+ report_to="trackio",
102
+ project="agent-zero-finetune",
103
+ run_name="glm-4.7-flash-qlora-v1",
104
+ )
105
+
106
+ peft_config = LoraConfig(
107
+ r=16, lora_alpha=32, lora_dropout=0.05,
108
+ bias="none", task_type="CAUSAL_LM",
109
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
110
+ )
111
+
112
+ print("Initializing trainer...")
113
+ trainer = SFTTrainer(
114
+ model=model,
115
+ tokenizer=tokenizer,
116
+ train_dataset=train_ds,
117
+ eval_dataset=val_ds,
118
+ args=config,
119
+ peft_config=peft_config,
120
+ )
121
+
122
+ print("Starting training...")
123
+ trainer.train()
124
+
125
+ print("Pushing to Hub...")
126
+ trainer.push_to_hub()
127
+ trackio.finish()
128
+ print("Done!")