training-scripts / train_glm47_flash_test.py
LordNeel's picture
Upload train_glm47_flash_test.py with huggingface_hub
390b224 verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch>=2.0.0",
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "trl>=0.12.0",
# "peft>=0.7.0",
# "accelerate>=0.24.0",
# "datasets",
# "bitsandbytes",
# ]
# ///
"""
TEST RUN: Fine-tune GLM-4.7-Flash on small sample (50 examples, 20 steps)
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import gc
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
MODEL_NAME = "zai-org/GLM-4.7-Flash"
DATASET_NAME = "LordNeel/unblinded-mastery-sharegpt"
print("=" * 60)
print("TEST RUN: GLM-4.7-Flash (50 examples, 20 steps)")
print("=" * 60)
# Load small sample
print("\nLoading dataset (50 examples only)...")
dataset = load_dataset(DATASET_NAME, split="train[:50]")
print(f"Dataset loaded: {len(dataset)} examples")
# 4-bit quantization
print("\nSetting up 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
print("\nLoading model with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_cache=False,
attn_implementation="eager",
)
print("Model loaded!")
# Enable gradient checkpointing and input gradients
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.enable_input_require_grads()
# Clear memory
gc.collect()
torch.cuda.empty_cache()
print(f"GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB allocated")
# Find linear layers for LoRA
print("\nFinding linear layers for LoRA...")
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names:
lora_module_names.remove('lm_head')
return list(lora_module_names)
target_modules = find_all_linear_names(model)
print(f"Target modules: {target_modules}")
# LoRA config - small rank for testing
print("\nConfiguring LoRA...")
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Format function
def format_sharegpt(example):
messages = []
for turn in example["conversations"]:
role_map = {"system": "system", "human": "user", "gpt": "assistant"}
role = role_map.get(turn["from"], turn["from"])
messages.append({"role": role, "content": turn["value"]})
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
return {"text": text}
print("\nFormatting dataset...")
dataset = dataset.map(format_sharegpt, remove_columns=dataset.column_names)
# Training config - minimal for testing
print("\nConfiguring training (20 steps only)...")
training_config = SFTConfig(
output_dir="test-output",
max_steps=20, # Just 20 steps to test
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=2e-4,
max_length=512, # Short for testing
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_steps=5,
bf16=True,
optim="paged_adamw_8bit",
dataset_text_field="text",
report_to="none", # No tracking for test
)
# Train
print("\nInitializing trainer...")
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_config,
processing_class=tokenizer,
)
print("\n" + "=" * 60)
print("STARTING TEST TRAINING (20 steps)")
print("=" * 60)
trainer.train()
print("\n" + "=" * 60)
print("TEST COMPLETE! Training works.")
print("=" * 60)