import gradio as gr import spaces import torch import os from datasets import load_dataset from model import GPT, GPTConfig import tiktoken # Setup configurations # H200 has massive VRAM, we can push the batch size much higher BATCH_SIZE = 64 BLOCK_SIZE = 256 LEARNING_RATE = 3e-4 # We train on the standard Edu-Fineweb baseline first to establish neural weights and architecture. # We will use the webreaper data for fine-tuning / reasoning alignment later. DATASET_NAME = "HuggingFaceFW/fineweb-edu" CHECKPOINT_DIR = "./checkpoints" os.makedirs(CHECKPOINT_DIR, exist_ok=True) # Optimize PyTorch precision for Ampere/Hopper/Blackwell architecture (like H200) torch.set_float32_matmul_precision('high') # ----------------------------------------------------------------------- # ZERO-GPU ARCHITECTURE # ----------------------------------------------------------------------- # ZeroGPU allocates a GPU dynamically only when this function is called. # It has a strict time limit (usually 60-120s for free users). # To bypass this, we "Chunk" the training: # 1. Load the model & optimizer from disk. # 2. Move to GPU. # 3. Train for `steps` (~60 seconds of compute). # 4. Move model back to CPU and save to disk BEFORE the GPU gets deallocated. # ----------------------------------------------------------------------- @spaces.GPU(duration=120) def train_chunk(steps=50, checkpoint_path=None): device = 'cuda' if torch.cuda.is_available() else 'cpu' # 1. Initialize or Load Model # Using a small "NanoGPT" footprint to fit in memory config = GPTConfig(block_size=BLOCK_SIZE, vocab_size=50304, n_layer=4, n_head=4, n_embd=256) model = GPT(config) # Karpathy's optimized AdamW configuration # We use weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8, and fused=True optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8, fused=True) start_step = 0 if checkpoint_path and os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') # Remove the _orig_mod. prefix added by torch.compile if present state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint['optimizer']) # Move optimizer state tensors to the correct device (GPU) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) start_step = checkpoint.get('step', 0) # Move to the ZeroGPU model.to(device) # Compile the model to push the H200 to its absolute limits if hasattr(torch, 'compile'): model = torch.compile(model) model.train() # 2. Load Dataset (Streaming to save RAM) try: # We stream the Fineweb dataset directly from Hugging Face ds = load_dataset(DATASET_NAME, name="sample-10BT", split="train", streaming=True) ds_iter = iter(ds) except Exception as e: return f"Dataset load error: {e}", checkpoint_path enc = tiktoken.get_encoding("gpt2") # 3. Training Loop on the GPU logs = [] logs.append(f"--- Resuming training at Step {start_step} (Using H200 bfloat16 + torch.compile) ---") for step in range(start_step, start_step + steps): # Fetch a text chunk from the dataset try: row = next(ds_iter) # Fineweb uses the 'text' column. text = row.get("text", " ") if not text: text = " " except StopIteration: break # Karpathy Cosine LR Schedule with Warmup max_steps = 5000 # Assume a 5000 step total training run warmup_steps = 100 if step < warmup_steps: lr = LEARNING_RATE * (step + 1) / warmup_steps elif step > max_steps: lr = LEARNING_RATE * 0.1 else: decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) import math coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) lr = LEARNING_RATE * 0.1 + coeff * (LEARNING_RATE - LEARNING_RATE * 0.1) for param_group in optimizer.param_groups: param_group['lr'] = lr tokens = enc.encode(text, allowed_special={"<|endoftext|>"}) if len(tokens) < BLOCK_SIZE + 1: continue # Sample a random sequence ix = torch.randint(len(tokens) - BLOCK_SIZE, (BATCH_SIZE,)) # pin_memory is implicitly fast on direct tensor creation, but pinning manually helps with standard dataloaders x = torch.stack([torch.tensor(tokens[i:i+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True) y = torch.stack([torch.tensor(tokens[i+1:i+1+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True) # Forward & Backward Pass with Mixed Precision (bfloat16) for Tensor Cores with torch.autocast(device_type=device, dtype=torch.bfloat16): logits, loss = model(x, y) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() if step % 10 == 0: logs.append(f"Step {step} | Loss: {loss.item():.4f}") # 4. Save Checkpoint BEFORE ZeroGPU shuts down model.to('cpu') raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model out_ckpt = os.path.join(CHECKPOINT_DIR, f"ckpt_step_{start_step + steps}.pt") torch.save({ 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': start_step + steps }, out_ckpt) # Clean up old checkpoints to save disk space if checkpoint_path and os.path.exists(checkpoint_path) and checkpoint_path != out_ckpt: os.remove(checkpoint_path) return "\n".join(logs), out_ckpt # ----------------------------------------------------------------------- # GRADIO INTERFACE # ----------------------------------------------------------------------- current_ckpt = None def run_training_ui(steps_per_call): global current_ckpt # We yield updates so the UI can stream the logs infinitely # while keeping the python thread alive, triggering the ZeroGPU # decorator over and over again. yield "Starting infinite training loop... this will run until you run out of ZeroGPU quota!", f"Initial Checkpoint: {current_ckpt}" total_logs = [] while True: try: # 1. Trigger the ZeroGPU allocation (runs for ~100s) chunk_log, new_ckpt = train_chunk(steps=int(steps_per_call), checkpoint_path=current_ckpt) # 2. Update global state current_ckpt = new_ckpt # 3. Append to logs and yield to UI total_logs.append(chunk_log) # Keep logs manageable if len(total_logs) > 50: total_logs = total_logs[-50:] yield "\n\n".join(total_logs), f"Current Active Checkpoint: {current_ckpt}" except Exception as e: err_msg = f"Training interrupted (likely ran out of ZeroGPU quota). Error: {str(e)}" total_logs.append(err_msg) yield "\n\n".join(total_logs), f"Final Checkpoint: {current_ckpt}" break with gr.Blocks(title="ZeroGPU WebReaper LM Trainer") as demo: gr.Markdown("# 🧠 Auto-ML Infinite Trainer on ZeroGPU") gr.Markdown(f"Trains a Karpathy-style NanoGPT directly on the `{DATASET_NAME}` dataset utilizing Hugging Face's free ZeroGPU allowance.") gr.Markdown("Click **'Start Infinite Training'** ONCE. The script will automatically loop, requesting a GPU, training for 60 seconds, saving state to disk, and repeating. It will only stop when you hit your daily 25-minute quota limit!") with gr.Row(): steps_slider = gr.Slider(minimum=10, maximum=150, value=50, step=10, label="Steps per GPU Request") train_btn = gr.Button("🚀 Start Infinite Training", variant="primary") stop_btn = gr.Button("âšī¸ Stop Training", variant="stop") with gr.Row(): log_output = gr.Textbox(label="Live Training Logs (Streams automatically)", lines=20) ckpt_output = gr.Textbox(label="Checkpoint Status") # Use generator to stream logs train_event = train_btn.click(fn=run_training_ui, inputs=[steps_slider], outputs=[log_output, ckpt_output]) stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[train_event]) if __name__ == "__main__": demo.launch()