LvcidPsyche's picture
Upload app.py with huggingface_hub
ed81b4c verified
import gradio as gr
import spaces
import torch
import os
from datasets import load_dataset
from model import GPT, GPTConfig
import tiktoken
# Setup configurations
# H200 has massive VRAM, we can push the batch size much higher
BATCH_SIZE = 64
BLOCK_SIZE = 256
LEARNING_RATE = 3e-4
# We train on the standard Edu-Fineweb baseline first to establish neural weights and architecture.
# We will use the webreaper data for fine-tuning / reasoning alignment later.
DATASET_NAME = "HuggingFaceFW/fineweb-edu"
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# Optimize PyTorch precision for Ampere/Hopper/Blackwell architecture (like H200)
torch.set_float32_matmul_precision('high')
# -----------------------------------------------------------------------
# ZERO-GPU ARCHITECTURE
# -----------------------------------------------------------------------
# ZeroGPU allocates a GPU dynamically only when this function is called.
# It has a strict time limit (usually 60-120s for free users).
# To bypass this, we "Chunk" the training:
# 1. Load the model & optimizer from disk.
# 2. Move to GPU.
# 3. Train for `steps` (~60 seconds of compute).
# 4. Move model back to CPU and save to disk BEFORE the GPU gets deallocated.
# -----------------------------------------------------------------------
@spaces.GPU(duration=120)
def train_chunk(steps=50, checkpoint_path=None):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 1. Initialize or Load Model
# Using a small "NanoGPT" footprint to fit in memory
config = GPTConfig(block_size=BLOCK_SIZE, vocab_size=50304, n_layer=4, n_head=4, n_embd=256)
model = GPT(config)
# Karpathy's optimized AdamW configuration
# We use weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8, and fused=True
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8, fused=True)
start_step = 0
if checkpoint_path and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Remove the _orig_mod. prefix added by torch.compile if present
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint['optimizer'])
# Move optimizer state tensors to the correct device (GPU)
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
start_step = checkpoint.get('step', 0)
# Move to the ZeroGPU
model.to(device)
# Compile the model to push the H200 to its absolute limits
if hasattr(torch, 'compile'):
model = torch.compile(model)
model.train()
# 2. Load Dataset (Streaming to save RAM)
try:
# We stream the Fineweb dataset directly from Hugging Face
ds = load_dataset(DATASET_NAME, name="sample-10BT", split="train", streaming=True)
ds_iter = iter(ds)
except Exception as e:
return f"Dataset load error: {e}", checkpoint_path
enc = tiktoken.get_encoding("gpt2")
# 3. Training Loop on the GPU
logs = []
logs.append(f"--- Resuming training at Step {start_step} (Using H200 bfloat16 + torch.compile) ---")
for step in range(start_step, start_step + steps):
# Fetch a text chunk from the dataset
try:
row = next(ds_iter)
# Fineweb uses the 'text' column.
text = row.get("text", " ")
if not text: text = " "
except StopIteration:
break
# Karpathy Cosine LR Schedule with Warmup
max_steps = 5000 # Assume a 5000 step total training run
warmup_steps = 100
if step < warmup_steps:
lr = LEARNING_RATE * (step + 1) / warmup_steps
elif step > max_steps:
lr = LEARNING_RATE * 0.1
else:
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
import math
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = LEARNING_RATE * 0.1 + coeff * (LEARNING_RATE - LEARNING_RATE * 0.1)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
tokens = enc.encode(text, allowed_special={"<|endoftext|>"})
if len(tokens) < BLOCK_SIZE + 1:
continue
# Sample a random sequence
ix = torch.randint(len(tokens) - BLOCK_SIZE, (BATCH_SIZE,))
# pin_memory is implicitly fast on direct tensor creation, but pinning manually helps with standard dataloaders
x = torch.stack([torch.tensor(tokens[i:i+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True)
y = torch.stack([torch.tensor(tokens[i+1:i+1+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True)
# Forward & Backward Pass with Mixed Precision (bfloat16) for Tensor Cores
with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits, loss = model(x, y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if step % 10 == 0:
logs.append(f"Step {step} | Loss: {loss.item():.4f}")
# 4. Save Checkpoint BEFORE ZeroGPU shuts down
model.to('cpu')
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
out_ckpt = os.path.join(CHECKPOINT_DIR, f"ckpt_step_{start_step + steps}.pt")
torch.save({
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': start_step + steps
}, out_ckpt)
# Clean up old checkpoints to save disk space
if checkpoint_path and os.path.exists(checkpoint_path) and checkpoint_path != out_ckpt:
os.remove(checkpoint_path)
return "\n".join(logs), out_ckpt
# -----------------------------------------------------------------------
# GRADIO INTERFACE
# -----------------------------------------------------------------------
current_ckpt = None
def run_training_ui(steps_per_call):
global current_ckpt
# We yield updates so the UI can stream the logs infinitely
# while keeping the python thread alive, triggering the ZeroGPU
# decorator over and over again.
yield "Starting infinite training loop... this will run until you run out of ZeroGPU quota!", f"Initial Checkpoint: {current_ckpt}"
total_logs = []
while True:
try:
# 1. Trigger the ZeroGPU allocation (runs for ~100s)
chunk_log, new_ckpt = train_chunk(steps=int(steps_per_call), checkpoint_path=current_ckpt)
# 2. Update global state
current_ckpt = new_ckpt
# 3. Append to logs and yield to UI
total_logs.append(chunk_log)
# Keep logs manageable
if len(total_logs) > 50:
total_logs = total_logs[-50:]
yield "\n\n".join(total_logs), f"Current Active Checkpoint: {current_ckpt}"
except Exception as e:
err_msg = f"Training interrupted (likely ran out of ZeroGPU quota). Error: {str(e)}"
total_logs.append(err_msg)
yield "\n\n".join(total_logs), f"Final Checkpoint: {current_ckpt}"
break
with gr.Blocks(title="ZeroGPU WebReaper LM Trainer") as demo:
gr.Markdown("# 🧠 Auto-ML Infinite Trainer on ZeroGPU")
gr.Markdown(f"Trains a Karpathy-style NanoGPT directly on the `{DATASET_NAME}` dataset utilizing Hugging Face's free ZeroGPU allowance.")
gr.Markdown("Click **'Start Infinite Training'** ONCE. The script will automatically loop, requesting a GPU, training for 60 seconds, saving state to disk, and repeating. It will only stop when you hit your daily 25-minute quota limit!")
with gr.Row():
steps_slider = gr.Slider(minimum=10, maximum=150, value=50, step=10, label="Steps per GPU Request")
train_btn = gr.Button("🚀 Start Infinite Training", variant="primary")
stop_btn = gr.Button("⏹️ Stop Training", variant="stop")
with gr.Row():
log_output = gr.Textbox(label="Live Training Logs (Streams automatically)", lines=20)
ckpt_output = gr.Textbox(label="Checkpoint Status")
# Use generator to stream logs
train_event = train_btn.click(fn=run_training_ui, inputs=[steps_slider], outputs=[log_output, ckpt_output])
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[train_event])
if __name__ == "__main__":
demo.launch()