Sualeh Qureshi
commited on
Commit
Β·
13f6128
1
Parent(s):
dc69345
Added README.md, and app.py for huggingface space
Browse files- .gitignore +3 -0
- README.md +321 -0
- README_SPACE.md +108 -0
- app.py +259 -0
- requirements.txt +5 -0
.gitignore
CHANGED
|
@@ -12,3 +12,6 @@ wheels/
|
|
| 12 |
# Checkpoints
|
| 13 |
checkpoints/
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# Checkpoints
|
| 13 |
checkpoints/
|
| 14 |
|
| 15 |
+
# tensorboard logs
|
| 16 |
+
logs/tensorboard/
|
| 17 |
+
|
README.md
CHANGED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SmolLM2-135M Implementation
|
| 2 |
+
|
| 3 |
+
A from-scratch PyTorch implementation of the SmolLM2-135M language model, following the LLaMA architecture with modern optimizations.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This repository contains a complete implementation of SmolLM2-135M, a 135 million parameter decoder-only transformer model. The implementation includes:
|
| 8 |
+
|
| 9 |
+
- **Model Architecture** (`model.py`): Complete model definition with KV cache support
|
| 10 |
+
- **Training Script** (`train.py`): PyTorch Lightning training with WSD scheduler
|
| 11 |
+
- **Gradio App** (`app.py`): Interactive web interface for text generation
|
| 12 |
+
|
| 13 |
+
## Model Architecture (`model.py`)
|
| 14 |
+
|
| 15 |
+
### Architecture Components
|
| 16 |
+
|
| 17 |
+
The model follows the LLaMA-style decoder-only transformer architecture with the following key components:
|
| 18 |
+
|
| 19 |
+
#### 1. **SmolConfig** (Configuration Class)
|
| 20 |
+
|
| 21 |
+
A dataclass that stores all model hyperparameters:
|
| 22 |
+
|
| 23 |
+
```python
|
| 24 |
+
@dataclass
|
| 25 |
+
class SmolConfig:
|
| 26 |
+
vocab_size: int = 49152 # Vocabulary size
|
| 27 |
+
hidden_size: int = 576 # Hidden dimension
|
| 28 |
+
intermediate_size: int = 1536 # MLP intermediate dimension
|
| 29 |
+
num_hidden_layers: int = 30 # Number of transformer layers
|
| 30 |
+
num_attention_heads: int = 9 # Number of query heads
|
| 31 |
+
num_key_value_heads: int = 3 # Number of key/value heads (GQA)
|
| 32 |
+
max_position_embeddings: int = 8192 # Maximum sequence length
|
| 33 |
+
rope_theta: float = 100000.0 # RoPE base frequency
|
| 34 |
+
rms_norm_eps: float = 1e-5 # RMSNorm epsilon
|
| 35 |
+
attention_bias: bool = False # Whether to use bias in attention
|
| 36 |
+
mlp_bias: bool = False # Whether to use bias in MLP
|
| 37 |
+
dtype: torch.dtype = torch.bfloat16
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
**Key Features:**
|
| 41 |
+
- `head_dim` property: Automatically computes head dimension (hidden_size // num_attention_heads = 64)
|
| 42 |
+
- `from_hf()` class method: Loads configuration from HuggingFace model config
|
| 43 |
+
|
| 44 |
+
#### 2. **RMSNorm** (Root Mean Square Normalization)
|
| 45 |
+
|
| 46 |
+
Replaces LayerNorm with a more efficient normalization:
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
class RMSNorm(nn.Module):
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
norm = x.pow(2).mean(dim=-1, keepdim=True)
|
| 52 |
+
x = x * torch.rsqrt(norm + self.eps)
|
| 53 |
+
return self.weight * x
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
**Benefits:**
|
| 57 |
+
- More efficient than LayerNorm (no mean subtraction)
|
| 58 |
+
- Used throughout the model for pre-norm architecture
|
| 59 |
+
|
| 60 |
+
#### 3. **RoPE** (Rotary Positional Embeddings)
|
| 61 |
+
|
| 62 |
+
Rotary Position Embeddings applied to query and key tensors:
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
def build_rope_cache(seq_len, head_dim, base, device, dtype):
|
| 66 |
+
# Computes cosine and sine caches for RoPE
|
| 67 |
+
inv_freq = 1.0 / (base ** (freq_seq / half_dim))
|
| 68 |
+
freqs = torch.outer(t, inv_freq)
|
| 69 |
+
cos = freqs.cos()[None, None, :, :]
|
| 70 |
+
sin = freqs.sin()[None, None, :, :]
|
| 71 |
+
return cos, sin
|
| 72 |
+
|
| 73 |
+
def apply_rope(x, cos, sin):
|
| 74 |
+
# Applies rotary transformation to input tensor
|
| 75 |
+
x1, x2 = x[..., :half], x[..., half:]
|
| 76 |
+
x1_rot = x1 * cos - x2 * sin
|
| 77 |
+
x2_rot = x1 * sin + x2 * cos
|
| 78 |
+
return torch.cat([x1_rot, x2_rot], dim=-1)
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
**Key Features:**
|
| 82 |
+
- Relative positional encoding (no absolute position embeddings)
|
| 83 |
+
- Applied only to Q and K (not V)
|
| 84 |
+
- Supports efficient caching for inference
|
| 85 |
+
|
| 86 |
+
#### 4. **MultiHeadSelfAttention** (Grouped Query Attention)
|
| 87 |
+
|
| 88 |
+
Implements GQA (Grouped Query Attention) where:
|
| 89 |
+
- **Query heads**: 9 (full attention)
|
| 90 |
+
- **Key/Value heads**: 3 (shared across query heads)
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 94 |
+
def forward(self, x, cos, sin, past_key_value=None, use_cache=False):
|
| 95 |
+
# 1. Project to Q, K, V
|
| 96 |
+
q = self.q_proj(x) # (B, T, n_heads * head_dim)
|
| 97 |
+
k = self.k_proj(x) # (B, T, n_kv_heads * head_dim)
|
| 98 |
+
v = self.v_proj(x) # (B, T, n_kv_heads * head_dim)
|
| 99 |
+
|
| 100 |
+
# 2. Apply RoPE to Q and K
|
| 101 |
+
q = apply_rope(q, cos, sin)
|
| 102 |
+
k = apply_rope(k, cos, sin)
|
| 103 |
+
|
| 104 |
+
# 3. KV Cache support (for inference)
|
| 105 |
+
if past_key_value:
|
| 106 |
+
k = torch.cat([past_k, k], dim=2)
|
| 107 |
+
v = torch.cat([past_v, v], dim=2)
|
| 108 |
+
|
| 109 |
+
# 4. GQA: Expand K/V if needed
|
| 110 |
+
if n_kv_heads < n_heads:
|
| 111 |
+
k = k.repeat_interleave(repeat_factor, dim=1)
|
| 112 |
+
v = v.repeat_interleave(repeat_factor, dim=1)
|
| 113 |
+
|
| 114 |
+
# 5. Compute attention scores
|
| 115 |
+
scores = (q @ k.transpose(-2, -1)) / sqrt(head_dim)
|
| 116 |
+
scores = scores + causal_mask # Causal masking
|
| 117 |
+
|
| 118 |
+
# 6. Softmax and weighted sum
|
| 119 |
+
probs = F.softmax(scores, dim=-1)
|
| 120 |
+
out = probs @ v
|
| 121 |
+
|
| 122 |
+
return out, present_key_value
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
**Key Features:**
|
| 126 |
+
- **KV Cache**: Efficient inference by caching past key-value pairs
|
| 127 |
+
- **GQA**: Reduces memory by sharing K/V heads (3:1 ratio)
|
| 128 |
+
- **Causal Masking**: Prevents attending to future tokens
|
| 129 |
+
- **RoPE Integration**: Positional encoding via rotary embeddings
|
| 130 |
+
|
| 131 |
+
#### 5. **SmolMLP** (SwiGLU Activation)
|
| 132 |
+
|
| 133 |
+
Implements the SwiGLU (Swish-Gated Linear Unit) MLP:
|
| 134 |
+
|
| 135 |
+
```python
|
| 136 |
+
class SmolMLP(nn.Module):
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
# fc1 outputs 2 * intermediate_size
|
| 139 |
+
x = self.fc1(x) # (B, T, 2 * 1536) = (B, T, 3072)
|
| 140 |
+
x1, x2 = x.chunk(2, dim=-1) # Split into two parts
|
| 141 |
+
# SwiGLU: SiLU(x1) * x2
|
| 142 |
+
return self.fc2(F.silu(x1) * x2)
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
**Key Features:**
|
| 146 |
+
- **SwiGLU**: `SiLU(x1) * x2` activation (better than ReLU/GELU)
|
| 147 |
+
- **No bias**: Following LLaMA architecture
|
| 148 |
+
- **Efficient**: Single matrix multiplication with split
|
| 149 |
+
|
| 150 |
+
#### 6. **SmolBlock** (Transformer Block)
|
| 151 |
+
|
| 152 |
+
Combines attention and MLP with pre-norm and residual connections:
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
class SmolBlock(nn.Module):
|
| 156 |
+
def forward(self, x, cos, sin, past_key_value=None, use_cache=False):
|
| 157 |
+
# Pre-norm attention with residual
|
| 158 |
+
attn_out, present_kv = self.attn(
|
| 159 |
+
self.attn_norm(x), cos, sin,
|
| 160 |
+
past_key_value=past_key_value, use_cache=use_cache
|
| 161 |
+
)
|
| 162 |
+
x = x + attn_out
|
| 163 |
+
|
| 164 |
+
# Pre-norm MLP with residual
|
| 165 |
+
x = x + self.mlp(self.mlp_norm(x))
|
| 166 |
+
|
| 167 |
+
return x, present_kv
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
**Architecture:**
|
| 171 |
+
- **Pre-norm**: Normalization before attention/MLP (not after)
|
| 172 |
+
- **Residual connections**: Skip connections for gradient flow
|
| 173 |
+
- **KV Cache passthrough**: Supports efficient inference
|
| 174 |
+
|
| 175 |
+
#### 7. **SmolLM2** (Main Model)
|
| 176 |
+
|
| 177 |
+
Top-level model that combines all components:
|
| 178 |
+
|
| 179 |
+
```python
|
| 180 |
+
class SmolLM2(nn.Module):
|
| 181 |
+
def __init__(self, config):
|
| 182 |
+
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
|
| 183 |
+
self.layers = nn.ModuleList([SmolBlock(config) for _ in range(30)])
|
| 184 |
+
self.norm = RMSNorm(hidden_size)
|
| 185 |
+
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
| 186 |
+
|
| 187 |
+
# Weight tying: share embeddings and output weights
|
| 188 |
+
self.lm_head.weight = self.embed_tokens.weight
|
| 189 |
+
|
| 190 |
+
def forward(self, input_ids, past_key_values=None, use_cache=False):
|
| 191 |
+
# 1. Token embeddings
|
| 192 |
+
x = self.embed_tokens(input_ids)
|
| 193 |
+
|
| 194 |
+
# 2. Build RoPE cache
|
| 195 |
+
cos, sin = build_rope_cache(...)
|
| 196 |
+
|
| 197 |
+
# 3. Pass through transformer layers
|
| 198 |
+
present_key_values = []
|
| 199 |
+
for layer in self.layers:
|
| 200 |
+
x, present_kv = layer(x, cos, sin, past_key_value, use_cache)
|
| 201 |
+
if use_cache:
|
| 202 |
+
present_key_values.append(present_kv)
|
| 203 |
+
|
| 204 |
+
# 4. Final norm and language modeling head
|
| 205 |
+
x = self.norm(x)
|
| 206 |
+
logits = self.lm_head(x)
|
| 207 |
+
|
| 208 |
+
return logits, present_key_values
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
**Key Features:**
|
| 212 |
+
- **Weight Tying**: Embeddings and output weights are shared (reduces parameters)
|
| 213 |
+
- **KV Cache Support**: Full support for efficient autoregressive generation
|
| 214 |
+
- **30 Layers**: Deep transformer stack for capacity
|
| 215 |
+
|
| 216 |
+
#### 8. **Generate Method** (Text Generation)
|
| 217 |
+
|
| 218 |
+
Autoregressive text generation with KV cache:
|
| 219 |
+
|
| 220 |
+
```python
|
| 221 |
+
@torch.no_grad()
|
| 222 |
+
def generate(self, input_ids, max_new_tokens=100, temperature=1.0,
|
| 223 |
+
top_k=None, top_p=None, eos_token_id=None):
|
| 224 |
+
generated = input_ids
|
| 225 |
+
past_key_values = None
|
| 226 |
+
|
| 227 |
+
for _ in range(max_new_tokens):
|
| 228 |
+
# Forward pass with KV cache
|
| 229 |
+
logits, past_key_values = self.forward(
|
| 230 |
+
generated[:, -1:] if past_key_values else generated,
|
| 231 |
+
past_key_values=past_key_values,
|
| 232 |
+
use_cache=True
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Sample next token with temperature, top-k, top-p
|
| 236 |
+
next_token_logits = logits[:, -1, :] / temperature
|
| 237 |
+
# Apply top-k and top-p filtering
|
| 238 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 239 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 240 |
+
|
| 241 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 242 |
+
|
| 243 |
+
if eos_token_id and (next_token == eos_token_id).all():
|
| 244 |
+
break
|
| 245 |
+
|
| 246 |
+
return generated
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
**Key Features:**
|
| 250 |
+
- **KV Cache**: Only processes new tokens (not entire sequence)
|
| 251 |
+
- **Sampling**: Supports temperature, top-k, and top-p (nucleus) sampling
|
| 252 |
+
- **Efficient**: O(1) per token after initial forward pass
|
| 253 |
+
|
| 254 |
+
### Model Specifications
|
| 255 |
+
|
| 256 |
+
| Parameter | Value |
|
| 257 |
+
|-----------|-------|
|
| 258 |
+
| **Total Parameters** | ~135M |
|
| 259 |
+
| **Hidden Size** | 576 |
|
| 260 |
+
| **Layers** | 30 |
|
| 261 |
+
| **Attention Heads** | 9 (Q), 3 (K/V) |
|
| 262 |
+
| **Head Dimension** | 64 |
|
| 263 |
+
| **Intermediate Size** | 1536 |
|
| 264 |
+
| **Vocabulary Size** | 49,152 |
|
| 265 |
+
| **Max Sequence Length** | 8,192 |
|
| 266 |
+
| **RoPE Theta** | 100,000 |
|
| 267 |
+
| **Activation** | SwiGLU (SiLU-gated) |
|
| 268 |
+
| **Normalization** | RMSNorm |
|
| 269 |
+
| **Weight Tying** | Yes (embeddings = output) |
|
| 270 |
+
|
| 271 |
+
### Key Design Choices
|
| 272 |
+
|
| 273 |
+
1. **GQA (Grouped Query Attention)**: 3:1 ratio reduces memory by 66% for K/V cache
|
| 274 |
+
2. **Pre-norm Architecture**: More stable training than post-norm
|
| 275 |
+
3. **RMSNorm**: Faster and simpler than LayerNorm
|
| 276 |
+
4. **RoPE**: Relative positional encoding, no learned embeddings
|
| 277 |
+
5. **SwiGLU**: Better activation than ReLU/GELU
|
| 278 |
+
6. **Weight Tying**: Reduces parameters and improves generalization
|
| 279 |
+
7. **No Biases**: Following LLaMA, reduces parameters slightly
|
| 280 |
+
|
| 281 |
+
### Usage Example
|
| 282 |
+
|
| 283 |
+
```python
|
| 284 |
+
from model import SmolConfig, SmolLM2
|
| 285 |
+
from transformers import AutoConfig
|
| 286 |
+
|
| 287 |
+
# Load config from HuggingFace
|
| 288 |
+
hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 289 |
+
config = SmolConfig.from_hf(hf_config)
|
| 290 |
+
|
| 291 |
+
# Create model
|
| 292 |
+
model = SmolLM2(config)
|
| 293 |
+
|
| 294 |
+
# Forward pass (training)
|
| 295 |
+
input_ids = torch.randint(0, config.vocab_size, (2, 512))
|
| 296 |
+
logits, _ = model(input_ids, use_cache=False)
|
| 297 |
+
|
| 298 |
+
# Text generation (inference with KV cache)
|
| 299 |
+
prompt_ids = tokenizer.encode("Hello, how are you?")
|
| 300 |
+
generated = model.generate(
|
| 301 |
+
prompt_ids,
|
| 302 |
+
max_new_tokens=100,
|
| 303 |
+
temperature=0.8,
|
| 304 |
+
top_k=50
|
| 305 |
+
)
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
## Training
|
| 309 |
+
|
| 310 |
+
See `README_TRAINING.md` for detailed training instructions.
|
| 311 |
+
|
| 312 |
+
## Inference
|
| 313 |
+
|
| 314 |
+
See `app.py` for the Gradio web interface or use the `generate()` method directly.
|
| 315 |
+
|
| 316 |
+
## References
|
| 317 |
+
|
| 318 |
+
- [SmolLM2 Paper](https://arxiv.org/abs/2406.02528)
|
| 319 |
+
- [LLaMA Architecture](https://arxiv.org/abs/2302.13971)
|
| 320 |
+
- [RoPE: Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
|
| 321 |
+
- [SwiGLU Activation](https://arxiv.org/abs/2002.05202)
|
README_SPACE.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SmolLM2-135M Hugging Face Space Setup Guide
|
| 2 |
+
|
| 3 |
+
This guide explains how to push your model and app to Hugging Face Spaces.
|
| 4 |
+
|
| 5 |
+
## Files Needed for Hugging Face Space
|
| 6 |
+
|
| 7 |
+
1. **app.py** - Main Gradio application (already created)
|
| 8 |
+
2. **model.py** - Model definition
|
| 9 |
+
3. **train.py** - Contains SmolLM2Module class (needed for loading checkpoints)
|
| 10 |
+
4. **requirements.txt** - Python dependencies
|
| 11 |
+
5. **README.md** - Space description (optional but recommended)
|
| 12 |
+
|
| 13 |
+
## Step-by-Step Guide
|
| 14 |
+
|
| 15 |
+
### 1. Fix Merge Conflicts (if still present)
|
| 16 |
+
|
| 17 |
+
If you still have merge conflicts, resolve them:
|
| 18 |
+
```bash
|
| 19 |
+
# Check status
|
| 20 |
+
git status
|
| 21 |
+
|
| 22 |
+
# Resolve conflicts in train.py and pyproject.toml
|
| 23 |
+
# Then commit
|
| 24 |
+
git add train.py pyproject.toml
|
| 25 |
+
git commit -m "Resolve merge conflicts"
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### 2. Create Hugging Face Space (if not already created)
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
# Create the space (without --sdk flag, set it in web UI)
|
| 32 |
+
huggingface-cli repo create smollm2-135m-trained-on-tinyShakespear-forfun --type=space
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Then go to the Space settings in the web UI and set:
|
| 36 |
+
- **SDK**: Gradio
|
| 37 |
+
- **Python version**: 3.12
|
| 38 |
+
|
| 39 |
+
### 3. Add Hugging Face Remote
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
# Add HF Space as remote (different name to avoid confusion with GitHub)
|
| 43 |
+
git remote add huggingface https://huggingface.co/spaces/Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### 4. Prepare Files for Space
|
| 47 |
+
|
| 48 |
+
Make sure these files are ready:
|
| 49 |
+
- β
`app.py` - Main app (loads from HF model repo)
|
| 50 |
+
- β
`model.py` - Model definition
|
| 51 |
+
- β
`train.py` - Contains SmolLM2Module
|
| 52 |
+
- β
`requirements.txt` - Dependencies
|
| 53 |
+
- β
`.gitignore` - Should exclude logs/, checkpoints/, etc.
|
| 54 |
+
|
| 55 |
+
### 5. Push to Hugging Face Space
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
# First, disable GPG signing temporarily (if you had issues)
|
| 59 |
+
git config --global commit.gpgsign false
|
| 60 |
+
|
| 61 |
+
# Add and commit files
|
| 62 |
+
git add app.py model.py train.py requirements.txt .gitignore
|
| 63 |
+
git commit -m "Add Gradio app for SmolLM2-135M inference"
|
| 64 |
+
|
| 65 |
+
# Push to Hugging Face Space
|
| 66 |
+
git push huggingface main
|
| 67 |
+
|
| 68 |
+
# Re-enable GPG signing if you want
|
| 69 |
+
git config --global commit.gpgsign true
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### 6. Verify on Hugging Face
|
| 73 |
+
|
| 74 |
+
1. Go to your Space: https://huggingface.co/spaces/Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun
|
| 75 |
+
2. Check the "Files" tab - you should see `app.py`, `model.py`, `train.py`, `requirements.txt`
|
| 76 |
+
3. The Space should automatically build and deploy
|
| 77 |
+
4. Once built, you can test the app in the web interface
|
| 78 |
+
|
| 79 |
+
## Important Notes
|
| 80 |
+
|
| 81 |
+
- **Model Loading**: The app automatically loads from `Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun` model repo
|
| 82 |
+
- **Checkpoint**: Uses `smollm2-step=05000-train_loss=0.0918.ckpt`
|
| 83 |
+
- **First Load**: The first time the Space loads, it will download the checkpoint from the model repo (may take a few minutes)
|
| 84 |
+
- **Caching**: Subsequent loads will be faster due to Hugging Face caching
|
| 85 |
+
|
| 86 |
+
## Troubleshooting
|
| 87 |
+
|
| 88 |
+
### If push fails with "non-fast-forward":
|
| 89 |
+
```bash
|
| 90 |
+
# Fetch latest
|
| 91 |
+
git fetch huggingface
|
| 92 |
+
|
| 93 |
+
# Rebase (without GPG signing)
|
| 94 |
+
git config --global commit.gpgsign false
|
| 95 |
+
git rebase huggingface/main
|
| 96 |
+
git push huggingface main
|
| 97 |
+
git config --global commit.gpgsign true
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
### If Space build fails:
|
| 101 |
+
- Check the "Logs" tab in your Space
|
| 102 |
+
- Ensure all dependencies are in `requirements.txt`
|
| 103 |
+
- Make sure `app.py` is the entry point (it should be automatically detected)
|
| 104 |
+
|
| 105 |
+
### If model loading fails:
|
| 106 |
+
- Verify the model repo name is correct: `Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun`
|
| 107 |
+
- Verify the checkpoint name: `smollm2-step=05000-train_loss=0.0918.ckpt`
|
| 108 |
+
- Check that the checkpoint file exists in the model repo
|
app.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio app for SmolLM2-135M inference with streaming output.
|
| 3 |
+
Loads model from Hugging Face model repo.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import AutoConfig, AutoTokenizer
|
| 14 |
+
from huggingface_hub import hf_hub_download
|
| 15 |
+
|
| 16 |
+
from model import SmolConfig, SmolLM2
|
| 17 |
+
from train import SmolLM2Module
|
| 18 |
+
|
| 19 |
+
# Hugging Face model repo configuration
|
| 20 |
+
HF_MODEL_REPO = "Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun"
|
| 21 |
+
CHECKPOINT_NAME = "smollm2-step=05000-train_loss=0.0918.ckpt"
|
| 22 |
+
|
| 23 |
+
# Device setup
|
| 24 |
+
DEVICE = "cpu"
|
| 25 |
+
if torch.cuda.is_available():
|
| 26 |
+
DEVICE = "cuda"
|
| 27 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 28 |
+
DEVICE = "mps"
|
| 29 |
+
|
| 30 |
+
# Globals
|
| 31 |
+
model: Optional[SmolLM2] = None
|
| 32 |
+
tokenizer = None
|
| 33 |
+
|
| 34 |
+
# Allow SmolConfig to be deserialized from Lightning checkpoints when torch.load
|
| 35 |
+
try:
|
| 36 |
+
torch.serialization.add_safe_globals([SmolConfig]) # type: ignore[attr-defined]
|
| 37 |
+
except Exception:
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_model_checkpoint(checkpoint_path: Optional[str] = None, use_hf: bool = True):
|
| 42 |
+
"""Load Lightning checkpoint from Hugging Face Hub or local path."""
|
| 43 |
+
global model, tokenizer
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
# Load tokenizer and config from Hugging Face
|
| 47 |
+
hf_cfg = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 48 |
+
config = SmolConfig.from_hf(hf_cfg)
|
| 49 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 50 |
+
if tokenizer.pad_token is None:
|
| 51 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 52 |
+
|
| 53 |
+
# Determine checkpoint path
|
| 54 |
+
if use_hf and checkpoint_path is None:
|
| 55 |
+
# Download from Hugging Face Hub
|
| 56 |
+
try:
|
| 57 |
+
local_ckpt = hf_hub_download(
|
| 58 |
+
repo_id=HF_MODEL_REPO,
|
| 59 |
+
filename=CHECKPOINT_NAME,
|
| 60 |
+
cache_dir=None, # Use default cache
|
| 61 |
+
)
|
| 62 |
+
checkpoint_path = local_ckpt
|
| 63 |
+
status_msg = f"β
Model loaded from Hugging Face: {HF_MODEL_REPO}/{CHECKPOINT_NAME}"
|
| 64 |
+
except Exception as e:
|
| 65 |
+
return f"β Failed to download from HF Hub: {e}"
|
| 66 |
+
elif checkpoint_path:
|
| 67 |
+
# Use local path
|
| 68 |
+
ckpt = Path(checkpoint_path)
|
| 69 |
+
if not ckpt.exists():
|
| 70 |
+
return f"β Checkpoint not found: {ckpt}"
|
| 71 |
+
status_msg = f"β
Model loaded from local path: {checkpoint_path}"
|
| 72 |
+
else:
|
| 73 |
+
return "β No checkpoint path provided"
|
| 74 |
+
|
| 75 |
+
# Load the Lightning module
|
| 76 |
+
module = SmolLM2Module.load_from_checkpoint(
|
| 77 |
+
str(checkpoint_path),
|
| 78 |
+
config=config,
|
| 79 |
+
tokenizer=tokenizer,
|
| 80 |
+
map_location=DEVICE,
|
| 81 |
+
strict=False,
|
| 82 |
+
)
|
| 83 |
+
module.eval()
|
| 84 |
+
model = module.model.to(DEVICE).eval()
|
| 85 |
+
return f"{status_msg} on {DEVICE}"
|
| 86 |
+
except Exception as e:
|
| 87 |
+
model = None
|
| 88 |
+
return f"β Error loading model: {e}"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def stream_generate(
|
| 92 |
+
prompt: str,
|
| 93 |
+
max_new_tokens: int,
|
| 94 |
+
temperature: float,
|
| 95 |
+
top_k: int,
|
| 96 |
+
top_p: float,
|
| 97 |
+
):
|
| 98 |
+
"""Generator that yields only the generated text (without prompt)."""
|
| 99 |
+
global model, tokenizer
|
| 100 |
+
if model is None or tokenizer is None:
|
| 101 |
+
yield "β οΈ Load the model first (click Reload Model)."
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
if not prompt or not prompt.strip():
|
| 105 |
+
yield "β οΈ Please enter a prompt."
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
# Tokenize prompt
|
| 109 |
+
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
| 110 |
+
input_ids = inputs["input_ids"].to(DEVICE)
|
| 111 |
+
|
| 112 |
+
# Guard against context overflow
|
| 113 |
+
if input_ids.shape[1] >= model.config.max_position_embeddings:
|
| 114 |
+
yield f"β οΈ Prompt too long ({input_ids.shape[1]} tokens). Max is {model.config.max_position_embeddings}."
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
generated = input_ids
|
| 118 |
+
past_key_values: Optional[List] = None
|
| 119 |
+
prompt_length = input_ids.shape[1]
|
| 120 |
+
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
for _ in range(max_new_tokens):
|
| 123 |
+
if past_key_values is None:
|
| 124 |
+
current_input = generated
|
| 125 |
+
else:
|
| 126 |
+
current_input = generated[:, -1:]
|
| 127 |
+
|
| 128 |
+
logits, past_key_values = model(
|
| 129 |
+
current_input,
|
| 130 |
+
past_key_values=past_key_values,
|
| 131 |
+
use_cache=True,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
next_token_logits = logits[:, -1, :] / max(temperature, 1e-6)
|
| 135 |
+
|
| 136 |
+
# top-k
|
| 137 |
+
if top_k > 0:
|
| 138 |
+
values, _ = torch.topk(next_token_logits, top_k)
|
| 139 |
+
min_keep = values[:, -1].unsqueeze(-1)
|
| 140 |
+
next_token_logits = torch.where(
|
| 141 |
+
next_token_logits < min_keep,
|
| 142 |
+
torch.full_like(next_token_logits, float("-inf")),
|
| 143 |
+
next_token_logits,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# top-p
|
| 147 |
+
if top_p < 1.0:
|
| 148 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 149 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 150 |
+
cumulative = torch.cumsum(probs, dim=-1)
|
| 151 |
+
sorted_mask = cumulative > top_p
|
| 152 |
+
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
| 153 |
+
sorted_mask[..., 0] = 0
|
| 154 |
+
mask = sorted_mask.scatter(1, sorted_indices, sorted_mask)
|
| 155 |
+
next_token_logits = torch.where(mask, torch.full_like(next_token_logits, float("-inf")), next_token_logits)
|
| 156 |
+
|
| 157 |
+
probs = torch.softmax(next_token_logits, dim=-1)
|
| 158 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 159 |
+
|
| 160 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 161 |
+
# Decode only the generated part (skip the prompt)
|
| 162 |
+
generated_text = tokenizer.decode(generated[0][prompt_length:], skip_special_tokens=True)
|
| 163 |
+
yield generated_text
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# Initial load from Hugging Face
|
| 167 |
+
INITIAL_STATUS = load_model_checkpoint(use_hf=True)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def chat_stream(message, history, max_tokens, temperature, top_k, top_p):
|
| 171 |
+
"""Gradio wrapper for streaming chat."""
|
| 172 |
+
if history is None:
|
| 173 |
+
history = []
|
| 174 |
+
|
| 175 |
+
# Convert history from tuple format to dict format if needed
|
| 176 |
+
if history and isinstance(history[0], (list, tuple)):
|
| 177 |
+
new_history = []
|
| 178 |
+
for h in history:
|
| 179 |
+
if isinstance(h, (list, tuple)) and len(h) >= 2:
|
| 180 |
+
if h[0]: # User message
|
| 181 |
+
new_history.append({"role": "user", "content": str(h[0])})
|
| 182 |
+
if h[1]: # Assistant message
|
| 183 |
+
new_history.append({"role": "assistant", "content": str(h[1])})
|
| 184 |
+
history = new_history
|
| 185 |
+
|
| 186 |
+
# Append user message
|
| 187 |
+
user_msg = (message or "").strip()
|
| 188 |
+
if not user_msg:
|
| 189 |
+
yield history
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
history.append({"role": "user", "content": user_msg})
|
| 193 |
+
history.append({"role": "assistant", "content": ""})
|
| 194 |
+
|
| 195 |
+
stream = stream_generate(user_msg, max_tokens, temperature, top_k, top_p)
|
| 196 |
+
for partial in stream:
|
| 197 |
+
# Update the last assistant message with generated text
|
| 198 |
+
if partial:
|
| 199 |
+
history[-1] = {"role": "assistant", "content": str(partial)}
|
| 200 |
+
yield history
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def clear_chat():
|
| 204 |
+
return "", []
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
with gr.Blocks(title="SmolLM2-135M Text Generator") as demo:
|
| 208 |
+
gr.Markdown(
|
| 209 |
+
"""
|
| 210 |
+
# π€ SmolLM2-135M Text Generator
|
| 211 |
+
|
| 212 |
+
Generate text with your trained SmolLM2-135M model (streaming output).
|
| 213 |
+
|
| 214 |
+
**Model:** Trained on TinyShakespeare dataset
|
| 215 |
+
**Source:** [Hugging Face Model Repo](https://huggingface.co/Sualeh77/smollm2-135m-trained-on-tinyShakespear-forfun)
|
| 216 |
+
"""
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
with gr.Row():
|
| 220 |
+
with gr.Column(scale=1):
|
| 221 |
+
gr.Markdown("### Model Status")
|
| 222 |
+
status_text = gr.Textbox(value=INITIAL_STATUS, label="Status", interactive=False, lines=3)
|
| 223 |
+
load_btn = gr.Button("π Reload Model from HF", variant="secondary")
|
| 224 |
+
load_btn.click(fn=lambda: load_model_checkpoint(use_hf=True), outputs=status_text)
|
| 225 |
+
|
| 226 |
+
gr.Markdown("### Local Checkpoint (Optional)")
|
| 227 |
+
ckpt_input = gr.Textbox(
|
| 228 |
+
value="",
|
| 229 |
+
label="Local checkpoint path (leave empty to use HF)",
|
| 230 |
+
interactive=True,
|
| 231 |
+
)
|
| 232 |
+
load_local_btn = gr.Button("π Load from Local Path", variant="secondary")
|
| 233 |
+
load_local_btn.click(
|
| 234 |
+
fn=lambda p: load_model_checkpoint(checkpoint_path=p, use_hf=False) if p else "β οΈ Please enter a path",
|
| 235 |
+
inputs=ckpt_input,
|
| 236 |
+
outputs=status_text
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
gr.Markdown("### Generation Parameters")
|
| 240 |
+
max_tokens = gr.Slider(10, 500, value=100, step=10, label="Max Tokens")
|
| 241 |
+
temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature")
|
| 242 |
+
top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K")
|
| 243 |
+
top_p = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-P")
|
| 244 |
+
|
| 245 |
+
with gr.Column(scale=2):
|
| 246 |
+
gr.Markdown("### π¬ Chat Interface")
|
| 247 |
+
chatbot = gr.Chatbot(label="Conversation", height=500)
|
| 248 |
+
with gr.Row():
|
| 249 |
+
msg = gr.Textbox(label="Your Message", placeholder="Type your prompt here...", scale=4, lines=2)
|
| 250 |
+
submit_btn = gr.Button("Send β€", variant="primary", scale=1)
|
| 251 |
+
clear_btn = gr.Button("ποΈ Clear Chat", variant="stop")
|
| 252 |
+
|
| 253 |
+
msg.submit(fn=chat_stream, inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], outputs=chatbot)
|
| 254 |
+
submit_btn.click(fn=chat_stream, inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], outputs=chatbot).then(fn=lambda: "", outputs=msg)
|
| 255 |
+
clear_btn.click(fn=clear_chat, outputs=[msg, chatbot])
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
demo.queue().launch(share=False, server_name="0.0.0.0", server_port=7860)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.9.1
|
| 2 |
+
lightning>=2.6.0
|
| 3 |
+
transformers>=4.57.3
|
| 4 |
+
gradio>=4.44.0
|
| 5 |
+
huggingface-hub>=0.20.0
|