training-scripts / train_glm47_flash.py
LordNeel's picture
Upload train_glm47_flash.py with huggingface_hub
7ce901f verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch>=2.0.0",
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "trl>=0.12.0",
# "peft>=0.7.0",
# "accelerate>=0.24.0",
# "datasets",
# "trackio",
# "bitsandbytes",
# ]
# ///
"""
Fine-tune GLM-4.7-Flash on Unblinded Mastery dataset for QA and instruction following.
Using TRL SFTTrainer with LoRA on H100.
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import gc
import trackio
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
# Configuration
MODEL_NAME = "zai-org/GLM-4.7-Flash"
DATASET_NAME = "LordNeel/unblinded-mastery-sharegpt"
OUTPUT_MODEL = "LordNeel/GLM-4.7-Flash-Unblinded-Mastery"
print("=" * 60)
print("GLM-4.7-Flash Fine-tuning for Unblinded Mastery")
print("=" * 60)
# Load dataset
print("\nLoading dataset...")
dataset = load_dataset(DATASET_NAME, split="train")
print(f"Dataset loaded: {len(dataset)} examples")
# Create train/eval split
print("Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.05, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
print(f" Train: {len(train_dataset)} examples")
print(f" Eval: {len(eval_dataset)} examples")
# 4-bit quantization config for memory efficiency
print("\nSetting up 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Tokenizer loaded. Vocab size: {len(tokenizer)}")
# Load model with 4-bit quantization
print("\nLoading model with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_cache=False, # Disable KV cache for training
attn_implementation="eager", # Use standard attention to save memory
)
print("Model loaded!")
# Enable gradient checkpointing
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
# Enable input gradients for LoRA (lighter than prepare_model_for_kbit_training)
model.enable_input_require_grads()
# Clear memory
gc.collect()
torch.cuda.empty_cache()
print(f"GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB allocated")
# Find all linear layer names for LoRA
print("\nFinding linear layers for LoRA...")
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
# Remove output layer
if 'lm_head' in lora_module_names:
lora_module_names.remove('lm_head')
return list(lora_module_names)
target_modules = find_all_linear_names(model)
print(f" Found target modules: {target_modules}")
# LoRA configuration - using lower rank for memory efficiency
print("\nConfiguring LoRA...")
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
)
# Apply LoRA
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Format function for ShareGPT conversations
def format_sharegpt(example):
"""Format ShareGPT conversations to chat template."""
messages = []
for turn in example["conversations"]:
role_map = {"system": "system", "human": "user", "gpt": "assistant"}
role = role_map.get(turn["from"], turn["from"])
messages.append({"role": role, "content": turn["value"]})
# Apply chat template
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
return {"text": text}
# Format datasets
print("\nFormatting datasets...")
train_dataset = train_dataset.map(format_sharegpt, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(format_sharegpt, remove_columns=eval_dataset.column_names)
print("Datasets formatted!")
# Training configuration
print("\nConfiguring training...")
training_config = SFTConfig(
# Hub settings - CRITICAL for saving
output_dir=OUTPUT_MODEL.split("/")[-1],
push_to_hub=True,
hub_model_id=OUTPUT_MODEL,
hub_strategy="every_save",
hub_private_repo=False,
# Training parameters
num_train_epochs=3,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=16, # Effective batch size: 16
learning_rate=2e-4,
max_length=1024, # Reduced for memory
# Memory optimization
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
# Logging & checkpointing
logging_steps=10,
save_strategy="steps",
save_steps=100,
save_total_limit=3,
# Evaluation
eval_strategy="steps",
eval_steps=100,
# Optimization
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
# Precision
bf16=True,
fp16=False,
# Monitoring
report_to="trackio",
project="unblinded-mastery-finetuning",
run_name="glm47flash-sft-lora",
# Dataset
dataset_text_field="text",
packing=False,
)
# Initialize trainer
print("\nInitializing trainer...")
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_config,
processing_class=tokenizer,
peft_config=None, # Already applied above
)
# Train
print("\n" + "=" * 60)
print("STARTING TRAINING")
print("=" * 60)
trainer.train()
# Save and push to hub
print("\nSaving model to Hub...")
trainer.save_model()
trainer.push_to_hub()
# Finish tracking
trackio.finish()
print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print(f"Model saved to: https://huggingface.co/{OUTPUT_MODEL}")
print(f"View metrics at: https://huggingface.co/spaces/LordNeel/trackio")
print("=" * 60)