Sualeh Qureshi commited on
Commit
13f6128
Β·
1 Parent(s): dc69345

Added README.md, and app.py for huggingface space

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. README.md +321 -0
  3. README_SPACE.md +108 -0
  4. app.py +259 -0
  5. requirements.txt +5 -0
.gitignore CHANGED
@@ -12,3 +12,6 @@ wheels/
12
  # Checkpoints
13
  checkpoints/
14
 
 
 
 
 
12
  # Checkpoints
13
  checkpoints/
14
 
15
+ # tensorboard logs
16
+ logs/tensorboard/
17
+
README.md CHANGED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SmolLM2-135M Implementation
2
+
3
+ A from-scratch PyTorch implementation of the SmolLM2-135M language model, following the LLaMA architecture with modern optimizations.
4
+
5
+ ## Overview
6
+
7
+ This repository contains a complete implementation of SmolLM2-135M, a 135 million parameter decoder-only transformer model. The implementation includes:
8
+
9
+ - **Model Architecture** (`model.py`): Complete model definition with KV cache support
10
+ - **Training Script** (`train.py`): PyTorch Lightning training with WSD scheduler
11
+ - **Gradio App** (`app.py`): Interactive web interface for text generation
12
+
13
+ ## Model Architecture (`model.py`)
14
+
15
+ ### Architecture Components
16
+
17
+ The model follows the LLaMA-style decoder-only transformer architecture with the following key components:
18
+
19
+ #### 1. **SmolConfig** (Configuration Class)
20
+
21
+ A dataclass that stores all model hyperparameters:
22
+
23
+ ```python
24
+ @dataclass
25
+ class SmolConfig:
26
+ vocab_size: int = 49152 # Vocabulary size
27
+ hidden_size: int = 576 # Hidden dimension
28
+ intermediate_size: int = 1536 # MLP intermediate dimension
29
+ num_hidden_layers: int = 30 # Number of transformer layers
30
+ num_attention_heads: int = 9 # Number of query heads
31
+ num_key_value_heads: int = 3 # Number of key/value heads (GQA)
32
+ max_position_embeddings: int = 8192 # Maximum sequence length
33
+ rope_theta: float = 100000.0 # RoPE base frequency
34
+ rms_norm_eps: float = 1e-5 # RMSNorm epsilon
35
+ attention_bias: bool = False # Whether to use bias in attention
36
+ mlp_bias: bool = False # Whether to use bias in MLP
37
+ dtype: torch.dtype = torch.bfloat16
38
+ ```
39
+
40
+ **Key Features:**
41
+ - `head_dim` property: Automatically computes head dimension (hidden_size // num_attention_heads = 64)
42
+ - `from_hf()` class method: Loads configuration from HuggingFace model config
43
+
44
+ #### 2. **RMSNorm** (Root Mean Square Normalization)
45
+
46
+ Replaces LayerNorm with a more efficient normalization:
47
+
48
+ ```python
49
+ class RMSNorm(nn.Module):
50
+ def forward(self, x):
51
+ norm = x.pow(2).mean(dim=-1, keepdim=True)
52
+ x = x * torch.rsqrt(norm + self.eps)
53
+ return self.weight * x
54
+ ```
55
+
56
+ **Benefits:**
57
+ - More efficient than LayerNorm (no mean subtraction)
58
+ - Used throughout the model for pre-norm architecture
59
+
60
+ #### 3. **RoPE** (Rotary Positional Embeddings)
61
+
62
+ Rotary Position Embeddings applied to query and key tensors:
63
+
64
+ ```python
65
+ def build_rope_cache(seq_len, head_dim, base, device, dtype):
66
+ # Computes cosine and sine caches for RoPE
67
+ inv_freq = 1.0 / (base ** (freq_seq / half_dim))
68
+ freqs = torch.outer(t, inv_freq)
69
+ cos = freqs.cos()[None, None, :, :]
70
+ sin = freqs.sin()[None, None, :, :]
71
+ return cos, sin
72
+
73
+ def apply_rope(x, cos, sin):
74
+ # Applies rotary transformation to input tensor
75
+ x1, x2 = x[..., :half], x[..., half:]
76
+ x1_rot = x1 * cos - x2 * sin
77
+ x2_rot = x1 * sin + x2 * cos
78
+ return torch.cat([x1_rot, x2_rot], dim=-1)
79
+ ```
80
+
81
+ **Key Features:**
82
+ - Relative positional encoding (no absolute position embeddings)
83
+ - Applied only to Q and K (not V)
84
+ - Supports efficient caching for inference
85
+
86
+ #### 4. **MultiHeadSelfAttention** (Grouped Query Attention)
87
+
88
+ Implements GQA (Grouped Query Attention) where:
89
+ - **Query heads**: 9 (full attention)
90
+ - **Key/Value heads**: 3 (shared across query heads)
91
+
92
+ ```python
93
+ class MultiHeadSelfAttention(nn.Module):
94
+ def forward(self, x, cos, sin, past_key_value=None, use_cache=False):
95
+ # 1. Project to Q, K, V
96
+ q = self.q_proj(x) # (B, T, n_heads * head_dim)
97
+ k = self.k_proj(x) # (B, T, n_kv_heads * head_dim)
98
+ v = self.v_proj(x) # (B, T, n_kv_heads * head_dim)
99
+
100
+ # 2. Apply RoPE to Q and K
101
+ q = apply_rope(q, cos, sin)
102
+ k = apply_rope(k, cos, sin)
103
+
104
+ # 3. KV Cache support (for inference)
105
+ if past_key_value:
106
+ k = torch.cat([past_k, k], dim=2)
107
+ v = torch.cat([past_v, v], dim=2)
108
+
109
+ # 4. GQA: Expand K/V if needed
110
+ if n_kv_heads < n_heads:
111
+ k = k.repeat_interleave(repeat_factor, dim=1)
112
+ v = v.repeat_interleave(repeat_factor, dim=1)
113
+
114
+ # 5. Compute attention scores
115
+ scores = (q @ k.transpose(-2, -1)) / sqrt(head_dim)
116
+ scores = scores + causal_mask # Causal masking
117
+
118
+ # 6. Softmax and weighted sum
119
+ probs = F.softmax(scores, dim=-1)
120
+ out = probs @ v
121
+
122
+ return out, present_key_value
123
+ ```
124
+
125
+ **Key Features:**
126
+ - **KV Cache**: Efficient inference by caching past key-value pairs
127
+ - **GQA**: Reduces memory by sharing K/V heads (3:1 ratio)
128
+ - **Causal Masking**: Prevents attending to future tokens
129
+ - **RoPE Integration**: Positional encoding via rotary embeddings
130
+
131
+ #### 5. **SmolMLP** (SwiGLU Activation)
132
+
133
+ Implements the SwiGLU (Swish-Gated Linear Unit) MLP:
134
+
135
+ ```python
136
+ class SmolMLP(nn.Module):
137
+ def forward(self, x):
138
+ # fc1 outputs 2 * intermediate_size
139
+ x = self.fc1(x) # (B, T, 2 * 1536) = (B, T, 3072)
140
+ x1, x2 = x.chunk(2, dim=-1) # Split into two parts
141
+ # SwiGLU: SiLU(x1) * x2
142
+ return self.fc2(F.silu(x1) * x2)
143
+ ```
144
+
145
+ **Key Features:**
146
+ - **SwiGLU**: `SiLU(x1) * x2` activation (better than ReLU/GELU)
147
+ - **No bias**: Following LLaMA architecture
148
+ - **Efficient**: Single matrix multiplication with split
149
+
150
+ #### 6. **SmolBlock** (Transformer Block)
151
+
152
+ Combines attention and MLP with pre-norm and residual connections:
153
+
154
+ ```python
155
+ class SmolBlock(nn.Module):
156
+ def forward(self, x, cos, sin, past_key_value=None, use_cache=False):
157
+ # Pre-norm attention with residual
158
+ attn_out, present_kv = self.attn(
159
+ self.attn_norm(x), cos, sin,
160
+ past_key_value=past_key_value, use_cache=use_cache
161
+ )
162
+ x = x + attn_out
163
+
164
+ # Pre-norm MLP with residual
165
+ x = x + self.mlp(self.mlp_norm(x))
166
+
167
+ return x, present_kv
168
+ ```
169
+
170
+ **Architecture:**
171
+ - **Pre-norm**: Normalization before attention/MLP (not after)
172
+ - **Residual connections**: Skip connections for gradient flow
173
+ - **KV Cache passthrough**: Supports efficient inference
174
+
175
+ #### 7. **SmolLM2** (Main Model)
176
+
177
+ Top-level model that combines all components:
178
+
179
+ ```python
180
+ class SmolLM2(nn.Module):
181
+ def __init__(self, config):
182
+ self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
183
+ self.layers = nn.ModuleList([SmolBlock(config) for _ in range(30)])
184
+ self.norm = RMSNorm(hidden_size)
185
+ self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
186
+
187
+ # Weight tying: share embeddings and output weights
188
+ self.lm_head.weight = self.embed_tokens.weight
189
+
190
+ def forward(self, input_ids, past_key_values=None, use_cache=False):
191
+ # 1. Token embeddings
192
+ x = self.embed_tokens(input_ids)
193
+
194
+ # 2. Build RoPE cache
195
+ cos, sin = build_rope_cache(...)
196
+
197
+ # 3. Pass through transformer layers
198
+ present_key_values = []
199
+ for layer in self.layers:
200
+ x, present_kv = layer(x, cos, sin, past_key_value, use_cache)
201
+ if use_cache:
202
+ present_key_values.append(present_kv)
203
+
204
+ # 4. Final norm and language modeling head
205
+ x = self.norm(x)
206
+ logits = self.lm_head(x)
207
+
208
+ return logits, present_key_values
209
+ ```
210
+
211
+ **Key Features:**
212
+ - **Weight Tying**: Embeddings and output weights are shared (reduces parameters)
213
+ - **KV Cache Support**: Full support for efficient autoregressive generation
214
+ - **30 Layers**: Deep transformer stack for capacity
215
+
216
+ #### 8. **Generate Method** (Text Generation)
217
+
218
+ Autoregressive text generation with KV cache:
219
+
220
+ ```python
221
+ @torch.no_grad()
222
+ def generate(self, input_ids, max_new_tokens=100, temperature=1.0,
223
+ top_k=None, top_p=None, eos_token_id=None):
224
+ generated = input_ids
225
+ past_key_values = None
226
+
227
+ for _ in range(max_new_tokens):
228
+ # Forward pass with KV cache
229
+ logits, past_key_values = self.forward(
230
+ generated[:, -1:] if past_key_values else generated,
231
+ past_key_values=past_key_values,
232
+ use_cache=True
233
+ )
234
+
235
+ # Sample next token with temperature, top-k, top-p
236
+ next_token_logits = logits[:, -1, :] / temperature
237
+ # Apply top-k and top-p filtering
238
+ probs = F.softmax(next_token_logits, dim=-1)
239
+ next_token = torch.multinomial(probs, num_samples=1)
240
+
241
+ generated = torch.cat([generated, next_token], dim=1)
242
+
243
+ if eos_token_id and (next_token == eos_token_id).all():
244
+ break
245
+
246
+ return generated
247
+ ```
248
+
249
+ **Key Features:**
250
+ - **KV Cache**: Only processes new tokens (not entire sequence)
251
+ - **Sampling**: Supports temperature, top-k, and top-p (nucleus) sampling
252
+ - **Efficient**: O(1) per token after initial forward pass
253
+
254
+ ### Model Specifications
255
+
256
+ | Parameter | Value |
257
+ |-----------|-------|
258
+ | **Total Parameters** | ~135M |
259
+ | **Hidden Size** | 576 |
260
+ | **Layers** | 30 |
261
+ | **Attention Heads** | 9 (Q), 3 (K/V) |
262
+ | **Head Dimension** | 64 |
263
+ | **Intermediate Size** | 1536 |
264
+ | **Vocabulary Size** | 49,152 |
265
+ | **Max Sequence Length** | 8,192 |
266
+ | **RoPE Theta** | 100,000 |
267
+ | **Activation** | SwiGLU (SiLU-gated) |
268
+ | **Normalization** | RMSNorm |
269
+ | **Weight Tying** | Yes (embeddings = output) |
270
+
271
+ ### Key Design Choices
272
+
273
+ 1. **GQA (Grouped Query Attention)**: 3:1 ratio reduces memory by 66% for K/V cache
274
+ 2. **Pre-norm Architecture**: More stable training than post-norm
275
+ 3. **RMSNorm**: Faster and simpler than LayerNorm
276
+ 4. **RoPE**: Relative positional encoding, no learned embeddings
277
+ 5. **SwiGLU**: Better activation than ReLU/GELU
278
+ 6. **Weight Tying**: Reduces parameters and improves generalization
279
+ 7. **No Biases**: Following LLaMA, reduces parameters slightly
280
+
281
+ ### Usage Example
282
+
283
+ ```python
284
+ from model import SmolConfig, SmolLM2
285
+ from transformers import AutoConfig
286
+
287
+ # Load config from HuggingFace
288
+ hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
289
+ config = SmolConfig.from_hf(hf_config)
290
+
291
+ # Create model
292
+ model = SmolLM2(config)
293
+
294
+ # Forward pass (training)
295
+ input_ids = torch.randint(0, config.vocab_size, (2, 512))
296
+ logits, _ = model(input_ids, use_cache=False)
297
+
298
+ # Text generation (inference with KV cache)
299
+ prompt_ids = tokenizer.encode("Hello, how are you?")
300
+ generated = model.generate(
301
+ prompt_ids,
302
+ max_new_tokens=100,
303
+ temperature=0.8,
304
+ top_k=50
305
+ )
306
+ ```
307
+
308
+ ## Training
309
+
310
+ See `README_TRAINING.md` for detailed training instructions.
311
+
312
+ ## Inference
313
+
314
+ See `app.py` for the Gradio web interface or use the `generate()` method directly.
315
+
316
+ ## References
317
+
318
+ - [SmolLM2 Paper](https://arxiv.org/abs/2406.02528)
319
+ - [LLaMA Architecture](https://arxiv.org/abs/2302.13971)
320
+ - [RoPE: Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
321
+ - [SwiGLU Activation](https://arxiv.org/abs/2002.05202)
README_SPACE.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SmolLM2-135M Hugging Face Space Setup Guide
2
+
3
+ This guide explains how to push your model and app to Hugging Face Spaces.
4
+
5
+ ## Files Needed for Hugging Face Space
6
+
7
+ 1. **app.py** - Main Gradio application (already created)
8
+ 2. **model.py** - Model definition
9
+ 3. **train.py** - Contains SmolLM2Module class (needed for loading checkpoints)
10
+ 4. **requirements.txt** - Python dependencies
11
+ 5. **README.md** - Space description (optional but recommended)
12
+
13
+ ## Step-by-Step Guide
14
+
15
+ ### 1. Fix Merge Conflicts (if still present)
16
+
17
+ If you still have merge conflicts, resolve them:
18
+ ```bash
19
+ # Check status
20
+ git status
21
+
22
+ # Resolve conflicts in train.py and pyproject.toml
23
+ # Then commit
24
+ git add train.py pyproject.toml
25
+ git commit -m "Resolve merge conflicts"
26
+ ```
27
+
28
+ ### 2. Create Hugging Face Space (if not already created)
29
+
30
+ ```bash
31
+ # Create the space (without --sdk flag, set it in web UI)
32
+ huggingface-cli repo create smollm2-135m-trained-on-tinyShakespear-forfun --type=space
33
+ ```
34
+
35
+ Then go to the Space settings in the web UI and set:
36
+ - **SDK**: Gradio
37
+ - **Python version**: 3.12
38
+
39
+ ### 3. Add Hugging Face Remote
40
+
41
+ ```bash
42
+ # Add HF Space as remote (different name to avoid confusion with GitHub)
43
+ git remote add huggingface https://huggingface.co/spaces/Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun
44
+ ```
45
+
46
+ ### 4. Prepare Files for Space
47
+
48
+ Make sure these files are ready:
49
+ - βœ… `app.py` - Main app (loads from HF model repo)
50
+ - βœ… `model.py` - Model definition
51
+ - βœ… `train.py` - Contains SmolLM2Module
52
+ - βœ… `requirements.txt` - Dependencies
53
+ - βœ… `.gitignore` - Should exclude logs/, checkpoints/, etc.
54
+
55
+ ### 5. Push to Hugging Face Space
56
+
57
+ ```bash
58
+ # First, disable GPG signing temporarily (if you had issues)
59
+ git config --global commit.gpgsign false
60
+
61
+ # Add and commit files
62
+ git add app.py model.py train.py requirements.txt .gitignore
63
+ git commit -m "Add Gradio app for SmolLM2-135M inference"
64
+
65
+ # Push to Hugging Face Space
66
+ git push huggingface main
67
+
68
+ # Re-enable GPG signing if you want
69
+ git config --global commit.gpgsign true
70
+ ```
71
+
72
+ ### 6. Verify on Hugging Face
73
+
74
+ 1. Go to your Space: https://huggingface.co/spaces/Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun
75
+ 2. Check the "Files" tab - you should see `app.py`, `model.py`, `train.py`, `requirements.txt`
76
+ 3. The Space should automatically build and deploy
77
+ 4. Once built, you can test the app in the web interface
78
+
79
+ ## Important Notes
80
+
81
+ - **Model Loading**: The app automatically loads from `Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun` model repo
82
+ - **Checkpoint**: Uses `smollm2-step=05000-train_loss=0.0918.ckpt`
83
+ - **First Load**: The first time the Space loads, it will download the checkpoint from the model repo (may take a few minutes)
84
+ - **Caching**: Subsequent loads will be faster due to Hugging Face caching
85
+
86
+ ## Troubleshooting
87
+
88
+ ### If push fails with "non-fast-forward":
89
+ ```bash
90
+ # Fetch latest
91
+ git fetch huggingface
92
+
93
+ # Rebase (without GPG signing)
94
+ git config --global commit.gpgsign false
95
+ git rebase huggingface/main
96
+ git push huggingface main
97
+ git config --global commit.gpgsign true
98
+ ```
99
+
100
+ ### If Space build fails:
101
+ - Check the "Logs" tab in your Space
102
+ - Ensure all dependencies are in `requirements.txt`
103
+ - Make sure `app.py` is the entry point (it should be automatically detected)
104
+
105
+ ### If model loading fails:
106
+ - Verify the model repo name is correct: `Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun`
107
+ - Verify the checkpoint name: `smollm2-step=05000-train_loss=0.0918.ckpt`
108
+ - Check that the checkpoint file exists in the model repo
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for SmolLM2-135M inference with streaming output.
3
+ Loads model from Hugging Face model repo.
4
+ """
5
+
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import List, Optional
9
+ import os
10
+
11
+ import gradio as gr
12
+ import torch
13
+ from transformers import AutoConfig, AutoTokenizer
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ from model import SmolConfig, SmolLM2
17
+ from train import SmolLM2Module
18
+
19
+ # Hugging Face model repo configuration
20
+ HF_MODEL_REPO = "Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun"
21
+ CHECKPOINT_NAME = "smollm2-step=05000-train_loss=0.0918.ckpt"
22
+
23
+ # Device setup
24
+ DEVICE = "cpu"
25
+ if torch.cuda.is_available():
26
+ DEVICE = "cuda"
27
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
28
+ DEVICE = "mps"
29
+
30
+ # Globals
31
+ model: Optional[SmolLM2] = None
32
+ tokenizer = None
33
+
34
+ # Allow SmolConfig to be deserialized from Lightning checkpoints when torch.load
35
+ try:
36
+ torch.serialization.add_safe_globals([SmolConfig]) # type: ignore[attr-defined]
37
+ except Exception:
38
+ pass
39
+
40
+
41
+ def load_model_checkpoint(checkpoint_path: Optional[str] = None, use_hf: bool = True):
42
+ """Load Lightning checkpoint from Hugging Face Hub or local path."""
43
+ global model, tokenizer
44
+
45
+ try:
46
+ # Load tokenizer and config from Hugging Face
47
+ hf_cfg = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
48
+ config = SmolConfig.from_hf(hf_cfg)
49
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
50
+ if tokenizer.pad_token is None:
51
+ tokenizer.pad_token = tokenizer.eos_token
52
+
53
+ # Determine checkpoint path
54
+ if use_hf and checkpoint_path is None:
55
+ # Download from Hugging Face Hub
56
+ try:
57
+ local_ckpt = hf_hub_download(
58
+ repo_id=HF_MODEL_REPO,
59
+ filename=CHECKPOINT_NAME,
60
+ cache_dir=None, # Use default cache
61
+ )
62
+ checkpoint_path = local_ckpt
63
+ status_msg = f"βœ… Model loaded from Hugging Face: {HF_MODEL_REPO}/{CHECKPOINT_NAME}"
64
+ except Exception as e:
65
+ return f"❌ Failed to download from HF Hub: {e}"
66
+ elif checkpoint_path:
67
+ # Use local path
68
+ ckpt = Path(checkpoint_path)
69
+ if not ckpt.exists():
70
+ return f"❌ Checkpoint not found: {ckpt}"
71
+ status_msg = f"βœ… Model loaded from local path: {checkpoint_path}"
72
+ else:
73
+ return "❌ No checkpoint path provided"
74
+
75
+ # Load the Lightning module
76
+ module = SmolLM2Module.load_from_checkpoint(
77
+ str(checkpoint_path),
78
+ config=config,
79
+ tokenizer=tokenizer,
80
+ map_location=DEVICE,
81
+ strict=False,
82
+ )
83
+ module.eval()
84
+ model = module.model.to(DEVICE).eval()
85
+ return f"{status_msg} on {DEVICE}"
86
+ except Exception as e:
87
+ model = None
88
+ return f"❌ Error loading model: {e}"
89
+
90
+
91
+ def stream_generate(
92
+ prompt: str,
93
+ max_new_tokens: int,
94
+ temperature: float,
95
+ top_k: int,
96
+ top_p: float,
97
+ ):
98
+ """Generator that yields only the generated text (without prompt)."""
99
+ global model, tokenizer
100
+ if model is None or tokenizer is None:
101
+ yield "⚠️ Load the model first (click Reload Model)."
102
+ return
103
+
104
+ if not prompt or not prompt.strip():
105
+ yield "⚠️ Please enter a prompt."
106
+ return
107
+
108
+ # Tokenize prompt
109
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
110
+ input_ids = inputs["input_ids"].to(DEVICE)
111
+
112
+ # Guard against context overflow
113
+ if input_ids.shape[1] >= model.config.max_position_embeddings:
114
+ yield f"⚠️ Prompt too long ({input_ids.shape[1]} tokens). Max is {model.config.max_position_embeddings}."
115
+ return
116
+
117
+ generated = input_ids
118
+ past_key_values: Optional[List] = None
119
+ prompt_length = input_ids.shape[1]
120
+
121
+ with torch.no_grad():
122
+ for _ in range(max_new_tokens):
123
+ if past_key_values is None:
124
+ current_input = generated
125
+ else:
126
+ current_input = generated[:, -1:]
127
+
128
+ logits, past_key_values = model(
129
+ current_input,
130
+ past_key_values=past_key_values,
131
+ use_cache=True,
132
+ )
133
+
134
+ next_token_logits = logits[:, -1, :] / max(temperature, 1e-6)
135
+
136
+ # top-k
137
+ if top_k > 0:
138
+ values, _ = torch.topk(next_token_logits, top_k)
139
+ min_keep = values[:, -1].unsqueeze(-1)
140
+ next_token_logits = torch.where(
141
+ next_token_logits < min_keep,
142
+ torch.full_like(next_token_logits, float("-inf")),
143
+ next_token_logits,
144
+ )
145
+
146
+ # top-p
147
+ if top_p < 1.0:
148
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
149
+ probs = torch.softmax(sorted_logits, dim=-1)
150
+ cumulative = torch.cumsum(probs, dim=-1)
151
+ sorted_mask = cumulative > top_p
152
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
153
+ sorted_mask[..., 0] = 0
154
+ mask = sorted_mask.scatter(1, sorted_indices, sorted_mask)
155
+ next_token_logits = torch.where(mask, torch.full_like(next_token_logits, float("-inf")), next_token_logits)
156
+
157
+ probs = torch.softmax(next_token_logits, dim=-1)
158
+ next_token = torch.multinomial(probs, num_samples=1)
159
+
160
+ generated = torch.cat([generated, next_token], dim=1)
161
+ # Decode only the generated part (skip the prompt)
162
+ generated_text = tokenizer.decode(generated[0][prompt_length:], skip_special_tokens=True)
163
+ yield generated_text
164
+
165
+
166
+ # Initial load from Hugging Face
167
+ INITIAL_STATUS = load_model_checkpoint(use_hf=True)
168
+
169
+
170
+ def chat_stream(message, history, max_tokens, temperature, top_k, top_p):
171
+ """Gradio wrapper for streaming chat."""
172
+ if history is None:
173
+ history = []
174
+
175
+ # Convert history from tuple format to dict format if needed
176
+ if history and isinstance(history[0], (list, tuple)):
177
+ new_history = []
178
+ for h in history:
179
+ if isinstance(h, (list, tuple)) and len(h) >= 2:
180
+ if h[0]: # User message
181
+ new_history.append({"role": "user", "content": str(h[0])})
182
+ if h[1]: # Assistant message
183
+ new_history.append({"role": "assistant", "content": str(h[1])})
184
+ history = new_history
185
+
186
+ # Append user message
187
+ user_msg = (message or "").strip()
188
+ if not user_msg:
189
+ yield history
190
+ return
191
+
192
+ history.append({"role": "user", "content": user_msg})
193
+ history.append({"role": "assistant", "content": ""})
194
+
195
+ stream = stream_generate(user_msg, max_tokens, temperature, top_k, top_p)
196
+ for partial in stream:
197
+ # Update the last assistant message with generated text
198
+ if partial:
199
+ history[-1] = {"role": "assistant", "content": str(partial)}
200
+ yield history
201
+
202
+
203
+ def clear_chat():
204
+ return "", []
205
+
206
+
207
+ with gr.Blocks(title="SmolLM2-135M Text Generator") as demo:
208
+ gr.Markdown(
209
+ """
210
+ # πŸ€– SmolLM2-135M Text Generator
211
+
212
+ Generate text with your trained SmolLM2-135M model (streaming output).
213
+
214
+ **Model:** Trained on TinyShakespeare dataset
215
+ **Source:** [Hugging Face Model Repo](https://huggingface.co/Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun)
216
+ """
217
+ )
218
+
219
+ with gr.Row():
220
+ with gr.Column(scale=1):
221
+ gr.Markdown("### Model Status")
222
+ status_text = gr.Textbox(value=INITIAL_STATUS, label="Status", interactive=False, lines=3)
223
+ load_btn = gr.Button("πŸ”„ Reload Model from HF", variant="secondary")
224
+ load_btn.click(fn=lambda: load_model_checkpoint(use_hf=True), outputs=status_text)
225
+
226
+ gr.Markdown("### Local Checkpoint (Optional)")
227
+ ckpt_input = gr.Textbox(
228
+ value="",
229
+ label="Local checkpoint path (leave empty to use HF)",
230
+ interactive=True,
231
+ )
232
+ load_local_btn = gr.Button("πŸ“ Load from Local Path", variant="secondary")
233
+ load_local_btn.click(
234
+ fn=lambda p: load_model_checkpoint(checkpoint_path=p, use_hf=False) if p else "⚠️ Please enter a path",
235
+ inputs=ckpt_input,
236
+ outputs=status_text
237
+ )
238
+
239
+ gr.Markdown("### Generation Parameters")
240
+ max_tokens = gr.Slider(10, 500, value=100, step=10, label="Max Tokens")
241
+ temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature")
242
+ top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K")
243
+ top_p = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-P")
244
+
245
+ with gr.Column(scale=2):
246
+ gr.Markdown("### πŸ’¬ Chat Interface")
247
+ chatbot = gr.Chatbot(label="Conversation", height=500)
248
+ with gr.Row():
249
+ msg = gr.Textbox(label="Your Message", placeholder="Type your prompt here...", scale=4, lines=2)
250
+ submit_btn = gr.Button("Send ➀", variant="primary", scale=1)
251
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="stop")
252
+
253
+ msg.submit(fn=chat_stream, inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], outputs=chatbot)
254
+ submit_btn.click(fn=chat_stream, inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], outputs=chatbot).then(fn=lambda: "", outputs=msg)
255
+ clear_btn.click(fn=clear_chat, outputs=[msg, chatbot])
256
+
257
+
258
+ if __name__ == "__main__":
259
+ demo.queue().launch(share=False, server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.9.1
2
+ lightning>=2.6.0
3
+ transformers>=4.57.3
4
+ gradio>=4.44.0
5
+ huggingface-hub>=0.20.0