import gradio as gr import torch import torch.nn as nn from model import SmolLM2_135M # Import your model class import yaml # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model @torch.no_grad() def load_model(): """Load the trained model""" print("Loading model...") # Load config with open('config.yaml', 'r') as f: config = yaml.safe_load(f) # Initialize model model = SmolLM2_135M( vocab_size=config['vocab_size'], d_model=config['d_model'], n_layers=config['n_layers'], n_heads=config['n_heads'], # Add other config parameters ).to(device) # Load checkpoint checkpoint = torch.load('checkpoints/checkpoint_step_5050.pt', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"Model loaded successfully on {device}") return model, checkpoint # Load model at startup model, checkpoint = load_model() # Tokenizer (adjust based on your implementation) def tokenize(text, max_length=128): """Simple character-level tokenizer - REPLACE with your actual tokenizer""" # This is a placeholder - use your actual tokenizer tokens = [ord(c) for c in text[:max_length]] return torch.tensor(tokens).unsqueeze(0).to(device) def detokenize(tokens): """Convert tokens back to text - REPLACE with your actual detokenizer""" # This is a placeholder - use your actual detokenizer return ''.join([chr(t) for t in tokens if t < 128]) @torch.no_grad() def generate_text( prompt, max_length=100, temperature=0.8, top_k=50, top_p=0.9 ): """Generate text from prompt""" try: # Tokenize input input_ids = tokenize(prompt) # Generate generated = input_ids[0].tolist() for _ in range(max_length): # Get model predictions input_tensor = torch.tensor([generated]).to(device) logits = model(input_tensor) # Get next token logits next_token_logits = logits[0, -1, :] / temperature # Apply top-k filtering if top_k > 0: indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') # Apply top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits[indices_to_remove] = float('-inf') # Sample next token probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() generated.append(next_token) # Stop if EOS token (adjust based on your vocab) if next_token == 0: # Assuming 0 is EOS break # Detokenize output_text = detokenize(generated) return output_text except Exception as e: return f"Error generating text: {str(e)}" def get_model_info(): """Display model information""" total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) info = f""" ### 📊 Model Information **Total Parameters:** {total_params:,} (~{total_params/1e6:.1f}M) **Trainable Parameters:** {trainable_params:,} **Training Steps:** {checkpoint.get('step', 'N/A')} **Device:** {device} **Model Architecture:** SmolLM2-135M ### 🎯 Training Details - Trained for 5,000 steps - Checkpoint saved and reloaded - Additional 50 steps after reload - Predictions logged every 500 steps """ return info # Gradio Interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🤖 SmolLM2-135M: From-Scratch Implementation This is a complete reverse-engineered implementation of SmolLM2-135M, trained from scratch. **GitHub:** [abi2024/smollm2-135-implementation](https://github.com/abi2024/smollm2-135-implementation) """) with gr.Tab("🎮 Generate Text"): with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=3, value="Once upon a time" ) with gr.Row(): max_length_slider = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="Max Length" ) temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature" ) with gr.Row(): top_k_slider = gr.Slider( minimum=0, maximum=100, value=50, step=5, label="Top-K" ) top_p_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P" ) generate_btn = gr.Button("🚀 Generate", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Generated Text", lines=10, interactive=False ) generate_btn.click( fn=generate_text, inputs=[ prompt_input, max_length_slider, temperature_slider, top_k_slider, top_p_slider ], outputs=output_text ) gr.Markdown(""" ### 💡 Tips: - **Temperature**: Higher = more creative, Lower = more focused - **Top-K**: Limits vocabulary to K most likely tokens - **Top-P**: Nucleus sampling - cumulative probability threshold """) with gr.Tab("📊 Model Info"): model_info_display = gr.Markdown(get_model_info()) gr.Markdown(""" ### 🏗️ Architecture Details This model was reverse-engineered by: 1. Analyzing the official SmolLM2 repository 2. Extracting architecture from pretrained weights 3. Implementing from scratch in PyTorch 4. Validating by swapping weights with pretrained model ### ⚡ Optimizations Used - Flash Attention 2 - Mixed Precision Training (BF16/FP16) - Gradient Accumulation - torch.compile() ### 📈 Training Process - **Step 0-5000**: Main training with periodic predictions - **Checkpoint**: Model saved and reloaded to validate state preservation - **Step 5000-5050**: Continued training to test checkpoint robustness """) with gr.Tab("🎯 Example Prompts"): gr.Markdown(""" ### Try these prompts: 1. **Story Generation** ``` Once upon a time in a land far away ``` 2. **Code Completion** ``` def fibonacci(n): ``` 3. **Question Answering** ``` Q: What is machine learning? A: ``` 4. **Creative Writing** ``` The old house at the end of the street was ``` 5. **Technical Explanation** ``` Neural networks work by ``` """) # Launch if __name__ == "__main__": demo.launch()