LordNeel commited on
Commit
8af7593
·
verified ·
1 Parent(s): 808673c

Upload train_glm47_flash_test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_glm47_flash_test.py +155 -0
train_glm47_flash_test.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "transformers @ git+https://github.com/huggingface/transformers.git",
6
+ # "trl>=0.12.0",
7
+ # "peft>=0.7.0",
8
+ # "accelerate>=0.24.0",
9
+ # "datasets",
10
+ # "bitsandbytes",
11
+ # ]
12
+ # ///
13
+
14
+ """
15
+ TEST RUN: Fine-tune GLM-4.7-Flash on small sample (50 examples, 20 steps)
16
+ """
17
+
18
+ import os
19
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
20
+
21
+ import torch
22
+ import gc
23
+ from datasets import load_dataset
24
+ from peft import LoraConfig, TaskType, get_peft_model
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
26
+ from trl import SFTTrainer, SFTConfig
27
+
28
+ MODEL_NAME = "zai-org/GLM-4.7-Flash"
29
+ DATASET_NAME = "LordNeel/unblinded-mastery-sharegpt"
30
+
31
+ print("=" * 60)
32
+ print("TEST RUN: GLM-4.7-Flash (50 examples, 20 steps)")
33
+ print("=" * 60)
34
+
35
+ # Load small sample
36
+ print("\nLoading dataset (50 examples only)...")
37
+ dataset = load_dataset(DATASET_NAME, split="train[:50]")
38
+ print(f"Dataset loaded: {len(dataset)} examples")
39
+
40
+ # 4-bit quantization
41
+ print("\nSetting up 4-bit quantization...")
42
+ bnb_config = BitsAndBytesConfig(
43
+ load_in_4bit=True,
44
+ bnb_4bit_quant_type="nf4",
45
+ bnb_4bit_compute_dtype=torch.bfloat16,
46
+ bnb_4bit_use_double_quant=True,
47
+ )
48
+
49
+ # Load tokenizer
50
+ print("\nLoading tokenizer...")
51
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
52
+ if tokenizer.pad_token is None:
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+
55
+ # Load model
56
+ print("\nLoading model with 4-bit quantization...")
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ MODEL_NAME,
59
+ quantization_config=bnb_config,
60
+ device_map="auto",
61
+ trust_remote_code=True,
62
+ torch_dtype=torch.bfloat16,
63
+ low_cpu_mem_usage=True,
64
+ use_cache=False,
65
+ attn_implementation="eager",
66
+ )
67
+ print("Model loaded!")
68
+
69
+ # Enable gradient checkpointing and input gradients
70
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
71
+ model.enable_input_require_grads()
72
+
73
+ # Clear memory
74
+ gc.collect()
75
+ torch.cuda.empty_cache()
76
+ print(f"GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB allocated")
77
+
78
+ # Find linear layers for LoRA
79
+ print("\nFinding linear layers for LoRA...")
80
+ def find_all_linear_names(model):
81
+ cls = torch.nn.Linear
82
+ lora_module_names = set()
83
+ for name, module in model.named_modules():
84
+ if isinstance(module, cls):
85
+ names = name.split('.')
86
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
87
+ if 'lm_head' in lora_module_names:
88
+ lora_module_names.remove('lm_head')
89
+ return list(lora_module_names)
90
+
91
+ target_modules = find_all_linear_names(model)
92
+ print(f"Target modules: {target_modules}")
93
+
94
+ # LoRA config - small rank for testing
95
+ print("\nConfiguring LoRA...")
96
+ peft_config = LoraConfig(
97
+ r=8,
98
+ lora_alpha=16,
99
+ lora_dropout=0.05,
100
+ bias="none",
101
+ task_type=TaskType.CAUSAL_LM,
102
+ target_modules=target_modules,
103
+ )
104
+
105
+ model = get_peft_model(model, peft_config)
106
+ model.print_trainable_parameters()
107
+
108
+ # Format function
109
+ def format_sharegpt(example):
110
+ messages = []
111
+ for turn in example["conversations"]:
112
+ role_map = {"system": "system", "human": "user", "gpt": "assistant"}
113
+ role = role_map.get(turn["from"], turn["from"])
114
+ messages.append({"role": role, "content": turn["value"]})
115
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
116
+ return {"text": text}
117
+
118
+ print("\nFormatting dataset...")
119
+ dataset = dataset.map(format_sharegpt, remove_columns=dataset.column_names)
120
+
121
+ # Training config - minimal for testing
122
+ print("\nConfiguring training (20 steps only)...")
123
+ training_config = SFTConfig(
124
+ output_dir="test-output",
125
+ max_steps=20, # Just 20 steps to test
126
+ per_device_train_batch_size=1,
127
+ gradient_accumulation_steps=4,
128
+ learning_rate=2e-4,
129
+ max_seq_length=512, # Short for testing
130
+ gradient_checkpointing=True,
131
+ gradient_checkpointing_kwargs={"use_reentrant": False},
132
+ logging_steps=5,
133
+ bf16=True,
134
+ optim="paged_adamw_8bit",
135
+ dataset_text_field="text",
136
+ report_to="none", # No tracking for test
137
+ )
138
+
139
+ # Train
140
+ print("\nInitializing trainer...")
141
+ trainer = SFTTrainer(
142
+ model=model,
143
+ train_dataset=dataset,
144
+ args=training_config,
145
+ tokenizer=tokenizer,
146
+ )
147
+
148
+ print("\n" + "=" * 60)
149
+ print("STARTING TEST TRAINING (20 steps)")
150
+ print("=" * 60)
151
+ trainer.train()
152
+
153
+ print("\n" + "=" * 60)
154
+ print("TEST COMPLETE! Training works.")
155
+ print("=" * 60)