abi96062's picture
Create app.py
dd84964 verified
raw
history blame
8.6 kB
import gradio as gr
import torch
import torch.nn as nn
from model import SmolLM2_135M # Import your model class
import yaml
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
@torch.no_grad()
def load_model():
"""Load the trained model"""
print("Loading model...")
# Load config
with open('config.yaml', 'r') as f:
config = yaml.safe_load(f)
# Initialize model
model = SmolLM2_135M(
vocab_size=config['vocab_size'],
d_model=config['d_model'],
n_layers=config['n_layers'],
n_heads=config['n_heads'],
# Add other config parameters
).to(device)
# Load checkpoint
checkpoint = torch.load('checkpoints/checkpoint_step_5050.pt',
map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Model loaded successfully on {device}")
return model, checkpoint
# Load model at startup
model, checkpoint = load_model()
# Tokenizer (adjust based on your implementation)
def tokenize(text, max_length=128):
"""Simple character-level tokenizer - REPLACE with your actual tokenizer"""
# This is a placeholder - use your actual tokenizer
tokens = [ord(c) for c in text[:max_length]]
return torch.tensor(tokens).unsqueeze(0).to(device)
def detokenize(tokens):
"""Convert tokens back to text - REPLACE with your actual detokenizer"""
# This is a placeholder - use your actual detokenizer
return ''.join([chr(t) for t in tokens if t < 128])
@torch.no_grad()
def generate_text(
prompt,
max_length=100,
temperature=0.8,
top_k=50,
top_p=0.9
):
"""Generate text from prompt"""
try:
# Tokenize input
input_ids = tokenize(prompt)
# Generate
generated = input_ids[0].tolist()
for _ in range(max_length):
# Get model predictions
input_tensor = torch.tensor([generated]).to(device)
logits = model(input_tensor)
# Get next token logits
next_token_logits = logits[0, -1, :] / temperature
# Apply top-k filtering
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[indices_to_remove] = float('-inf')
# Sample next token
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
generated.append(next_token)
# Stop if EOS token (adjust based on your vocab)
if next_token == 0: # Assuming 0 is EOS
break
# Detokenize
output_text = detokenize(generated)
return output_text
except Exception as e:
return f"Error generating text: {str(e)}"
def get_model_info():
"""Display model information"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
info = f"""
### 📊 Model Information
**Total Parameters:** {total_params:,} (~{total_params/1e6:.1f}M)
**Trainable Parameters:** {trainable_params:,}
**Training Steps:** {checkpoint.get('step', 'N/A')}
**Device:** {device}
**Model Architecture:** SmolLM2-135M
### 🎯 Training Details
- Trained for 5,000 steps
- Checkpoint saved and reloaded
- Additional 50 steps after reload
- Predictions logged every 500 steps
"""
return info
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🤖 SmolLM2-135M: From-Scratch Implementation
This is a complete reverse-engineered implementation of SmolLM2-135M, trained from scratch.
**GitHub:** [abi2024/smollm2-135-implementation](https://github.com/abi2024/smollm2-135-implementation)
""")
with gr.Tab("🎮 Generate Text"):
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=3,
value="Once upon a time"
)
with gr.Row():
max_length_slider = gr.Slider(
minimum=10,
maximum=500,
value=100,
step=10,
label="Max Length"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature"
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=5,
label="Top-K"
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P"
)
generate_btn = gr.Button("🚀 Generate", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=10,
interactive=False
)
generate_btn.click(
fn=generate_text,
inputs=[
prompt_input,
max_length_slider,
temperature_slider,
top_k_slider,
top_p_slider
],
outputs=output_text
)
gr.Markdown("""
### 💡 Tips:
- **Temperature**: Higher = more creative, Lower = more focused
- **Top-K**: Limits vocabulary to K most likely tokens
- **Top-P**: Nucleus sampling - cumulative probability threshold
""")
with gr.Tab("📊 Model Info"):
model_info_display = gr.Markdown(get_model_info())
gr.Markdown("""
### 🏗️ Architecture Details
This model was reverse-engineered by:
1. Analyzing the official SmolLM2 repository
2. Extracting architecture from pretrained weights
3. Implementing from scratch in PyTorch
4. Validating by swapping weights with pretrained model
### ⚡ Optimizations Used
- Flash Attention 2
- Mixed Precision Training (BF16/FP16)
- Gradient Accumulation
- torch.compile()
### 📈 Training Process
- **Step 0-5000**: Main training with periodic predictions
- **Checkpoint**: Model saved and reloaded to validate state preservation
- **Step 5000-5050**: Continued training to test checkpoint robustness
""")
with gr.Tab("🎯 Example Prompts"):
gr.Markdown("""
### Try these prompts:
1. **Story Generation**
```
Once upon a time in a land far away
```
2. **Code Completion**
```
def fibonacci(n):
```
3. **Question Answering**
```
Q: What is machine learning?
A:
```
4. **Creative Writing**
```
The old house at the end of the street was
```
5. **Technical Explanation**
```
Neural networks work by
```
""")
# Launch
if __name__ == "__main__":
demo.launch()