abi96062 commited on
Commit
dd84964
ยท
verified ยท
1 Parent(s): f63f21f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -0
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from model import SmolLM2_135M # Import your model class
5
+ import yaml
6
+
7
+ # Device setup
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # Load model
11
+ @torch.no_grad()
12
+ def load_model():
13
+ """Load the trained model"""
14
+ print("Loading model...")
15
+
16
+ # Load config
17
+ with open('config.yaml', 'r') as f:
18
+ config = yaml.safe_load(f)
19
+
20
+ # Initialize model
21
+ model = SmolLM2_135M(
22
+ vocab_size=config['vocab_size'],
23
+ d_model=config['d_model'],
24
+ n_layers=config['n_layers'],
25
+ n_heads=config['n_heads'],
26
+ # Add other config parameters
27
+ ).to(device)
28
+
29
+ # Load checkpoint
30
+ checkpoint = torch.load('checkpoints/checkpoint_step_5050.pt',
31
+ map_location=device)
32
+ model.load_state_dict(checkpoint['model_state_dict'])
33
+ model.eval()
34
+
35
+ print(f"Model loaded successfully on {device}")
36
+ return model, checkpoint
37
+
38
+ # Load model at startup
39
+ model, checkpoint = load_model()
40
+
41
+ # Tokenizer (adjust based on your implementation)
42
+ def tokenize(text, max_length=128):
43
+ """Simple character-level tokenizer - REPLACE with your actual tokenizer"""
44
+ # This is a placeholder - use your actual tokenizer
45
+ tokens = [ord(c) for c in text[:max_length]]
46
+ return torch.tensor(tokens).unsqueeze(0).to(device)
47
+
48
+ def detokenize(tokens):
49
+ """Convert tokens back to text - REPLACE with your actual detokenizer"""
50
+ # This is a placeholder - use your actual detokenizer
51
+ return ''.join([chr(t) for t in tokens if t < 128])
52
+
53
+ @torch.no_grad()
54
+ def generate_text(
55
+ prompt,
56
+ max_length=100,
57
+ temperature=0.8,
58
+ top_k=50,
59
+ top_p=0.9
60
+ ):
61
+ """Generate text from prompt"""
62
+ try:
63
+ # Tokenize input
64
+ input_ids = tokenize(prompt)
65
+
66
+ # Generate
67
+ generated = input_ids[0].tolist()
68
+
69
+ for _ in range(max_length):
70
+ # Get model predictions
71
+ input_tensor = torch.tensor([generated]).to(device)
72
+ logits = model(input_tensor)
73
+
74
+ # Get next token logits
75
+ next_token_logits = logits[0, -1, :] / temperature
76
+
77
+ # Apply top-k filtering
78
+ if top_k > 0:
79
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
80
+ next_token_logits[indices_to_remove] = float('-inf')
81
+
82
+ # Apply top-p (nucleus) filtering
83
+ if top_p < 1.0:
84
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
85
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
86
+ sorted_indices_to_remove = cumulative_probs > top_p
87
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
88
+ sorted_indices_to_remove[..., 0] = 0
89
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
90
+ next_token_logits[indices_to_remove] = float('-inf')
91
+
92
+ # Sample next token
93
+ probs = torch.softmax(next_token_logits, dim=-1)
94
+ next_token = torch.multinomial(probs, num_samples=1).item()
95
+
96
+ generated.append(next_token)
97
+
98
+ # Stop if EOS token (adjust based on your vocab)
99
+ if next_token == 0: # Assuming 0 is EOS
100
+ break
101
+
102
+ # Detokenize
103
+ output_text = detokenize(generated)
104
+ return output_text
105
+
106
+ except Exception as e:
107
+ return f"Error generating text: {str(e)}"
108
+
109
+ def get_model_info():
110
+ """Display model information"""
111
+ total_params = sum(p.numel() for p in model.parameters())
112
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
113
+
114
+ info = f"""
115
+ ### ๐Ÿ“Š Model Information
116
+
117
+ **Total Parameters:** {total_params:,} (~{total_params/1e6:.1f}M)
118
+ **Trainable Parameters:** {trainable_params:,}
119
+ **Training Steps:** {checkpoint.get('step', 'N/A')}
120
+ **Device:** {device}
121
+ **Model Architecture:** SmolLM2-135M
122
+
123
+ ### ๐ŸŽฏ Training Details
124
+ - Trained for 5,000 steps
125
+ - Checkpoint saved and reloaded
126
+ - Additional 50 steps after reload
127
+ - Predictions logged every 500 steps
128
+ """
129
+ return info
130
+
131
+ # Gradio Interface
132
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
+ gr.Markdown("""
134
+ # ๐Ÿค– SmolLM2-135M: From-Scratch Implementation
135
+
136
+ This is a complete reverse-engineered implementation of SmolLM2-135M, trained from scratch.
137
+
138
+ **GitHub:** [abi2024/smollm2-135-implementation](https://github.com/abi2024/smollm2-135-implementation)
139
+ """)
140
+
141
+ with gr.Tab("๐ŸŽฎ Generate Text"):
142
+ with gr.Row():
143
+ with gr.Column():
144
+ prompt_input = gr.Textbox(
145
+ label="Prompt",
146
+ placeholder="Enter your prompt here...",
147
+ lines=3,
148
+ value="Once upon a time"
149
+ )
150
+
151
+ with gr.Row():
152
+ max_length_slider = gr.Slider(
153
+ minimum=10,
154
+ maximum=500,
155
+ value=100,
156
+ step=10,
157
+ label="Max Length"
158
+ )
159
+ temperature_slider = gr.Slider(
160
+ minimum=0.1,
161
+ maximum=2.0,
162
+ value=0.8,
163
+ step=0.1,
164
+ label="Temperature"
165
+ )
166
+
167
+ with gr.Row():
168
+ top_k_slider = gr.Slider(
169
+ minimum=0,
170
+ maximum=100,
171
+ value=50,
172
+ step=5,
173
+ label="Top-K"
174
+ )
175
+ top_p_slider = gr.Slider(
176
+ minimum=0.0,
177
+ maximum=1.0,
178
+ value=0.9,
179
+ step=0.05,
180
+ label="Top-P"
181
+ )
182
+
183
+ generate_btn = gr.Button("๐Ÿš€ Generate", variant="primary")
184
+
185
+ with gr.Column():
186
+ output_text = gr.Textbox(
187
+ label="Generated Text",
188
+ lines=10,
189
+ interactive=False
190
+ )
191
+
192
+ generate_btn.click(
193
+ fn=generate_text,
194
+ inputs=[
195
+ prompt_input,
196
+ max_length_slider,
197
+ temperature_slider,
198
+ top_k_slider,
199
+ top_p_slider
200
+ ],
201
+ outputs=output_text
202
+ )
203
+
204
+ gr.Markdown("""
205
+ ### ๐Ÿ’ก Tips:
206
+ - **Temperature**: Higher = more creative, Lower = more focused
207
+ - **Top-K**: Limits vocabulary to K most likely tokens
208
+ - **Top-P**: Nucleus sampling - cumulative probability threshold
209
+ """)
210
+
211
+ with gr.Tab("๐Ÿ“Š Model Info"):
212
+ model_info_display = gr.Markdown(get_model_info())
213
+
214
+ gr.Markdown("""
215
+ ### ๐Ÿ—๏ธ Architecture Details
216
+
217
+ This model was reverse-engineered by:
218
+ 1. Analyzing the official SmolLM2 repository
219
+ 2. Extracting architecture from pretrained weights
220
+ 3. Implementing from scratch in PyTorch
221
+ 4. Validating by swapping weights with pretrained model
222
+
223
+ ### โšก Optimizations Used
224
+ - Flash Attention 2
225
+ - Mixed Precision Training (BF16/FP16)
226
+ - Gradient Accumulation
227
+ - torch.compile()
228
+
229
+ ### ๐Ÿ“ˆ Training Process
230
+ - **Step 0-5000**: Main training with periodic predictions
231
+ - **Checkpoint**: Model saved and reloaded to validate state preservation
232
+ - **Step 5000-5050**: Continued training to test checkpoint robustness
233
+ """)
234
+
235
+ with gr.Tab("๐ŸŽฏ Example Prompts"):
236
+ gr.Markdown("""
237
+ ### Try these prompts:
238
+
239
+ 1. **Story Generation**
240
+ ```
241
+ Once upon a time in a land far away
242
+ ```
243
+
244
+ 2. **Code Completion**
245
+ ```
246
+ def fibonacci(n):
247
+ ```
248
+
249
+ 3. **Question Answering**
250
+ ```
251
+ Q: What is machine learning?
252
+ A:
253
+ ```
254
+
255
+ 4. **Creative Writing**
256
+ ```
257
+ The old house at the end of the street was
258
+ ```
259
+
260
+ 5. **Technical Explanation**
261
+ ```
262
+ Neural networks work by
263
+ ```
264
+ """)
265
+
266
+ # Launch
267
+ if __name__ == "__main__":
268
+ demo.launch()