Spaces:
Running
Running
| """ | |
| GPUburnout Models β Unified Demo | |
| Compare models trained from scratch: GPUburnout-3M β GPUburnout-134M β GPUburnout-1B | |
| """ | |
| import gc | |
| import json | |
| import os | |
| import sys | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| # Add models directory to path | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "models")) | |
| # ββ Model Registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODELS = { | |
| "GPUburnout-3M (3.2M)": { | |
| "path": "checkpoints/tiny", | |
| "arch": "s1", | |
| "description": "Character-level model trained on Shakespeare. The very first step.", | |
| "examples": ["ROMEO:", "JULIET:", "To be, or not to be", "First Citizen:"], | |
| }, | |
| "GPUburnout-134M (134M)": { | |
| "path": "checkpoints/gpt2_small", | |
| "arch": "s1", | |
| "description": "Season 1 final model. BPE tokenizer, 2.8B tokens, 12 layers.", | |
| "examples": [ | |
| "The capital of France is", | |
| "Explain machine learning in simple terms.", | |
| "def fibonacci(n):", | |
| "The meaning of life is", | |
| ], | |
| }, | |
| "GPUburnout-1B (1.04B)": { | |
| "path": "checkpoints/llama_1b", | |
| "arch": "s2", | |
| "description": "Season 2. Llama architecture, 11.8B tokens, $175 total. Final loss 2.494.", | |
| "examples": [ | |
| "The capital of France is", | |
| "In a shocking discovery, scientists found that", | |
| "def fibonacci(n):", | |
| "Once upon a time, in a land far away,", | |
| ], | |
| }, | |
| } | |
| # ββ Current model state (one at a time) βββββββββββββββββββββββββββββββββββββ | |
| current = {"name": None, "model": None, "tokenizer": None, "config": None} | |
| def unload_current(): | |
| """Free the currently loaded model from memory.""" | |
| if current["model"] is not None: | |
| del current["model"] | |
| current["model"] = None | |
| current["tokenizer"] = None | |
| current["config"] = None | |
| current["name"] = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def load_model(model_name): | |
| """Load a model by name, unloading the previous one first.""" | |
| if current["name"] == model_name and current["model"] is not None: | |
| return current["model"], current["tokenizer"], current["config"] | |
| unload_current() | |
| info = MODELS[model_name] | |
| model_dir = info["path"] | |
| config_path = os.path.join(model_dir, "config.json") | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(f"Model not found: {model_dir}") | |
| with open(config_path) as f: | |
| config = json.load(f) | |
| if info["arch"] == "s1": | |
| model, tokenizer = _load_s1(model_dir, config) | |
| else: | |
| model, tokenizer = _load_s2(model_dir, config) | |
| current["name"] = model_name | |
| current["model"] = model | |
| current["tokenizer"] = tokenizer | |
| current["config"] = config | |
| return model, tokenizer, config | |
| def _load_s1(model_dir, config): | |
| """Load Season 1 GPT-2 style model.""" | |
| from s1_model import TransformerLanguageModel | |
| model = TransformerLanguageModel( | |
| vocab_size=config["vocab_size"], | |
| embed_dim=config["embed_dim"], | |
| num_heads=config["num_heads"], | |
| num_layers=config["num_layers"], | |
| ff_dim=config["ff_dim"], | |
| max_seq_len=config["max_seq_len"], | |
| dropout=0.0, | |
| ) | |
| weights_path = os.path.join(model_dir, "pytorch_model.bin") | |
| model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True)) | |
| model.eval() | |
| # Load tokenizer | |
| tokenizer_type = config.get("tokenizer_type", "character") | |
| tokenizer_path = os.path.join(model_dir, "tokenizer.json") | |
| if tokenizer_type == "bpe": | |
| from s1_tokenizer_bpe import BPETokenizer | |
| tokenizer = BPETokenizer() | |
| tokenizer.load(tokenizer_path) | |
| else: | |
| from s1_tokenizer_char import CharacterTokenizer | |
| tokenizer = CharacterTokenizer() | |
| tokenizer.load(tokenizer_path) | |
| return model, tokenizer | |
| def _load_s2(model_dir, config): | |
| """Load Season 2 Llama style model.""" | |
| from s2_model import LlamaModel, ModelConfig | |
| model_config = ModelConfig( | |
| vocab_size=config.get("vocab_size", 32005), | |
| d_model=config.get("d_model", 2048), | |
| n_layers=config.get("n_layers", 16), | |
| n_heads=config.get("n_heads", 32), | |
| n_kv_heads=config.get("n_kv_heads", 8), | |
| d_ff=config.get("d_ff", 8192), | |
| max_seq_len=config.get("max_seq_len", 2048), | |
| ) | |
| model = LlamaModel(model_config).to("cpu") | |
| weights_path = os.path.join(model_dir, "pytorch_model.bin") | |
| # Download from HF model repo if not present locally (Space LFS limit workaround) | |
| if not os.path.exists(weights_path): | |
| from huggingface_hub import hf_hub_download | |
| print("Downloading Llama 1B weights from GPUburnout/gpuburnout-1b...") | |
| weights_path = hf_hub_download( | |
| repo_id="GPUburnout/gpuburnout-1b", | |
| filename="pytorch_model.bin", | |
| local_dir=model_dir, | |
| ) | |
| state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # S2 uses HuggingFace tokenizers library | |
| from tokenizers import Tokenizer | |
| tokenizer = Tokenizer.from_file("tokenizer/bpe_tokenizer.json") | |
| return model, tokenizer | |
| # ββ Generation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_s1(model, tokenizer, config, prompt, max_tokens, temperature, top_k): | |
| """Generate text with S1 (GPT-2) model.""" | |
| tokens = tokenizer.encode(prompt) | |
| if not tokens: | |
| return "Could not encode prompt." | |
| tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0) | |
| max_seq_len = config.get("max_seq_len", 256) | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| inp = tokens[:, -max_seq_len:] if tokens.size(1) > max_seq_len else tokens | |
| logits = model(inp)[:, -1, :] / temperature | |
| if top_k > 0: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = float("-inf") | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| tokens = torch.cat([tokens, next_token], dim=1) | |
| return tokenizer.decode(tokens[0].tolist()) | |
| def generate_s2(model, tokenizer, prompt, max_tokens, temperature, top_k): | |
| """Generate text with S2 (Llama) model.""" | |
| encoded = tokenizer.encode(prompt) | |
| input_ids = torch.tensor([encoded.ids], dtype=torch.long) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k if top_k > 0 else None, | |
| ) | |
| return tokenizer.decode(output_ids[0].tolist()) | |
| def generate_text(model_name, prompt, max_tokens, temperature, top_k): | |
| """Main generation entry point.""" | |
| if not prompt.strip(): | |
| return "Please enter a prompt." | |
| try: | |
| model, tokenizer, config = load_model(model_name) | |
| except FileNotFoundError as e: | |
| return f"Error: {e}" | |
| info = MODELS[model_name] | |
| if info["arch"] == "s1": | |
| return generate_s1(model, tokenizer, config, prompt, int(max_tokens), temperature, int(top_k)) | |
| else: | |
| return generate_s2(model, tokenizer, prompt, int(max_tokens), temperature, int(top_k)) | |
| def get_status(model_name): | |
| """Return status string for the selected model.""" | |
| info = MODELS[model_name] | |
| loaded = "Loaded" if current["name"] == model_name else "Not loaded (will load on generate)" | |
| return f"**{model_name}** β {info['description']}\n\nStatus: {loaded}" | |
| def update_examples(model_name): | |
| """Return example prompts for the selected model.""" | |
| return gr.update(samples=[[ex] for ex in MODELS[model_name]["examples"]]) | |
| # ββ Custom CSS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CUSTOM_CSS = """ | |
| .gradio-container { | |
| max-width: 900px !important; | |
| margin: auto; | |
| } | |
| .header-text { | |
| text-align: center; | |
| margin-bottom: 0.5em; | |
| } | |
| .header-text h1 { | |
| color: #22d3ee; | |
| font-family: 'Courier New', monospace; | |
| } | |
| .header-text a { | |
| color: #f59e0b; | |
| } | |
| .model-info { | |
| font-family: 'Courier New', monospace; | |
| font-size: 0.85em; | |
| padding: 10px; | |
| border-radius: 8px; | |
| } | |
| """ | |
| # ββ Theme ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| dark_theme = gr.themes.Base( | |
| primary_hue="cyan", | |
| neutral_hue="gray", | |
| font=gr.themes.GoogleFont("JetBrains Mono"), | |
| ).set( | |
| body_background_fill="#08080d", | |
| body_background_fill_dark="#08080d", | |
| background_fill_primary="#0e0e15", | |
| background_fill_primary_dark="#0e0e15", | |
| background_fill_secondary="#12121a", | |
| background_fill_secondary_dark="#12121a", | |
| block_background_fill="#0e0e15", | |
| block_background_fill_dark="#0e0e15", | |
| block_border_color="#2a3a4a", | |
| block_border_color_dark="#2a3a4a", | |
| block_border_width="2px", | |
| block_border_width_dark="2px", | |
| block_label_background_fill="#12121a", | |
| block_label_background_fill_dark="#12121a", | |
| block_label_text_color="#9ca3af", | |
| block_label_text_color_dark="#9ca3af", | |
| block_title_text_color="#9ca3af", | |
| block_title_text_color_dark="#9ca3af", | |
| body_text_color="#e0e0e5", | |
| body_text_color_dark="#e0e0e5", | |
| body_text_color_subdued="#6b7280", | |
| body_text_color_subdued_dark="#6b7280", | |
| border_color_primary="#2a3a4a", | |
| border_color_primary_dark="#2a3a4a", | |
| input_background_fill="#12121a", | |
| input_background_fill_dark="#12121a", | |
| input_border_color="#2a3a4a", | |
| input_border_color_dark="#2a3a4a", | |
| input_placeholder_color="#6b7280", | |
| input_placeholder_color_dark="#6b7280", | |
| button_primary_background_fill="#22d3ee", | |
| button_primary_background_fill_dark="#22d3ee", | |
| button_primary_text_color="#08080d", | |
| button_primary_text_color_dark="#08080d", | |
| button_primary_background_fill_hover="#67e8f9", | |
| button_primary_background_fill_hover_dark="#67e8f9", | |
| panel_background_fill="#0e0e15", | |
| panel_background_fill_dark="#0e0e15", | |
| panel_border_color="#2a3a4a", | |
| panel_border_color_dark="#2a3a4a", | |
| panel_border_width="2px", | |
| panel_border_width_dark="2px", | |
| slider_color="#22d3ee", | |
| slider_color_dark="#22d3ee", | |
| ) | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks( | |
| title="GPUburnout Models", | |
| theme=dark_theme, | |
| css=CUSTOM_CSS, | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="header-text"> | |
| <h1>GPUburnout Models</h1> | |
| <p>Compare language models I trained from scratch β from 3.2M to 1 billion parameters.</p> | |
| <p> | |
| <a href="https://gpuburnout.com" target="_blank">Read the blog</a> Β· | |
| <a href="https://github.com/GPUburnout" target="_blank">GitHub</a> Β· | |
| <a href="https://gpuburnout.com/about/" target="_blank">About</a> | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="GPUburnout-134M (134M)", | |
| label="Select Model", | |
| ) | |
| model_status = gr.Markdown(elem_classes=["model-info"]) | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Type something...", | |
| lines=2, | |
| value="The capital of France is", | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(50, 300, value=50, step=25, label="Max tokens") | |
| temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature") | |
| top_k = gr.Slider(1, 100, value=50, step=1, label="Top-K") | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output = gr.Textbox(label="Output", lines=15, show_copy_button=True) | |
| examples = gr.Examples( | |
| examples=[["The capital of France is"], ["def fibonacci(n):"]], | |
| inputs=prompt, | |
| label="Example prompts", | |
| ) | |
| # Events | |
| demo.load(get_status, inputs=model_selector, outputs=model_status) | |
| model_selector.change(get_status, inputs=model_selector, outputs=model_status) | |
| model_selector.change(update_examples, inputs=model_selector, outputs=examples.dataset) | |
| generate_btn.click( | |
| generate_text, | |
| inputs=[model_selector, prompt, max_tokens, temperature, top_k], | |
| outputs=output, | |
| ) | |
| prompt.submit( | |
| generate_text, | |
| inputs=[model_selector, prompt, max_tokens, temperature, top_k], | |
| outputs=output, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |