llm-training / src /prolewiki_llm /train_headless.py
percyraskova's picture
Upload folder using huggingface_hub
81b3473 verified
#!/usr/bin/env python3
"""
Headless GRPO Training for RunPod Deployment.
This script is designed for containerized, non-interactive execution on RunPod.
It adapts train_grpo_marxist.py for headless operation with:
- Environment variable configuration
- Checkpoint resumption support
- Automatic model upload to HuggingFace Hub
- W&B logging for remote monitoring
- Self-termination capability
Environment Variables:
Required:
HF_TOKEN - HuggingFace API token for model upload
WANDB_API_KEY - Weights & Biases API key
Optional (with defaults):
HF_REPO - Target repo for model upload (default: prolewiki/marxist-grpo-lora)
RUNPOD_POD_ID - Pod ID for self-termination after training
MODEL_NAME - Base model (default: unsloth/DeepSeek-R1-0528-Qwen3-8B)
MAX_STEPS - Training steps (default: 500)
BATCH_SIZE - Per-device batch size (default: 2)
LEARNING_RATE - Learning rate (default: 5e-6)
REWARD_MODE - FULL, ROBUST, or LEGACY (default: FULL)
DATASET_PATH - Path to grpo_dataset.jsonl (default: /workspace/dataset.jsonl)
CHECKPOINT_DIR - Directory for checkpoints (default: /workspace/checkpoints)
LORA_OUTPUT - Directory for final LoRA (default: /workspace/lora-output)
Usage:
# In container:
python -m prolewiki_llm.train_headless
# With environment overrides:
MAX_STEPS=100 REWARD_MODE=ROBUST python -m prolewiki_llm.train_headless
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
# =============================================================================
# CRITICAL: Disable torch.compile BEFORE any imports
# =============================================================================
# These environment variables prevent torch.compile from spawning inductor
# compilation workers that hang indefinitely on RunPod/WSL2/Jupyter.
# See: https://github.com/unslothai/unsloth/issues/3432
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
os.environ["TORCH_COMPILE"] = "0"
os.environ["TORCHINDUCTOR_DISABLE"] = "1"
os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
def get_env(key: str, default: str | None = None, required: bool = False) -> str:
"""Get environment variable with optional default and required check."""
value = os.environ.get(key, default)
if required and value is None:
print(f"ERROR: Required environment variable {key} not set", file=sys.stderr)
sys.exit(1)
return value # type: ignore[return-value]
def get_env_int(key: str, default: int) -> int:
"""Get environment variable as integer."""
return int(os.environ.get(key, str(default)))
def get_env_float(key: str, default: float) -> float:
"""Get environment variable as float."""
return float(os.environ.get(key, str(default)))
# =============================================================================
# CONFIGURATION FROM ENVIRONMENT
# =============================================================================
# Required secrets
HF_TOKEN = get_env("HF_TOKEN", required=True)
WANDB_API_KEY = get_env("WANDB_API_KEY", required=True)
# Model configuration
MODEL_NAME = get_env("MODEL_NAME", "unsloth/DeepSeek-R1-0528-Qwen3-8B")
MAX_SEQ_LENGTH = get_env_int("MAX_SEQ_LENGTH", 2048)
LORA_RANK = get_env_int("LORA_RANK", 32)
# Training configuration
MAX_STEPS = get_env_int("MAX_STEPS", 500)
SAVE_STEPS = get_env_int("SAVE_STEPS", 50)
LEARNING_RATE = get_env_float("LEARNING_RATE", 5e-6)
WARMUP_RATIO = get_env_float("WARMUP_RATIO", 0.1)
BATCH_SIZE = get_env_int("BATCH_SIZE", 2)
GRADIENT_ACCUMULATION = get_env_int("GRADIENT_ACCUMULATION", 2)
NUM_GENERATIONS = get_env_int("NUM_GENERATIONS", 4)
GPU_MEMORY_UTILIZATION = get_env_float("GPU_MEMORY_UTILIZATION", 0.6)
# Sequence lengths
MAX_PROMPT_LENGTH = get_env_int("MAX_PROMPT_LENGTH", 512)
MAX_COMPLETION_LENGTH = get_env_int("MAX_COMPLETION_LENGTH", 1500)
# Paths
DATASET_PATH = Path(get_env("DATASET_PATH", "/workspace/dataset.jsonl"))
CHECKPOINT_DIR = Path(get_env("CHECKPOINT_DIR", "/workspace/checkpoints"))
LORA_OUTPUT = Path(get_env("LORA_OUTPUT", "/workspace/lora-output"))
OUTPUT_DIR = Path(get_env("OUTPUT_DIR", "/workspace/outputs"))
# Upload configuration
HF_REPO = get_env("HF_REPO", "prolewiki/marxist-grpo-lora")
# Reward mode: FULL, ROBUST, or LEGACY
REWARD_MODE = get_env("REWARD_MODE", "FULL").upper()
# Pod management
RUNPOD_POD_ID = get_env("RUNPOD_POD_ID")
def find_latest_checkpoint(checkpoint_dir: Path) -> Path | None:
"""Find the latest checkpoint directory if resuming training."""
if not checkpoint_dir.exists():
return None
checkpoints = sorted(
[d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")],
key=lambda d: int(d.name.split("-")[1]),
)
if checkpoints:
return checkpoints[-1]
return None
def upload_to_hub(model_path: Path, repo_id: str, token: str) -> None:
"""Upload trained LoRA adapter to HuggingFace Hub."""
from huggingface_hub import HfApi
print(f"\nUploading model to HuggingFace Hub: {repo_id}")
api = HfApi(token=token)
# Create repo if it doesn't exist
try:
api.create_repo(repo_id, exist_ok=True, private=True)
except Exception as e:
print(f"Note: Could not create repo (may already exist): {e}")
# Upload the LoRA adapter directory
api.upload_folder(
folder_path=str(model_path),
repo_id=repo_id,
commit_message="Headless GRPO training run",
)
print(f"Model uploaded to: https://huggingface.co/{repo_id}")
def main() -> int:
"""Run headless GRPO training."""
import torch
import wandb
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from unsloth import FastLanguageModel
from vllm import SamplingParams
from prolewiki_llm.grpo_rewards import (
completeness_reward,
debug_print_reward,
full_coherence_reward,
match_format_approximately,
match_format_exactly,
robust_coherence_reward,
semantic_similarity_reward,
terminology_reward,
)
from prolewiki_llm.wandb_logging import (
WandbSampleLogger,
create_logging_reward,
finish_wandb_logging,
init_wandb_logging,
)
print("=" * 70)
print("HEADLESS GRPO TRAINING - RUNPOD DEPLOYMENT")
print("=" * 70)
# =========================================================================
# System Info
# =========================================================================
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name()
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU: {gpu_name}")
print(f"VRAM: {gpu_mem:.1f} GB")
else:
print("ERROR: CUDA not available!", file=sys.stderr)
return 1
print(f"\nConfiguration:")
print(f" Model: {MODEL_NAME}")
print(f" Max Steps: {MAX_STEPS}")
print(f" Batch Size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f" Learning Rate: {LEARNING_RATE}")
print(f" Reward Mode: {REWARD_MODE}")
print(f" Dataset: {DATASET_PATH}")
print(f" Output: {LORA_OUTPUT}")
print(f" HF Repo: {HF_REPO}")
# =========================================================================
# Initialize W&B
# =========================================================================
print("\nInitializing Weights & Biases...")
wandb.login(key=WANDB_API_KEY)
wandb_run = init_wandb_logging(
project="marxist-grpo-headless",
config={
"model": MODEL_NAME,
"learning_rate": LEARNING_RATE,
"batch_size": BATCH_SIZE,
"gradient_accumulation": GRADIENT_ACCUMULATION,
"num_generations": NUM_GENERATIONS,
"max_steps": MAX_STEPS,
"reward_mode": REWARD_MODE,
"lora_rank": LORA_RANK,
},
tags=["grpo", "marxist-leninist", "headless", "runpod"],
)
sample_logger = WandbSampleLogger(log_every_n_steps=10, max_samples_per_log=4)
logging_reward = create_logging_reward(sample_logger, compute_all_rewards=True)
# =========================================================================
# Load Dataset
# =========================================================================
print(f"\nLoading dataset from: {DATASET_PATH}")
if not DATASET_PATH.exists():
print(f"ERROR: Dataset not found: {DATASET_PATH}", file=sys.stderr)
return 1
dataset = Dataset.from_json(str(DATASET_PATH))
print(f"Loaded {len(dataset):,} examples")
# =========================================================================
# Load Model
# =========================================================================
print(f"\nLoading model: {MODEL_NAME}")
# GRPO requires 16-bit LoRA adapters (load_in_4bit=False)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
load_in_4bit=False, # Must be False for GRPO
fast_inference=True,
max_lora_rank=LORA_RANK,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
)
print(f"Model type: {model.config.model_type}")
# =========================================================================
# Apply LoRA
# =========================================================================
print("\nApplying LoRA adapters...")
# Use gradient_checkpointing=True (not "unsloth") for stability on RunPod
model = FastLanguageModel.get_peft_model(
model,
r=LORA_RANK,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=LORA_RANK, # Same as r for GRPO (not r*2)
use_gradient_checkpointing=True, # Stable on RunPod (not "unsloth")
random_state=3407,
)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
# =========================================================================
# Configure vLLM Sampling
# =========================================================================
vllm_sampling_params = SamplingParams(
min_p=0.1,
top_p=1.0,
top_k=-1,
max_tokens=MAX_COMPLETION_LENGTH,
stop=[tokenizer.eos_token],
include_stop_str_in_output=True,
seed=3407,
)
# =========================================================================
# Configure Training
# =========================================================================
print("\nConfiguring GRPO trainer...")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
training_args = GRPOConfig(
# vLLM
vllm_sampling_params=vllm_sampling_params,
temperature=1.0,
# Optimization
learning_rate=LEARNING_RATE,
weight_decay=0.001,
warmup_ratio=WARMUP_RATIO,
lr_scheduler_type="linear",
optim="adamw_8bit",
# Batch settings
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
num_generations=NUM_GENERATIONS,
# Sequence lengths
max_prompt_length=MAX_PROMPT_LENGTH,
max_completion_length=MAX_COMPLETION_LENGTH,
# Training duration
max_steps=MAX_STEPS,
save_steps=SAVE_STEPS,
# Logging
logging_steps=1,
report_to="wandb",
# Output
output_dir=str(CHECKPOINT_DIR),
)
# =========================================================================
# Select Reward Functions
# =========================================================================
if REWARD_MODE == "FULL":
print("\nUsing FULL reward mode (recommended):")
print(" - match_format_exactly, match_format_approximately")
print(" - full_coherence_reward (NLI + structure + topic + depth)")
print(" - completeness_reward, logging_reward")
reward_funcs = [
match_format_exactly,
match_format_approximately,
full_coherence_reward,
completeness_reward,
debug_print_reward,
logging_reward,
]
elif REWARD_MODE == "ROBUST":
print("\nUsing ROBUST reward mode:")
print(" - match_format_exactly, match_format_approximately")
print(" - robust_coherence_reward (NLI + self-consistency + structure)")
print(" - completeness_reward, logging_reward")
reward_funcs = [
match_format_exactly,
match_format_approximately,
robust_coherence_reward,
completeness_reward,
debug_print_reward,
logging_reward,
]
else: # LEGACY
print("\nUsing LEGACY reward mode (faster, less robust):")
print(" - match_format_exactly, match_format_approximately")
print(" - semantic_similarity_reward, terminology_reward")
print(" - completeness_reward, logging_reward")
reward_funcs = [
match_format_exactly,
match_format_approximately,
semantic_similarity_reward,
terminology_reward,
completeness_reward,
debug_print_reward,
logging_reward,
]
# =========================================================================
# Create Trainer
# =========================================================================
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset,
)
# =========================================================================
# Check for Checkpoint Resume
# =========================================================================
resume_from = find_latest_checkpoint(CHECKPOINT_DIR)
if resume_from:
print(f"\nResuming from checkpoint: {resume_from}")
# =========================================================================
# Train!
# =========================================================================
print("\n" + "=" * 70)
print("STARTING TRAINING")
print("=" * 70)
print(f"Steps: {MAX_STEPS}")
print(f"Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}")
print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION * NUM_GENERATIONS}")
print()
try:
if resume_from:
trainer.train(resume_from_checkpoint=str(resume_from))
else:
trainer.train()
except KeyboardInterrupt:
print("\nTraining interrupted. Saving checkpoint...")
except Exception as e:
print(f"\nTraining error: {e}", file=sys.stderr)
finish_wandb_logging({"status": "error", "error": str(e)})
raise
# =========================================================================
# Save LoRA
# =========================================================================
print("\n" + "=" * 70)
print("SAVING MODEL")
print("=" * 70)
LORA_OUTPUT.mkdir(parents=True, exist_ok=True)
model.save_lora(str(LORA_OUTPUT))
print(f"LoRA saved to: {LORA_OUTPUT}")
# =========================================================================
# Upload to HuggingFace Hub
# =========================================================================
try:
upload_to_hub(LORA_OUTPUT, HF_REPO, HF_TOKEN)
except Exception as e:
print(f"Warning: Failed to upload to HuggingFace Hub: {e}", file=sys.stderr)
# =========================================================================
# Finish W&B
# =========================================================================
finish_wandb_logging({
"status": "completed",
"final_step": MAX_STEPS,
"reward_mode": REWARD_MODE,
"dataset_size": len(dataset),
"hf_repo": HF_REPO,
})
print("\n" + "=" * 70)
print("TRAINING COMPLETE!")
print("=" * 70)
print(f"LoRA saved to: {LORA_OUTPUT}")
print(f"Model uploaded to: https://huggingface.co/{HF_REPO}")
return 0
if __name__ == "__main__":
sys.exit(main())