Sualeh Qureshi commited on
Commit
58ae689
·
1 Parent(s): c175ce3

Added Gradio app for HF space

Browse files
app_smol.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for SmolLM2-135M inference with streaming output.
3
+ Uses Lightning checkpoint saved from training.
4
+ """
5
+
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import List, Optional
9
+
10
+ import gradio as gr
11
+ import torch
12
+ from transformers import AutoConfig, AutoTokenizer
13
+
14
+ from model import SmolConfig, SmolLM2
15
+ from train import SmolLM2Module
16
+
17
+ # Device setup
18
+ DEVICE = "cpu"
19
+ if torch.cuda.is_available():
20
+ DEVICE = "cuda"
21
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
22
+ DEVICE = "mps"
23
+
24
+ # Globals
25
+ model: Optional[SmolLM2] = None
26
+ tokenizer = None
27
+
28
+ # Allow SmolConfig to be deserialized from Lightning checkpoints when torch.load
29
+ try:
30
+ torch.serialization.add_safe_globals([SmolConfig]) # type: ignore[attr-defined]
31
+ except Exception:
32
+ pass
33
+
34
+
35
+ def load_model_checkpoint(checkpoint_path: str = "checkpoints/smollm2-final-step-05000.ckpt"):
36
+ """Load Lightning checkpoint and return status string."""
37
+ global model, tokenizer
38
+
39
+ ckpt = Path(checkpoint_path)
40
+ if not ckpt.exists():
41
+ return f"❌ Checkpoint not found: {ckpt}"
42
+
43
+ try:
44
+ hf_cfg = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
45
+ config = SmolConfig.from_hf(hf_cfg)
46
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+
50
+ module = SmolLM2Module.load_from_checkpoint(
51
+ str(ckpt),
52
+ config=config,
53
+ tokenizer=tokenizer,
54
+ map_location=DEVICE,
55
+ strict=False,
56
+ )
57
+ module.eval()
58
+ model = module.model.to(DEVICE).eval()
59
+ return f"✅ Model loaded from {ckpt} on {DEVICE}"
60
+ except Exception as e: # pragma: no cover - interactive
61
+ model = None
62
+ return f"❌ Error loading model: {e}"
63
+
64
+
65
+ def stream_generate(
66
+ prompt: str,
67
+ max_new_tokens: int,
68
+ temperature: float,
69
+ top_k: int,
70
+ top_p: float,
71
+ ):
72
+ """Generator that yields only the generated text (without prompt)."""
73
+ global model, tokenizer
74
+ if model is None or tokenizer is None:
75
+ yield "⚠️ Load the model first (click Reload Model)."
76
+ return
77
+
78
+ if not prompt or not prompt.strip():
79
+ yield "⚠️ Please enter a prompt."
80
+ return
81
+
82
+ # Tokenize prompt
83
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
84
+ input_ids = inputs["input_ids"].to(DEVICE)
85
+
86
+ # Guard against context overflow
87
+ if input_ids.shape[1] >= model.config.max_position_embeddings:
88
+ yield f"⚠️ Prompt too long ({input_ids.shape[1]} tokens). Max is {model.config.max_position_embeddings}."
89
+ return
90
+
91
+ generated = input_ids
92
+ past_key_values: Optional[List] = None
93
+ prompt_length = input_ids.shape[1]
94
+
95
+ with torch.no_grad():
96
+ for _ in range(max_new_tokens):
97
+ if past_key_values is None:
98
+ current_input = generated
99
+ else:
100
+ current_input = generated[:, -1:]
101
+
102
+ logits, past_key_values = model(
103
+ current_input,
104
+ past_key_values=past_key_values,
105
+ use_cache=True,
106
+ )
107
+
108
+ next_token_logits = logits[:, -1, :] / max(temperature, 1e-6)
109
+
110
+ # top-k
111
+ if top_k > 0:
112
+ values, _ = torch.topk(next_token_logits, top_k)
113
+ min_keep = values[:, -1].unsqueeze(-1)
114
+ next_token_logits = torch.where(
115
+ next_token_logits < min_keep,
116
+ torch.full_like(next_token_logits, float("-inf")),
117
+ next_token_logits,
118
+ )
119
+
120
+ # top-p
121
+ if top_p < 1.0:
122
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
123
+ probs = torch.softmax(sorted_logits, dim=-1)
124
+ cumulative = torch.cumsum(probs, dim=-1)
125
+ sorted_mask = cumulative > top_p
126
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
127
+ sorted_mask[..., 0] = 0
128
+ mask = sorted_mask.scatter(1, sorted_indices, sorted_mask)
129
+ next_token_logits = torch.where(mask, torch.full_like(next_token_logits, float("-inf")), next_token_logits)
130
+
131
+ probs = torch.softmax(next_token_logits, dim=-1)
132
+ next_token = torch.multinomial(probs, num_samples=1)
133
+
134
+ generated = torch.cat([generated, next_token], dim=1)
135
+ # Decode only the generated part (skip the prompt)
136
+ generated_text = tokenizer.decode(generated[0][prompt_length:], skip_special_tokens=True)
137
+ yield generated_text
138
+
139
+
140
+ # Initial load
141
+ INITIAL_STATUS = load_model_checkpoint()
142
+
143
+
144
+ def chat_stream(message, history, max_tokens, temperature, top_k, top_p):
145
+ """Gradio wrapper for streaming chat."""
146
+ if history is None:
147
+ history = []
148
+
149
+ # Convert history from tuple format to dict format if needed
150
+ if history and isinstance(history[0], (list, tuple)):
151
+ # Convert from tuple format [(user, assistant), ...] to dict format
152
+ new_history = []
153
+ for h in history:
154
+ if isinstance(h, (list, tuple)) and len(h) >= 2:
155
+ if h[0]: # User message
156
+ new_history.append({"role": "user", "content": str(h[0])})
157
+ if h[1]: # Assistant message
158
+ new_history.append({"role": "assistant", "content": str(h[1])})
159
+ history = new_history
160
+
161
+ # Append user message
162
+ user_msg = (message or "").strip()
163
+ if not user_msg:
164
+ yield history
165
+ return
166
+
167
+ history.append({"role": "user", "content": user_msg})
168
+ history.append({"role": "assistant", "content": ""})
169
+
170
+ stream = stream_generate(user_msg, max_tokens, temperature, top_k, top_p)
171
+ for partial in stream:
172
+ # Update the last assistant message with generated text
173
+ if partial:
174
+ history[-1] = {"role": "assistant", "content": str(partial)}
175
+ yield history
176
+
177
+
178
+ def clear_chat():
179
+ return "", []
180
+
181
+
182
+ with gr.Blocks(title="SmolLM2-135M Text Generator") as demo:
183
+ gr.Markdown(
184
+ """
185
+ # 🤖 SmolLM2-135M Text Generator
186
+
187
+ Generate text with your trained SmolLM2-135M checkpoint (streaming output).
188
+ """
189
+ )
190
+
191
+ with gr.Row():
192
+ with gr.Column(scale=1):
193
+ gr.Markdown("### Model Status")
194
+ status_text = gr.Textbox(value=INITIAL_STATUS, label="Status", interactive=False, lines=2)
195
+ load_btn = gr.Button("🔄 Reload Model", variant="secondary")
196
+ ckpt_input = gr.Textbox(
197
+ value="checkpoints/smollm2-step=05000-train_loss=0.0918.ckpt",
198
+ label="Checkpoint path",
199
+ interactive=True,
200
+ )
201
+ load_btn.click(fn=lambda p: load_model_checkpoint(p), inputs=ckpt_input, outputs=status_text)
202
+
203
+ gr.Markdown("### Generation Parameters")
204
+ max_tokens = gr.Slider(10, 500, value=100, step=10, label="Max Tokens")
205
+ temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature")
206
+ top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K")
207
+ top_p = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-P")
208
+
209
+ with gr.Column(scale=2):
210
+ gr.Markdown("### 💬 Chat Interface")
211
+ chatbot = gr.Chatbot(label="Conversation", height=500)
212
+ with gr.Row():
213
+ msg = gr.Textbox(label="Your Message", placeholder="Type your prompt here...", scale=4, lines=2)
214
+ submit_btn = gr.Button("Send ➤", variant="primary", scale=1)
215
+ clear_btn = gr.Button("🗑️ Clear Chat", variant="stop")
216
+
217
+ msg.submit(fn=chat_stream, inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], outputs=chatbot)
218
+ submit_btn.click(fn=chat_stream, inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], outputs=chatbot).then(fn=lambda: "", outputs=msg)
219
+ clear_btn.click(fn=clear_chat, outputs=[msg, chatbot])
220
+
221
+
222
+ if __name__ == "__main__":
223
+ demo.queue().launch(share=False, server_name="0.0.0.0", server_port=7860)
logs/tensorboard/version_2/events.out.tfevents.1765275552.MAC-QNYQPC2R2T.7768.0 CHANGED
Binary files a/logs/tensorboard/version_2/events.out.tfevents.1765275552.MAC-QNYQPC2R2T.7768.0 and b/logs/tensorboard/version_2/events.out.tfevents.1765275552.MAC-QNYQPC2R2T.7768.0 differ
 
logs/tensorboard/version_3/events.out.tfevents.1765278317.MAC-QNYQPC2R2T.13054.0 ADDED
Binary file (5.8 kB). View file
 
logs/tensorboard/version_3/hparams.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ block_size: 512
2
+ peak_lr: 0.0005
3
+ predict_every: 500
4
+ total_steps: 5000
5
+ warmup_steps: 1000
logs/training_20251209_154910.log CHANGED
@@ -33,3 +33,48 @@ First Citizen:
33
  None,
34
  2025-12-09 15:59:47,488 - INFO - ================================================================================
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  None,
34
  2025-12-09 15:59:47,488 - INFO - ================================================================================
35
 
36
+ 2025-12-09 16:10:06,586 - INFO - Step 2500 | train_loss=0.9911
37
+ 2025-12-09 16:10:08,637 - INFO -
38
+ ================================================================================
39
+ 2025-12-09 16:10:08,637 - INFO - Step 2500 - Generated text:
40
+ 2025-12-09 16:10:08,637 - INFO - First Citizen:
41
+ He said he: youCLARENCE:
42
+ He hath nopt to die among this case,
43
+ Yet to flatter, shield your wit would not have not right.
44
+
45
+ LADY ANNE:
46
+ It is it so.
47
+ 2025-12-09 16:10:08,637 - INFO - ================================================================================
48
+
49
+ 2025-12-09 16:20:02,546 - INFO - Step 3000 | train_loss=0.6307
50
+ 2025-12-09 16:20:04,468 - INFO -
51
+ ================================================================================
52
+ 2025-12-09 16:20:04,468 - INFO - Step 3000 - Generated text:
53
+ 2025-12-09 16:20:04,468 - INFO - First Citizen:
54
+ Come, let us go in our delay: if
55
+ you guard guard guard Corioli, your rash a
56
+ more in yourple; even your need, the queen,
57
+ Your wives,
58
+ Your loving, bosom, kill into his
59
+ 2025-12-09 16:20:04,468 - INFO - ================================================================================
60
+
61
+ 2025-12-09 16:30:01,255 - INFO - Step 3500 | train_loss=0.1352
62
+ 2025-12-09 16:30:03,305 - INFO -
63
+ ================================================================================
64
+ 2025-12-09 16:30:03,305 - INFO - Step 3500 - Generated text:
65
+ 2025-12-09 16:30:03,305 - INFO - First Citizen:
66
+ Nor I.
67
+
68
+ CORIOLANUS:
69
+ Not now, if it be your will be here.
70
+
71
+ MENENIUS:
72
+ I tell thee, fellow,
73
+ If thou dost love to see thee,
74
+
75
+ 2025-12-09 16:30:03,305 - INFO - ================================================================================
76
+
77
+ 2025-12-09 16:30:18,743 - INFO - Final checkpoint saved: checkpoints/smollm2-final-step-03500.ckpt
78
+ 2025-12-09 16:31:03,806 - INFO - Training completed!
79
+ 2025-12-09 16:31:03,807 - INFO - Best checkpoint: /Users/qureshsu/Learning/TSAI/ERAV4/session13/smolLM-135/checkpoints/smollm2-step=03500-train_loss=0.1352.ckpt
80
+ 2025-12-09 16:31:03,807 - INFO - Last checkpoint: /Users/qureshsu/Learning/TSAI/ERAV4/session13/smolLM-135/checkpoints/last.ckpt
logs/training_20251209_163515.log ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-12-09 16:35:15,206 - INFO - Logging to: logs/training_20251209_163515.log
2
+ 2025-12-09 16:35:15,206 - INFO - Loading tokenizer...
3
+ 2025-12-09 16:35:16,040 - INFO - Loading model config...
4
+ 2025-12-09 16:35:16,277 - INFO - Loading dataset from: /Users/qureshsu/Learning/TSAI/ERAV4/session13/data/input.txt
5
+ 2025-12-09 16:35:16,738 - INFO - Initializing model...
6
+ 2025-12-09 16:35:17,466 - INFO - Starting training...
7
+ 2025-12-09 16:35:17,466 - INFO - Resuming from checkpoint: checkpoints/smollm2-step=03500-train_loss=0.1352.ckpt
8
+ 2025-12-09 16:35:35,153 - INFO -
9
+ ================================================================================
10
+ 2025-12-09 16:35:35,153 - INFO - MODEL SUMMARY
11
+ 2025-12-09 16:35:35,153 - INFO - ================================================================================
12
+ 2025-12-09 16:35:35,153 - INFO - Model: SmolLM2-135M
13
+ 2025-12-09 16:35:35,153 - INFO - Total parameters: 134,515,008
14
+ 2025-12-09 16:35:35,153 - INFO - Trainable parameters: 134,515,008
15
+ 2025-12-09 16:35:35,153 - INFO - Block size: 512
16
+ 2025-12-09 16:35:35,153 - INFO - Warmup steps: 1000
17
+ 2025-12-09 16:35:35,153 - INFO - Peak learning rate: 0.0005
18
+ 2025-12-09 16:35:35,153 - INFO - Total training steps: 5000
19
+ 2025-12-09 16:35:35,153 - INFO - Predict every: 500 steps
20
+ 2025-12-09 16:35:35,153 - INFO - ================================================================================
21
+
22
+ 2025-12-09 16:46:13,641 - INFO - Step 4000 | train_loss=0.5093
23
+ 2025-12-09 16:46:15,889 - INFO -
24
+ ================================================================================
25
+ 2025-12-09 16:46:15,890 - INFO - Step 4000 - Generated text:
26
+ 2025-12-09 16:46:15,890 - INFO - First Citizen:
27
+ What a strange news, what he hath done famously
28
+ All slain and g indeed.
29
+
30
+ KING HENRY VI:
31
+ Hadst thou been kill'd, I would not sh wrong;
32
+ And by that you are, some thou
33
+ 2025-12-09 16:46:15,890 - INFO - ================================================================================
34
+
35
+ 2025-12-09 16:56:44,602 - INFO - Step 4500 | train_loss=0.5634
36
+ 2025-12-09 16:56:46,770 - INFO -
37
+ ================================================================================
38
+ 2025-12-09 16:56:46,770 - INFO - Step 4500 - Generated text:
39
+ 2025-12-09 16:56:46,770 - INFO - First Citizen:
40
+ 'Tis a nupt in a sword's make him my
41
+ First Citizen:
42
+ Therefore.
43
+
44
+ First Citizen:
45
+ Is there no hope?
46
+
47
+ Third Citizen:
48
+ And ta'en! Suffolk, we shall bring all
49
+
50
+ 2025-12-09 16:56:46,770 - INFO - ================================================================================
51
+
52
+ 2025-12-09 17:07:03,502 - INFO - Step 5000 | train_loss=0.0918
53
+ 2025-12-09 17:07:06,185 - INFO -
54
+ ================================================================================
55
+ 2025-12-09 17:07:06,186 - INFO - Step 5000 - Generated text:
56
+ 2025-12-09 17:07:06,186 - INFO - First Citizen:
57
+ You must think of it?
58
+
59
+ Pedant:
60
+ Ay, I have
61
+ AUTOLYCUS:
62
+ Pray you, who came George to 't last once.
63
+
64
+ AUTOLYCUS:
65
+ I know
66
+ 2025-12-09 17:07:06,186 - INFO - ================================================================================
67
+
68
+ 2025-12-09 17:07:18,753 - INFO - Final checkpoint saved: checkpoints/smollm2-final-step-05000.ckpt
69
+ 2025-12-09 17:07:49,059 - INFO - Training completed!
70
+ 2025-12-09 17:07:49,060 - INFO - Best checkpoint: /Users/qureshsu/Learning/TSAI/ERAV4/session13/smolLM-135/checkpoints/smollm2-step=05000-train_loss=0.0918.ckpt
71
+ 2025-12-09 17:07:49,060 - INFO - Last checkpoint: /Users/qureshsu/Learning/TSAI/ERAV4/session13/smolLM-135/checkpoints/last.ckpt
pyproject.toml CHANGED
@@ -14,4 +14,5 @@ dependencies = [
14
  "torchvision>=0.24.1",
15
  "tqdm>=4.67.1",
16
  "transformers>=4.57.3",
 
17
  ]
 
14
  "torchvision>=0.24.1",
15
  "tqdm>=4.67.1",
16
  "transformers>=4.57.3",
17
+ "gradio>=4.44.0",
18
  ]
train.py CHANGED
@@ -234,9 +234,9 @@ def main():
234
  block_size = 512
235
  batch_size = 4
236
  num_workers = 8
237
- max_steps = 3500
238
  predict_every = 500
239
- resume_from_checkpoint = "checkpoints/smollm2-step=01500-train_loss=3.6240.ckpt" # Set to checkpoint path to resume, or None for fresh training
240
 
241
  # Training hyperparameters from paper
242
  warmup_steps = 1000
 
234
  block_size = 512
235
  batch_size = 4
236
  num_workers = 8
237
+ max_steps = 5000
238
  predict_every = 500
239
+ resume_from_checkpoint = "checkpoints/smollm2-step=03500-train_loss=0.1352.ckpt" # Set to checkpoint path to resume, or None for fresh training
240
 
241
  # Training hyperparameters from paper
242
  warmup_steps = 1000
uv.lock CHANGED
The diff for this file is too large to render. See raw diff