|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from model import SmolLM2Model |
|
|
from transformers import AutoTokenizer, AutoConfig |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
print("Loading tokenizer and config...") |
|
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") |
|
|
config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M") |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def load_model(): |
|
|
"""Load the trained model""" |
|
|
print("Loading model...") |
|
|
|
|
|
|
|
|
model = SmolLM2Model(config).to(device) |
|
|
|
|
|
|
|
|
checkpoint = torch.load('checkpoint_step_5050.pt', map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
print(f"โ
Model loaded successfully on {device}") |
|
|
print(f"โ
Training step: {checkpoint.get('step', 'N/A')}") |
|
|
return model, checkpoint |
|
|
|
|
|
|
|
|
model, checkpoint = load_model() |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_text( |
|
|
prompt, |
|
|
max_length=100, |
|
|
temperature=0.8, |
|
|
top_k=50, |
|
|
top_p=0.9 |
|
|
): |
|
|
"""Generate text from prompt""" |
|
|
try: |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
input_ids = inputs['input_ids'] |
|
|
|
|
|
|
|
|
generated_ids = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_length, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k if top_k > 0 else None, |
|
|
do_sample=temperature > 0 |
|
|
) |
|
|
|
|
|
|
|
|
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
return output_text |
|
|
|
|
|
except Exception as e: |
|
|
return f"โ Error generating text: {str(e)}" |
|
|
|
|
|
def get_model_info(): |
|
|
"""Display model information""" |
|
|
total_params = model.get_num_params() |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
info = f""" |
|
|
### ๐ Model Information |
|
|
|
|
|
**Model:** SmolLM2-135M |
|
|
**Total Parameters:** {total_params:,} (~{total_params/1e6:.1f}M) |
|
|
**Trainable Parameters:** {trainable_params:,} |
|
|
**Training Steps:** {checkpoint.get('step', 'N/A')} |
|
|
**Device:** {device} |
|
|
**Vocab Size:** {config.vocab_size:,} |
|
|
|
|
|
### ๐๏ธ Architecture |
|
|
- **Layers:** {config.num_hidden_layers} |
|
|
- **Hidden Size:** {config.hidden_size} |
|
|
- **Attention Heads:** {config.num_attention_heads} (Query) / {config.num_key_value_heads} (KV) |
|
|
- **FFN Size:** {config.intermediate_size} |
|
|
- **Context Length:** {config.max_position_embeddings} |
|
|
|
|
|
### ๐ฏ Training Details |
|
|
- โ
Trained for 5,000 steps |
|
|
- โ
Checkpoint saved and reloaded |
|
|
- โ
Additional 50 steps after reload |
|
|
- โ
Predictions logged every 500 steps |
|
|
""" |
|
|
return info |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="SmolLM2-135M Demo") as demo: |
|
|
gr.Markdown(""" |
|
|
# ๐ค SmolLM2-135M: From-Scratch Implementation |
|
|
|
|
|
Complete reverse-engineered implementation of SmolLM2-135M, trained from scratch. |
|
|
|
|
|
**GitHub:** [abi2024/smollm2-135-implementation](https://github.com/abi2024/smollm2-135-implementation) |
|
|
""") |
|
|
|
|
|
with gr.Tab("๐ฎ Generate Text"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt_input = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="Enter your prompt here...", |
|
|
lines=3, |
|
|
value="Once upon a time" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
max_length_slider = gr.Slider( |
|
|
minimum=10, |
|
|
maximum=200, |
|
|
value=50, |
|
|
step=10, |
|
|
label="Max New Tokens" |
|
|
) |
|
|
temperature_slider = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=2.0, |
|
|
value=0.8, |
|
|
step=0.1, |
|
|
label="Temperature" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
top_k_slider = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=100, |
|
|
value=50, |
|
|
step=5, |
|
|
label="Top-K" |
|
|
) |
|
|
top_p_slider = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.05, |
|
|
label="Top-P (Nucleus)" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("๐ Generate", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox( |
|
|
label="Generated Text", |
|
|
lines=12, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_text, |
|
|
inputs=[ |
|
|
prompt_input, |
|
|
max_length_slider, |
|
|
temperature_slider, |
|
|
top_k_slider, |
|
|
top_p_slider |
|
|
], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### ๐ก Generation Tips: |
|
|
- **Temperature**: Controls randomness (0.1 = focused, 2.0 = creative) |
|
|
- **Top-K**: Limits to K most likely tokens (0 = disabled) |
|
|
- **Top-P**: Nucleus sampling threshold (0.9 recommended) |
|
|
""") |
|
|
|
|
|
with gr.Tab("๐ Model Info"): |
|
|
model_info_display = gr.Markdown(get_model_info()) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### ๐ Reverse Engineering Process |
|
|
|
|
|
1. **Architecture Analysis** |
|
|
- Studied SmolLM2 GitHub repository |
|
|
- Extracted model configuration from YAML |
|
|
- Downloaded pretrained 135M checkpoint |
|
|
|
|
|
2. **Implementation** |
|
|
- Built from scratch using PyTorch |
|
|
- Implemented Grouped Query Attention (9Q/3KV heads) |
|
|
- Added RoPE position embeddings |
|
|
- Used SwiGLU FFN and RMSNorm |
|
|
|
|
|
3. **Validation** |
|
|
- Loaded official pretrained weights |
|
|
- Verified parameter count (134,515,008) |
|
|
- Confirmed architecture matches exactly |
|
|
|
|
|
### โก Optimizations Applied |
|
|
- โ
Flash Attention 2 (via scaled_dot_product_attention) |
|
|
- โ
Mixed Precision Training (BF16/FP16) |
|
|
- โ
Gradient Accumulation |
|
|
- โ
torch.compile() for inference speedup |
|
|
- โ
Grouped Query Attention (memory efficient) |
|
|
|
|
|
### ๐ Training Pipeline |
|
|
1. **Main Training:** 5,000 steps with predictions every 500 steps |
|
|
2. **Checkpoint Test:** Model saved and successfully reloaded |
|
|
3. **Resume Training:** 50 additional steps (validates checkpoint integrity) |
|
|
""") |
|
|
|
|
|
with gr.Tab("๐ฏ Example Prompts"): |
|
|
gr.Markdown(""" |
|
|
### Try these prompts: |
|
|
|
|
|
**1. Story Generation** |
|
|
``` |
|
|
Once upon a time in a magical forest, |
|
|
``` |
|
|
|
|
|
**2. Code Completion** |
|
|
``` |
|
|
def calculate_fibonacci(n): |
|
|
# Calculate the nth Fibonacci number |
|
|
``` |
|
|
|
|
|
**3. Question Answering** |
|
|
``` |
|
|
Q: What is the capital of France? |
|
|
A: |
|
|
``` |
|
|
|
|
|
**4. Technical Writing** |
|
|
``` |
|
|
The main advantage of transformer architectures is |
|
|
``` |
|
|
|
|
|
**5. Creative Writing** |
|
|
``` |
|
|
The scientist discovered something extraordinary: |
|
|
``` |
|
|
|
|
|
### ๐๏ธ Recommended Settings: |
|
|
- **Creative Writing:** Temperature=1.0, Top-P=0.95 |
|
|
- **Code Generation:** Temperature=0.3, Top-P=0.9, Top-K=40 |
|
|
- **Factual Q&A:** Temperature=0.5, Top-P=0.8, Top-K=30 |
|
|
""") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |