| import gradio as gr |
| import spaces |
| import torch |
| import os |
| from datasets import load_dataset |
| from model import GPT, GPTConfig |
| import tiktoken |
|
|
| |
| |
| BATCH_SIZE = 64 |
| BLOCK_SIZE = 256 |
| LEARNING_RATE = 3e-4 |
| |
| |
| DATASET_NAME = "HuggingFaceFW/fineweb-edu" |
| CHECKPOINT_DIR = "./checkpoints" |
|
|
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
|
|
| |
| torch.set_float32_matmul_precision('high') |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @spaces.GPU(duration=120) |
| def train_chunk(steps=50, checkpoint_path=None): |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| |
| |
| config = GPTConfig(block_size=BLOCK_SIZE, vocab_size=50304, n_layer=4, n_head=4, n_embd=256) |
| model = GPT(config) |
| |
| |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8, fused=True) |
| |
| start_step = 0 |
| if checkpoint_path and os.path.exists(checkpoint_path): |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| |
| |
| state_dict = checkpoint['model'] |
| unwanted_prefix = '_orig_mod.' |
| for k, v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
| |
| model.load_state_dict(state_dict) |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| |
| |
| for state in optimizer.state.values(): |
| for k, v in state.items(): |
| if isinstance(v, torch.Tensor): |
| state[k] = v.to(device) |
| |
| start_step = checkpoint.get('step', 0) |
| |
| |
| model.to(device) |
| |
| |
| if hasattr(torch, 'compile'): |
| model = torch.compile(model) |
| |
| model.train() |
| |
| |
| try: |
| |
| ds = load_dataset(DATASET_NAME, name="sample-10BT", split="train", streaming=True) |
| ds_iter = iter(ds) |
| except Exception as e: |
| return f"Dataset load error: {e}", checkpoint_path |
| |
| enc = tiktoken.get_encoding("gpt2") |
| |
| |
| logs = [] |
| logs.append(f"--- Resuming training at Step {start_step} (Using H200 bfloat16 + torch.compile) ---") |
| |
| for step in range(start_step, start_step + steps): |
| |
| try: |
| row = next(ds_iter) |
| |
| text = row.get("text", " ") |
| if not text: text = " " |
| except StopIteration: |
| break |
| |
| |
| max_steps = 5000 |
| warmup_steps = 100 |
| if step < warmup_steps: |
| lr = LEARNING_RATE * (step + 1) / warmup_steps |
| elif step > max_steps: |
| lr = LEARNING_RATE * 0.1 |
| else: |
| decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) |
| import math |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| lr = LEARNING_RATE * 0.1 + coeff * (LEARNING_RATE - LEARNING_RATE * 0.1) |
| |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = lr |
| |
| tokens = enc.encode(text, allowed_special={"<|endoftext|>"}) |
| if len(tokens) < BLOCK_SIZE + 1: |
| continue |
| |
| |
| ix = torch.randint(len(tokens) - BLOCK_SIZE, (BATCH_SIZE,)) |
| |
| x = torch.stack([torch.tensor(tokens[i:i+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True) |
| y = torch.stack([torch.tensor(tokens[i+1:i+1+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True) |
| |
| |
| with torch.autocast(device_type=device, dtype=torch.bfloat16): |
| logits, loss = model(x, y) |
| |
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| optimizer.step() |
| |
| if step % 10 == 0: |
| logs.append(f"Step {step} | Loss: {loss.item():.4f}") |
| |
| |
| model.to('cpu') |
| raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| out_ckpt = os.path.join(CHECKPOINT_DIR, f"ckpt_step_{start_step + steps}.pt") |
| |
| torch.save({ |
| 'model': raw_model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'step': start_step + steps |
| }, out_ckpt) |
| |
| |
| if checkpoint_path and os.path.exists(checkpoint_path) and checkpoint_path != out_ckpt: |
| os.remove(checkpoint_path) |
| |
| return "\n".join(logs), out_ckpt |
|
|
| |
| |
| |
| current_ckpt = None |
|
|
| def run_training_ui(steps_per_call): |
| global current_ckpt |
| |
| |
| |
| |
| |
| yield "Starting infinite training loop... this will run until you run out of ZeroGPU quota!", f"Initial Checkpoint: {current_ckpt}" |
| |
| total_logs = [] |
| |
| while True: |
| try: |
| |
| chunk_log, new_ckpt = train_chunk(steps=int(steps_per_call), checkpoint_path=current_ckpt) |
| |
| |
| current_ckpt = new_ckpt |
| |
| |
| total_logs.append(chunk_log) |
| |
| if len(total_logs) > 50: |
| total_logs = total_logs[-50:] |
| |
| yield "\n\n".join(total_logs), f"Current Active Checkpoint: {current_ckpt}" |
| |
| except Exception as e: |
| err_msg = f"Training interrupted (likely ran out of ZeroGPU quota). Error: {str(e)}" |
| total_logs.append(err_msg) |
| yield "\n\n".join(total_logs), f"Final Checkpoint: {current_ckpt}" |
| break |
|
|
| with gr.Blocks(title="ZeroGPU WebReaper LM Trainer") as demo: |
| gr.Markdown("# 🧠 Auto-ML Infinite Trainer on ZeroGPU") |
| gr.Markdown(f"Trains a Karpathy-style NanoGPT directly on the `{DATASET_NAME}` dataset utilizing Hugging Face's free ZeroGPU allowance.") |
| gr.Markdown("Click **'Start Infinite Training'** ONCE. The script will automatically loop, requesting a GPU, training for 60 seconds, saving state to disk, and repeating. It will only stop when you hit your daily 25-minute quota limit!") |
| |
| with gr.Row(): |
| steps_slider = gr.Slider(minimum=10, maximum=150, value=50, step=10, label="Steps per GPU Request") |
| train_btn = gr.Button("🚀 Start Infinite Training", variant="primary") |
| stop_btn = gr.Button("⏹️ Stop Training", variant="stop") |
| |
| with gr.Row(): |
| log_output = gr.Textbox(label="Live Training Logs (Streams automatically)", lines=20) |
| ckpt_output = gr.Textbox(label="Checkpoint Status") |
| |
| |
| train_event = train_btn.click(fn=run_training_ui, inputs=[steps_slider], outputs=[log_output, ckpt_output]) |
| stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[train_event]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|