Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from model import SmolLM2_135M # Import your model class | |
| import yaml | |
| # Device setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model | |
| def load_model(): | |
| """Load the trained model""" | |
| print("Loading model...") | |
| # Load config | |
| with open('config.yaml', 'r') as f: | |
| config = yaml.safe_load(f) | |
| # Initialize model | |
| model = SmolLM2_135M( | |
| vocab_size=config['vocab_size'], | |
| d_model=config['d_model'], | |
| n_layers=config['n_layers'], | |
| n_heads=config['n_heads'], | |
| # Add other config parameters | |
| ).to(device) | |
| # Load checkpoint | |
| checkpoint = torch.load('checkpoints/checkpoint_step_5050.pt', | |
| map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| print(f"Model loaded successfully on {device}") | |
| return model, checkpoint | |
| # Load model at startup | |
| model, checkpoint = load_model() | |
| # Tokenizer (adjust based on your implementation) | |
| def tokenize(text, max_length=128): | |
| """Simple character-level tokenizer - REPLACE with your actual tokenizer""" | |
| # This is a placeholder - use your actual tokenizer | |
| tokens = [ord(c) for c in text[:max_length]] | |
| return torch.tensor(tokens).unsqueeze(0).to(device) | |
| def detokenize(tokens): | |
| """Convert tokens back to text - REPLACE with your actual detokenizer""" | |
| # This is a placeholder - use your actual detokenizer | |
| return ''.join([chr(t) for t in tokens if t < 128]) | |
| def generate_text( | |
| prompt, | |
| max_length=100, | |
| temperature=0.8, | |
| top_k=50, | |
| top_p=0.9 | |
| ): | |
| """Generate text from prompt""" | |
| try: | |
| # Tokenize input | |
| input_ids = tokenize(prompt) | |
| # Generate | |
| generated = input_ids[0].tolist() | |
| for _ in range(max_length): | |
| # Get model predictions | |
| input_tensor = torch.tensor([generated]).to(device) | |
| logits = model(input_tensor) | |
| # Get next token logits | |
| next_token_logits = logits[0, -1, :] / temperature | |
| # Apply top-k filtering | |
| if top_k > 0: | |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| # Apply top-p (nucleus) filtering | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) | |
| cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| # Sample next token | |
| probs = torch.softmax(next_token_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1).item() | |
| generated.append(next_token) | |
| # Stop if EOS token (adjust based on your vocab) | |
| if next_token == 0: # Assuming 0 is EOS | |
| break | |
| # Detokenize | |
| output_text = detokenize(generated) | |
| return output_text | |
| except Exception as e: | |
| return f"Error generating text: {str(e)}" | |
| def get_model_info(): | |
| """Display model information""" | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| info = f""" | |
| ### 📊 Model Information | |
| **Total Parameters:** {total_params:,} (~{total_params/1e6:.1f}M) | |
| **Trainable Parameters:** {trainable_params:,} | |
| **Training Steps:** {checkpoint.get('step', 'N/A')} | |
| **Device:** {device} | |
| **Model Architecture:** SmolLM2-135M | |
| ### 🎯 Training Details | |
| - Trained for 5,000 steps | |
| - Checkpoint saved and reloaded | |
| - Additional 50 steps after reload | |
| - Predictions logged every 500 steps | |
| """ | |
| return info | |
| # Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🤖 SmolLM2-135M: From-Scratch Implementation | |
| This is a complete reverse-engineered implementation of SmolLM2-135M, trained from scratch. | |
| **GitHub:** [abi2024/smollm2-135-implementation](https://github.com/abi2024/smollm2-135-implementation) | |
| """) | |
| with gr.Tab("🎮 Generate Text"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3, | |
| value="Once upon a time" | |
| ) | |
| with gr.Row(): | |
| max_length_slider = gr.Slider( | |
| minimum=10, | |
| maximum=500, | |
| value=100, | |
| step=10, | |
| label="Max Length" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| with gr.Row(): | |
| top_k_slider = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=50, | |
| step=5, | |
| label="Top-K" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-P" | |
| ) | |
| generate_btn = gr.Button("🚀 Generate", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=10, | |
| interactive=False | |
| ) | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[ | |
| prompt_input, | |
| max_length_slider, | |
| temperature_slider, | |
| top_k_slider, | |
| top_p_slider | |
| ], | |
| outputs=output_text | |
| ) | |
| gr.Markdown(""" | |
| ### 💡 Tips: | |
| - **Temperature**: Higher = more creative, Lower = more focused | |
| - **Top-K**: Limits vocabulary to K most likely tokens | |
| - **Top-P**: Nucleus sampling - cumulative probability threshold | |
| """) | |
| with gr.Tab("📊 Model Info"): | |
| model_info_display = gr.Markdown(get_model_info()) | |
| gr.Markdown(""" | |
| ### 🏗️ Architecture Details | |
| This model was reverse-engineered by: | |
| 1. Analyzing the official SmolLM2 repository | |
| 2. Extracting architecture from pretrained weights | |
| 3. Implementing from scratch in PyTorch | |
| 4. Validating by swapping weights with pretrained model | |
| ### ⚡ Optimizations Used | |
| - Flash Attention 2 | |
| - Mixed Precision Training (BF16/FP16) | |
| - Gradient Accumulation | |
| - torch.compile() | |
| ### 📈 Training Process | |
| - **Step 0-5000**: Main training with periodic predictions | |
| - **Checkpoint**: Model saved and reloaded to validate state preservation | |
| - **Step 5000-5050**: Continued training to test checkpoint robustness | |
| """) | |
| with gr.Tab("🎯 Example Prompts"): | |
| gr.Markdown(""" | |
| ### Try these prompts: | |
| 1. **Story Generation** | |
| ``` | |
| Once upon a time in a land far away | |
| ``` | |
| 2. **Code Completion** | |
| ``` | |
| def fibonacci(n): | |
| ``` | |
| 3. **Question Answering** | |
| ``` | |
| Q: What is machine learning? | |
| A: | |
| ``` | |
| 4. **Creative Writing** | |
| ``` | |
| The old house at the end of the street was | |
| ``` | |
| 5. **Technical Explanation** | |
| ``` | |
| Neural networks work by | |
| ``` | |
| """) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch() |