abi96062 commited on
Commit
144aae5
Β·
verified Β·
1 Parent(s): 0e3e3d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -120
app.py CHANGED
@@ -1,55 +1,38 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- from model import SmolLM2_135M # Import your model class
5
- import yaml
6
 
7
  # Device setup
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
 
 
 
 
 
10
  # Load model
11
  @torch.no_grad()
12
  def load_model():
13
  """Load the trained model"""
14
  print("Loading model...")
15
 
16
- # Load config
17
- with open('config.yaml', 'r') as f:
18
- config = yaml.safe_load(f)
19
-
20
- # Initialize model
21
- model = SmolLM2_135M(
22
- vocab_size=config['vocab_size'],
23
- d_model=config['d_model'],
24
- n_layers=config['n_layers'],
25
- n_heads=config['n_heads'],
26
- # Add other config parameters
27
- ).to(device)
28
 
29
  # Load checkpoint
30
- checkpoint = torch.load('checkpoints/checkpoint_step_5050.pt',
31
- map_location=device)
32
  model.load_state_dict(checkpoint['model_state_dict'])
33
  model.eval()
34
 
35
- print(f"Model loaded successfully on {device}")
 
36
  return model, checkpoint
37
 
38
  # Load model at startup
39
  model, checkpoint = load_model()
40
 
41
- # Tokenizer (adjust based on your implementation)
42
- def tokenize(text, max_length=128):
43
- """Simple character-level tokenizer - REPLACE with your actual tokenizer"""
44
- # This is a placeholder - use your actual tokenizer
45
- tokens = [ord(c) for c in text[:max_length]]
46
- return torch.tensor(tokens).unsqueeze(0).to(device)
47
-
48
- def detokenize(tokens):
49
- """Convert tokens back to text - REPLACE with your actual detokenizer"""
50
- # This is a placeholder - use your actual detokenizer
51
- return ''.join([chr(t) for t in tokens if t < 128])
52
-
53
  @torch.no_grad()
54
  def generate_text(
55
  prompt,
@@ -61,79 +44,62 @@ def generate_text(
61
  """Generate text from prompt"""
62
  try:
63
  # Tokenize input
64
- input_ids = tokenize(prompt)
65
-
66
- # Generate
67
- generated = input_ids[0].tolist()
68
 
69
- for _ in range(max_length):
70
- # Get model predictions
71
- input_tensor = torch.tensor([generated]).to(device)
72
- logits = model(input_tensor)
73
-
74
- # Get next token logits
75
- next_token_logits = logits[0, -1, :] / temperature
76
-
77
- # Apply top-k filtering
78
- if top_k > 0:
79
- indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
80
- next_token_logits[indices_to_remove] = float('-inf')
81
-
82
- # Apply top-p (nucleus) filtering
83
- if top_p < 1.0:
84
- sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
85
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
86
- sorted_indices_to_remove = cumulative_probs > top_p
87
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
88
- sorted_indices_to_remove[..., 0] = 0
89
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
90
- next_token_logits[indices_to_remove] = float('-inf')
91
-
92
- # Sample next token
93
- probs = torch.softmax(next_token_logits, dim=-1)
94
- next_token = torch.multinomial(probs, num_samples=1).item()
95
-
96
- generated.append(next_token)
97
-
98
- # Stop if EOS token (adjust based on your vocab)
99
- if next_token == 0: # Assuming 0 is EOS
100
- break
101
 
102
- # Detokenize
103
- output_text = detokenize(generated)
104
  return output_text
105
 
106
  except Exception as e:
107
- return f"Error generating text: {str(e)}"
108
 
109
  def get_model_info():
110
  """Display model information"""
111
- total_params = sum(p.numel() for p in model.parameters())
112
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
113
 
114
  info = f"""
115
- ### πŸ“Š Model Information
116
-
117
- **Total Parameters:** {total_params:,} (~{total_params/1e6:.1f}M)
118
- **Trainable Parameters:** {trainable_params:,}
119
- **Training Steps:** {checkpoint.get('step', 'N/A')}
120
- **Device:** {device}
121
- **Model Architecture:** SmolLM2-135M
122
-
123
- ### 🎯 Training Details
124
- - Trained for 5,000 steps
125
- - Checkpoint saved and reloaded
126
- - Additional 50 steps after reload
127
- - Predictions logged every 500 steps
 
 
 
 
 
 
 
 
128
  """
129
  return info
130
 
131
  # Gradio Interface
132
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
  gr.Markdown("""
134
  # πŸ€– SmolLM2-135M: From-Scratch Implementation
135
 
136
- This is a complete reverse-engineered implementation of SmolLM2-135M, trained from scratch.
137
 
138
  **GitHub:** [abi2024/smollm2-135-implementation](https://github.com/abi2024/smollm2-135-implementation)
139
  """)
@@ -151,10 +117,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
151
  with gr.Row():
152
  max_length_slider = gr.Slider(
153
  minimum=10,
154
- maximum=500,
155
- value=100,
156
  step=10,
157
- label="Max Length"
158
  )
159
  temperature_slider = gr.Slider(
160
  minimum=0.1,
@@ -177,15 +143,15 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
177
  maximum=1.0,
178
  value=0.9,
179
  step=0.05,
180
- label="Top-P"
181
  )
182
 
183
- generate_btn = gr.Button("πŸš€ Generate", variant="primary")
184
 
185
  with gr.Column():
186
  output_text = gr.Textbox(
187
  label="Generated Text",
188
- lines=10,
189
  interactive=False
190
  )
191
 
@@ -202,65 +168,82 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
202
  )
203
 
204
  gr.Markdown("""
205
- ### πŸ’‘ Tips:
206
- - **Temperature**: Higher = more creative, Lower = more focused
207
- - **Top-K**: Limits vocabulary to K most likely tokens
208
- - **Top-P**: Nucleus sampling - cumulative probability threshold
209
  """)
210
 
211
  with gr.Tab("πŸ“Š Model Info"):
212
  model_info_display = gr.Markdown(get_model_info())
213
 
214
  gr.Markdown("""
215
- ### πŸ—οΈ Architecture Details
 
 
 
 
 
216
 
217
- This model was reverse-engineered by:
218
- 1. Analyzing the official SmolLM2 repository
219
- 2. Extracting architecture from pretrained weights
220
- 3. Implementing from scratch in PyTorch
221
- 4. Validating by swapping weights with pretrained model
222
 
223
- ### ⚑ Optimizations Used
224
- - Flash Attention 2
225
- - Mixed Precision Training (BF16/FP16)
226
- - Gradient Accumulation
227
- - torch.compile()
228
 
229
- ### πŸ“ˆ Training Process
230
- - **Step 0-5000**: Main training with periodic predictions
231
- - **Checkpoint**: Model saved and reloaded to validate state preservation
232
- - **Step 5000-5050**: Continued training to test checkpoint robustness
 
 
 
 
 
 
 
233
  """)
234
 
235
  with gr.Tab("🎯 Example Prompts"):
236
  gr.Markdown("""
237
  ### Try these prompts:
238
 
239
- 1. **Story Generation**
240
  ```
241
- Once upon a time in a land far away
242
  ```
243
 
244
- 2. **Code Completion**
245
  ```
246
- def fibonacci(n):
 
247
  ```
248
 
249
- 3. **Question Answering**
250
  ```
251
- Q: What is machine learning?
252
- A:
253
  ```
254
 
255
- 4. **Creative Writing**
256
  ```
257
- The old house at the end of the street was
258
  ```
259
 
260
- 5. **Technical Explanation**
261
  ```
262
- Neural networks work by
263
  ```
 
 
 
 
 
264
  """)
265
 
266
  # Launch
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ from model import SmolLM2Model # βœ… Correct import
5
+ from transformers import AutoTokenizer, AutoConfig
6
 
7
  # Device setup
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
+ # Load tokenizer and config
11
+ print("Loading tokenizer and config...")
12
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
13
+ config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
14
+
15
  # Load model
16
  @torch.no_grad()
17
  def load_model():
18
  """Load the trained model"""
19
  print("Loading model...")
20
 
21
+ # Initialize model with config
22
+ model = SmolLM2Model(config).to(device)
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Load checkpoint
25
+ checkpoint = torch.load('checkpoint_step_5050.pt', map_location=device)
 
26
  model.load_state_dict(checkpoint['model_state_dict'])
27
  model.eval()
28
 
29
+ print(f"βœ… Model loaded successfully on {device}")
30
+ print(f"βœ… Training step: {checkpoint.get('step', 'N/A')}")
31
  return model, checkpoint
32
 
33
  # Load model at startup
34
  model, checkpoint = load_model()
35
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @torch.no_grad()
37
  def generate_text(
38
  prompt,
 
44
  """Generate text from prompt"""
45
  try:
46
  # Tokenize input
47
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
48
+ input_ids = inputs['input_ids']
 
 
49
 
50
+ # Generate using model's built-in method
51
+ generated_ids = model.generate(
52
+ input_ids,
53
+ max_new_tokens=max_length,
54
+ temperature=temperature,
55
+ top_p=top_p,
56
+ top_k=top_k if top_k > 0 else None,
57
+ do_sample=temperature > 0
58
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Decode
61
+ output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
62
  return output_text
63
 
64
  except Exception as e:
65
+ return f"❌ Error generating text: {str(e)}"
66
 
67
  def get_model_info():
68
  """Display model information"""
69
+ total_params = model.get_num_params()
70
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
71
 
72
  info = f"""
73
+ ### πŸ“Š Model Information
74
+
75
+ **Model:** SmolLM2-135M
76
+ **Total Parameters:** {total_params:,} (~{total_params/1e6:.1f}M)
77
+ **Trainable Parameters:** {trainable_params:,}
78
+ **Training Steps:** {checkpoint.get('step', 'N/A')}
79
+ **Device:** {device}
80
+ **Vocab Size:** {config.vocab_size:,}
81
+
82
+ ### πŸ—οΈ Architecture
83
+ - **Layers:** {config.num_hidden_layers}
84
+ - **Hidden Size:** {config.hidden_size}
85
+ - **Attention Heads:** {config.num_attention_heads} (Query) / {config.num_key_value_heads} (KV)
86
+ - **FFN Size:** {config.intermediate_size}
87
+ - **Context Length:** {config.max_position_embeddings}
88
+
89
+ ### 🎯 Training Details
90
+ - βœ… Trained for 5,000 steps
91
+ - βœ… Checkpoint saved and reloaded
92
+ - βœ… Additional 50 steps after reload
93
+ - βœ… Predictions logged every 500 steps
94
  """
95
  return info
96
 
97
  # Gradio Interface
98
+ with gr.Blocks(theme=gr.themes.Soft(), title="SmolLM2-135M Demo") as demo:
99
  gr.Markdown("""
100
  # πŸ€– SmolLM2-135M: From-Scratch Implementation
101
 
102
+ Complete reverse-engineered implementation of SmolLM2-135M, trained from scratch.
103
 
104
  **GitHub:** [abi2024/smollm2-135-implementation](https://github.com/abi2024/smollm2-135-implementation)
105
  """)
 
117
  with gr.Row():
118
  max_length_slider = gr.Slider(
119
  minimum=10,
120
+ maximum=200,
121
+ value=50,
122
  step=10,
123
+ label="Max New Tokens"
124
  )
125
  temperature_slider = gr.Slider(
126
  minimum=0.1,
 
143
  maximum=1.0,
144
  value=0.9,
145
  step=0.05,
146
+ label="Top-P (Nucleus)"
147
  )
148
 
149
+ generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
150
 
151
  with gr.Column():
152
  output_text = gr.Textbox(
153
  label="Generated Text",
154
+ lines=12,
155
  interactive=False
156
  )
157
 
 
168
  )
169
 
170
  gr.Markdown("""
171
+ ### πŸ’‘ Generation Tips:
172
+ - **Temperature**: Controls randomness (0.1 = focused, 2.0 = creative)
173
+ - **Top-K**: Limits to K most likely tokens (0 = disabled)
174
+ - **Top-P**: Nucleus sampling threshold (0.9 recommended)
175
  """)
176
 
177
  with gr.Tab("πŸ“Š Model Info"):
178
  model_info_display = gr.Markdown(get_model_info())
179
 
180
  gr.Markdown("""
181
+ ### πŸ” Reverse Engineering Process
182
+
183
+ 1. **Architecture Analysis**
184
+ - Studied SmolLM2 GitHub repository
185
+ - Extracted model configuration from YAML
186
+ - Downloaded pretrained 135M checkpoint
187
 
188
+ 2. **Implementation**
189
+ - Built from scratch using PyTorch
190
+ - Implemented Grouped Query Attention (9Q/3KV heads)
191
+ - Added RoPE position embeddings
192
+ - Used SwiGLU FFN and RMSNorm
193
 
194
+ 3. **Validation**
195
+ - Loaded official pretrained weights
196
+ - Verified parameter count (134,515,008)
197
+ - Confirmed architecture matches exactly
 
198
 
199
+ ### ⚑ Optimizations Applied
200
+ - βœ… Flash Attention 2 (via scaled_dot_product_attention)
201
+ - βœ… Mixed Precision Training (BF16/FP16)
202
+ - βœ… Gradient Accumulation
203
+ - βœ… torch.compile() for inference speedup
204
+ - βœ… Grouped Query Attention (memory efficient)
205
+
206
+ ### πŸ“ˆ Training Pipeline
207
+ 1. **Main Training:** 5,000 steps with predictions every 500 steps
208
+ 2. **Checkpoint Test:** Model saved and successfully reloaded
209
+ 3. **Resume Training:** 50 additional steps (validates checkpoint integrity)
210
  """)
211
 
212
  with gr.Tab("🎯 Example Prompts"):
213
  gr.Markdown("""
214
  ### Try these prompts:
215
 
216
+ **1. Story Generation**
217
  ```
218
+ Once upon a time in a magical forest,
219
  ```
220
 
221
+ **2. Code Completion**
222
  ```
223
+ def calculate_fibonacci(n):
224
+ # Calculate the nth Fibonacci number
225
  ```
226
 
227
+ **3. Question Answering**
228
  ```
229
+ Q: What is the capital of France?
230
+ A:
231
  ```
232
 
233
+ **4. Technical Writing**
234
  ```
235
+ The main advantage of transformer architectures is
236
  ```
237
 
238
+ **5. Creative Writing**
239
  ```
240
+ The scientist discovered something extraordinary:
241
  ```
242
+
243
+ ### πŸŽ›οΈ Recommended Settings:
244
+ - **Creative Writing:** Temperature=1.0, Top-P=0.95
245
+ - **Code Generation:** Temperature=0.3, Top-P=0.9, Top-K=40
246
+ - **Factual Q&A:** Temperature=0.5, Top-P=0.8, Top-K=30
247
  """)
248
 
249
  # Launch