import gradio as gr import torch import torch.nn as nn from model import SmolLM2Model # ✅ Correct import from transformers import AutoTokenizer, AutoConfig # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load tokenizer and config print("Loading tokenizer and config...") tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M") # Load model @torch.no_grad() def load_model(): """Load the trained model""" print("Loading model...") # Initialize model with config model = SmolLM2Model(config).to(device) # Load checkpoint checkpoint = torch.load('checkpoint_step_5050.pt', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"✅ Model loaded successfully on {device}") print(f"✅ Training step: {checkpoint.get('step', 'N/A')}") return model, checkpoint # Load model at startup model, checkpoint = load_model() @torch.no_grad() def generate_text( prompt, max_length=100, temperature=0.8, top_k=50, top_p=0.9 ): """Generate text from prompt""" try: # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(device) input_ids = inputs['input_ids'] # Generate using model's built-in method generated_ids = model.generate( input_ids, max_new_tokens=max_length, temperature=temperature, top_p=top_p, top_k=top_k if top_k > 0 else None, do_sample=temperature > 0 ) # Decode output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) return output_text except Exception as e: return f"❌ Error generating text: {str(e)}" def get_model_info(): """Display model information""" total_params = model.get_num_params() trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) info = f""" ### 📊 Model Information **Model:** SmolLM2-135M **Total Parameters:** {total_params:,} (~{total_params/1e6:.1f}M) **Trainable Parameters:** {trainable_params:,} **Training Steps:** {checkpoint.get('step', 'N/A')} **Device:** {device} **Vocab Size:** {config.vocab_size:,} ### 🏗️ Architecture - **Layers:** {config.num_hidden_layers} - **Hidden Size:** {config.hidden_size} - **Attention Heads:** {config.num_attention_heads} (Query) / {config.num_key_value_heads} (KV) - **FFN Size:** {config.intermediate_size} - **Context Length:** {config.max_position_embeddings} ### 🎯 Training Details - ✅ Trained for 5,000 steps - ✅ Checkpoint saved and reloaded - ✅ Additional 50 steps after reload - ✅ Predictions logged every 500 steps """ return info # Gradio Interface with gr.Blocks(theme=gr.themes.Soft(), title="SmolLM2-135M Demo") as demo: gr.Markdown(""" # 🤖 SmolLM2-135M: From-Scratch Implementation Complete reverse-engineered implementation of SmolLM2-135M, trained from scratch. **GitHub:** [abi2024/smollm2-135-implementation](https://github.com/abi2024/smollm2-135-implementation) """) with gr.Tab("🎮 Generate Text"): with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=3, value="Once upon a time" ) with gr.Row(): max_length_slider = gr.Slider( minimum=10, maximum=200, value=50, step=10, label="Max New Tokens" ) temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature" ) with gr.Row(): top_k_slider = gr.Slider( minimum=0, maximum=100, value=50, step=5, label="Top-K" ) top_p_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)" ) generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg") with gr.Column(): output_text = gr.Textbox( label="Generated Text", lines=12, interactive=False ) generate_btn.click( fn=generate_text, inputs=[ prompt_input, max_length_slider, temperature_slider, top_k_slider, top_p_slider ], outputs=output_text ) gr.Markdown(""" ### 💡 Generation Tips: - **Temperature**: Controls randomness (0.1 = focused, 2.0 = creative) - **Top-K**: Limits to K most likely tokens (0 = disabled) - **Top-P**: Nucleus sampling threshold (0.9 recommended) """) with gr.Tab("📊 Model Info"): model_info_display = gr.Markdown(get_model_info()) gr.Markdown(""" ### 🔍 Reverse Engineering Process 1. **Architecture Analysis** - Studied SmolLM2 GitHub repository - Extracted model configuration from YAML - Downloaded pretrained 135M checkpoint 2. **Implementation** - Built from scratch using PyTorch - Implemented Grouped Query Attention (9Q/3KV heads) - Added RoPE position embeddings - Used SwiGLU FFN and RMSNorm 3. **Validation** - Loaded official pretrained weights - Verified parameter count (134,515,008) - Confirmed architecture matches exactly ### ⚡ Optimizations Applied - ✅ Flash Attention 2 (via scaled_dot_product_attention) - ✅ Mixed Precision Training (BF16/FP16) - ✅ Gradient Accumulation - ✅ torch.compile() for inference speedup - ✅ Grouped Query Attention (memory efficient) ### 📈 Training Pipeline 1. **Main Training:** 5,000 steps with predictions every 500 steps 2. **Checkpoint Test:** Model saved and successfully reloaded 3. **Resume Training:** 50 additional steps (validates checkpoint integrity) """) with gr.Tab("🎯 Example Prompts"): gr.Markdown(""" ### Try these prompts: **1. Story Generation** ``` Once upon a time in a magical forest, ``` **2. Code Completion** ``` def calculate_fibonacci(n): # Calculate the nth Fibonacci number ``` **3. Question Answering** ``` Q: What is the capital of France? A: ``` **4. Technical Writing** ``` The main advantage of transformer architectures is ``` **5. Creative Writing** ``` The scientist discovered something extraordinary: ``` ### 🎛️ Recommended Settings: - **Creative Writing:** Temperature=1.0, Top-P=0.95 - **Code Generation:** Temperature=0.3, Top-P=0.9, Top-K=40 - **Factual Q&A:** Temperature=0.5, Top-P=0.8, Top-K=30 """) # Launch if __name__ == "__main__": demo.launch()