agileabhi commited on
Commit
e9fa755
·
verified ·
1 Parent(s): 1d61fc9
Files changed (3) hide show
  1. app.py +113 -116
  2. model.py +212 -212
  3. requirements.txt +4 -4
app.py CHANGED
@@ -1,116 +1,113 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn.functional as F
4
- from transformers import AutoTokenizer
5
- from model import CustomSmolLM, ModelConfig
6
- import os
7
-
8
- from huggingface_hub import hf_hub_download
9
-
10
- # Configuration
11
- DEVICE = "cpu" # Spaces usually run on CPU unless GPU is requested
12
- # We will host the heavy model weights in a separate Model Repository
13
- MODEL_REPO_ID = "AmolDuse/SmolLM2-135M-Disecting-Model"
14
- MODEL_FILENAME = "final.pt" # Using the stripped version
15
-
16
- print("Downloading model weights from Hub...")
17
- try:
18
- MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
19
- print(f"✅ Model downloaded to {MODEL_PATH}")
20
- except Exception as e:
21
- print(f"⚠️ Could not download model: {e}")
22
- MODEL_PATH = "final.pt" # Fallback to local
23
-
24
- print("Loading model and tokenizer...")
25
- tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
26
-
27
- config = ModelConfig()
28
- model = CustomSmolLM(config)
29
-
30
- # Load weights
31
- if os.path.exists(MODEL_PATH):
32
- state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
33
- model.load_state_dict(state_dict)
34
- print("✅ Model weights loaded successfully")
35
- else:
36
- print(f"⚠️ Warning: {MODEL_PATH} not found. Running with random weights.")
37
-
38
- model.to(DEVICE)
39
- model.eval()
40
-
41
-
42
- def generate_text(prompt, max_length=50, temperature=0.8):
43
- try:
44
- if not prompt:
45
- return "Please enter a prompt."
46
-
47
- # Ensure inputs are correct types
48
- max_length = int(max_length)
49
- temperature = float(temperature)
50
-
51
- print(f"Generating: prompt='{prompt}', max_length={max_length}, temp={temperature}")
52
-
53
- input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE)
54
-
55
- with torch.no_grad():
56
- for i in range(max_length):
57
- outputs = model(input_ids)
58
- logits = outputs['logits']
59
-
60
- # Get next token logits
61
- next_token_logits = logits[:, -1, :] / temperature
62
-
63
- # Sample next token
64
- probs = F.softmax(next_token_logits, dim=-1)
65
- next_token = torch.multinomial(probs, num_samples=1)
66
-
67
- # Append to sequence
68
- input_ids = torch.cat([input_ids, next_token], dim=1)
69
-
70
- # Stop if we hit end of sequence
71
- if tokenizer.eos_token_id is not None and next_token.item() == tokenizer.eos_token_id:
72
- break
73
-
74
- return tokenizer.decode(input_ids[0], skip_special_tokens=True)
75
-
76
- except Exception as e:
77
- import traceback
78
- traceback.print_exc()
79
- return f"Error during generation: {str(e)}"
80
-
81
-
82
- # Gradio Interface
83
- with gr.Blocks(title="SmolLM2-135M Dissecting Demo") as demo:
84
- gr.Markdown("# SmolLM2-135M Dissecting Demo")
85
- gr.Markdown(
86
- "A custom implementation of SmolLM2-135M trained from scratch. Enter a prompt to see what it generates!")
87
-
88
- with gr.Row():
89
- with gr.Column():
90
- prompt_input = gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt")
91
- with gr.Row():
92
- max_len_input = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Max Length")
93
- temp_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
94
- generate_btn = gr.Button("Generate", variant="primary")
95
-
96
- with gr.Column():
97
- output_text = gr.Textbox(label="Generated Text", lines=10)
98
-
99
- generate_btn.click(
100
- fn=generate_text,
101
- inputs=[prompt_input, max_len_input, temp_input],
102
- outputs=output_text
103
- )
104
-
105
- # Add examples
106
- gr.Examples(
107
- examples=[
108
- ["The quick brown fox"],
109
- ["Once upon a time"],
110
- ["What is English"]
111
- ],
112
- inputs=prompt_input
113
- )
114
-
115
- if __name__ == "__main__":
116
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer
5
+ from model import CustomSmolLM, ModelConfig
6
+ import os
7
+
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # Configuration
11
+ DEVICE = "cpu" # Spaces usually run on CPU unless GPU is requested
12
+ # We will host the heavy model weights in a separate Model Repository
13
+ MODEL_REPO_ID = "AmolDuse/SmolLM2-135M-Disecting-Model"
14
+ MODEL_FILENAME = "model.pt" # Using the stripped version
15
+
16
+ print("Downloading model weights from Hub...")
17
+ try:
18
+ MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
19
+ print(f"✅ Model downloaded to {MODEL_PATH}")
20
+ except Exception as e:
21
+ print(f"⚠️ Could not download model: {e}")
22
+ MODEL_PATH = "model.pt" # Fallback to local
23
+
24
+ print("Loading model and tokenizer...")
25
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
26
+
27
+ config = ModelConfig()
28
+ model = CustomSmolLM(config)
29
+
30
+ # Load weights
31
+ if os.path.exists(MODEL_PATH):
32
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
33
+ model.load_state_dict(state_dict)
34
+ print("✅ Model weights loaded successfully")
35
+ else:
36
+ print(f"⚠️ Warning: {MODEL_PATH} not found. Running with random weights.")
37
+
38
+ model.to(DEVICE)
39
+ model.eval()
40
+
41
+ def generate_text(prompt, max_length=50, temperature=0.8):
42
+ try:
43
+ if not prompt:
44
+ return "Please enter a prompt."
45
+
46
+ # Ensure inputs are correct types
47
+ max_length = int(max_length)
48
+ temperature = float(temperature)
49
+
50
+ print(f"Generating: prompt='{prompt}', max_length={max_length}, temp={temperature}")
51
+
52
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE)
53
+
54
+ with torch.no_grad():
55
+ for i in range(max_length):
56
+ outputs = model(input_ids)
57
+ logits = outputs['logits']
58
+
59
+ # Get next token logits
60
+ next_token_logits = logits[:, -1, :] / temperature
61
+
62
+ # Sample next token
63
+ probs = F.softmax(next_token_logits, dim=-1)
64
+ next_token = torch.multinomial(probs, num_samples=1)
65
+
66
+ # Append to sequence
67
+ input_ids = torch.cat([input_ids, next_token], dim=1)
68
+
69
+ # Stop if we hit end of sequence
70
+ if tokenizer.eos_token_id is not None and next_token.item() == tokenizer.eos_token_id:
71
+ break
72
+
73
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True)
74
+
75
+ except Exception as e:
76
+ import traceback
77
+ traceback.print_exc()
78
+ return f"Error during generation: {str(e)}"
79
+
80
+ # Gradio Interface
81
+ with gr.Blocks(title="SmolLM2-135M Dissecting Demo") as demo:
82
+ gr.Markdown("# SmolLM2-135M Dissecting Demo")
83
+ gr.Markdown("A custom implementation of SmolLM2-135M trained from scratch. Enter a prompt to see what it generates!")
84
+
85
+ with gr.Row():
86
+ with gr.Column():
87
+ prompt_input = gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt")
88
+ with gr.Row():
89
+ max_len_input = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Max Length")
90
+ temp_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
91
+ generate_btn = gr.Button("Generate", variant="primary")
92
+
93
+ with gr.Column():
94
+ output_text = gr.Textbox(label="Generated Text", lines=10)
95
+
96
+ generate_btn.click(
97
+ fn=generate_text,
98
+ inputs=[prompt_input, max_len_input, temp_input],
99
+ outputs=output_text
100
+ )
101
+
102
+ # Add examples
103
+ gr.Examples(
104
+ examples=[
105
+ ["The quick brown fox"],
106
+ ["Once upon a time"],
107
+ ["What is English"]
108
+ ],
109
+ inputs=prompt_input
110
+ )
111
+
112
+ if __name__ == "__main__":
113
+ demo.launch()
 
 
 
model.py CHANGED
@@ -1,212 +1,212 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
- from dataclasses import dataclass
6
-
7
- @dataclass
8
- class ModelConfig:
9
- """Configuration matching SmolLM2-135M"""
10
- vocab_size: int = 49152
11
- hidden_size: int = 576
12
- num_hidden_layers: int = 30
13
- num_attention_heads: int = 9
14
- intermediate_size: int = 1536
15
- max_position_embeddings: int = 2048
16
- layer_norm_eps: float = 1e-5
17
- hidden_dropout_prob: float = 0.1
18
- attention_dropout_prob: float = 0.1
19
-
20
- @property
21
- def head_dim(self):
22
- return self.hidden_size // self.num_attention_heads
23
-
24
-
25
- class RotaryEmbedding(nn.Module):
26
- """Rotary Position Embedding (RoPE)"""
27
- def __init__(self, dim, max_position_embeddings=2048, base=10000):
28
- super().__init__()
29
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
30
- self.register_buffer("inv_freq", inv_freq)
31
- self.max_seq_len_cached = max_position_embeddings
32
-
33
- t = torch.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype)
34
- freqs = torch.outer(t, self.inv_freq)
35
- emb = torch.cat((freqs, freqs), dim=-1)
36
- self.register_buffer("cos_cached", emb.cos(), persistent=False)
37
- self.register_buffer("sin_cached", emb.sin(), persistent=False)
38
-
39
- def forward(self, x, seq_len):
40
- return (
41
- self.cos_cached[:seq_len, ...],
42
- self.sin_cached[:seq_len, ...],
43
- )
44
-
45
-
46
- def rotate_half(x):
47
- """Rotates half the hidden dims of the input."""
48
- x1 = x[..., : x.shape[-1] // 2]
49
- x2 = x[..., x.shape[-1] // 2 :]
50
- return torch.cat((-x2, x1), dim=-1)
51
-
52
-
53
- def apply_rotary_pos_emb(q, k, cos, sin):
54
- """Apply rotary position embedding to query and key tensors."""
55
- q_embed = (q * cos) + (rotate_half(q) * sin)
56
- k_embed = (k * cos) + (rotate_half(k) * sin)
57
- return q_embed, k_embed
58
-
59
-
60
- class MultiHeadAttention(nn.Module):
61
- """Multi-head attention with RoPE"""
62
- def __init__(self, config: ModelConfig):
63
- super().__init__()
64
- self.num_heads = config.num_attention_heads
65
- self.head_dim = config.head_dim
66
- self.hidden_size = config.hidden_size
67
-
68
- self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
69
- self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
70
- self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
71
- self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
72
-
73
- self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings)
74
- self.dropout = nn.Dropout(config.attention_dropout_prob)
75
-
76
- def forward(self, hidden_states, attention_mask=None):
77
- batch_size, seq_len, _ = hidden_states.shape
78
-
79
- # Project to Q, K, V
80
- q = self.q_proj(hidden_states)
81
- k = self.k_proj(hidden_states)
82
- v = self.v_proj(hidden_states)
83
-
84
- # Reshape for multi-head attention
85
- q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
86
- k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
87
- v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
88
-
89
- # Apply rotary embeddings
90
- cos, sin = self.rotary_emb(v, seq_len)
91
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
92
-
93
- # Attention scores
94
- attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
95
-
96
- if attention_mask is not None:
97
- attn_weights = attn_weights + attention_mask
98
-
99
- attn_weights = F.softmax(attn_weights, dim=-1)
100
- attn_weights = self.dropout(attn_weights)
101
-
102
- # Apply attention to values
103
- attn_output = torch.matmul(attn_weights, v)
104
-
105
- # Reshape and project
106
- attn_output = attn_output.transpose(1, 2).contiguous()
107
- attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
108
- attn_output = self.o_proj(attn_output)
109
-
110
- return attn_output
111
-
112
-
113
- class MLP(nn.Module):
114
- """Feed-forward network"""
115
- def __init__(self, config: ModelConfig):
116
- super().__init__()
117
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
118
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
119
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
120
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
121
-
122
- def forward(self, x):
123
- # SwiGLU activation
124
- gate = F.silu(self.gate_proj(x))
125
- up = self.up_proj(x)
126
- return self.dropout(self.down_proj(gate * up))
127
-
128
-
129
- class TransformerBlock(nn.Module):
130
- """Single transformer block"""
131
- def __init__(self, config: ModelConfig):
132
- super().__init__()
133
- self.attention = MultiHeadAttention(config)
134
- self.mlp = MLP(config)
135
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
136
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
137
-
138
- def forward(self, hidden_states, attention_mask=None):
139
- # Pre-norm architecture
140
- residual = hidden_states
141
- hidden_states = self.input_layernorm(hidden_states)
142
- hidden_states = self.attention(hidden_states, attention_mask)
143
- hidden_states = residual + hidden_states
144
-
145
- residual = hidden_states
146
- hidden_states = self.post_attention_layernorm(hidden_states)
147
- hidden_states = self.mlp(hidden_states)
148
- hidden_states = residual + hidden_states
149
-
150
- return hidden_states
151
-
152
-
153
- class CustomSmolLM(nn.Module):
154
- """Custom implementation mimicking SmolLM2-135M"""
155
- def __init__(self, config: ModelConfig):
156
- super().__init__()
157
- self.config = config
158
-
159
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
160
- self.layers = nn.ModuleList([
161
- TransformerBlock(config) for _ in range(config.num_hidden_layers)
162
- ])
163
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
164
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
165
-
166
- # Tie weights
167
- self.lm_head.weight = self.embed_tokens.weight
168
-
169
- self.apply(self._init_weights)
170
-
171
- def _init_weights(self, module):
172
- std = 0.02
173
- if isinstance(module, nn.Linear):
174
- module.weight.data.normal_(mean=0.0, std=std)
175
- if module.bias is not None:
176
- module.bias.data.zero_()
177
- elif isinstance(module, nn.Embedding):
178
- module.weight.data.normal_(mean=0.0, std=std)
179
-
180
- def forward(self, input_ids, attention_mask=None, labels=None):
181
- batch_size, seq_len = input_ids.shape
182
-
183
- # Create causal mask
184
- if attention_mask is None:
185
- causal_mask = torch.triu(
186
- torch.full((seq_len, seq_len), float('-inf'), device=input_ids.device),
187
- diagonal=1
188
- )
189
- causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
190
- else:
191
- causal_mask = None # Simplified for this example
192
-
193
- # Embed tokens
194
- hidden_states = self.embed_tokens(input_ids)
195
-
196
- # Pass through transformer blocks
197
- for layer in self.layers:
198
- hidden_states = layer(hidden_states, causal_mask)
199
-
200
- hidden_states = self.norm(hidden_states)
201
- logits = self.lm_head(hidden_states)
202
-
203
- loss = None
204
- if labels is not None:
205
- shift_logits = logits[..., :-1, :].contiguous()
206
- shift_labels = labels[..., 1:].contiguous()
207
- loss = F.cross_entropy(
208
- shift_logits.view(-1, self.config.vocab_size),
209
- shift_labels.view(-1)
210
- )
211
-
212
- return {'loss': loss, 'logits': logits}
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from dataclasses import dataclass
6
+
7
+ @dataclass
8
+ class ModelConfig:
9
+ """Configuration matching SmolLM2-135M"""
10
+ vocab_size: int = 49152
11
+ hidden_size: int = 576
12
+ num_hidden_layers: int = 30
13
+ num_attention_heads: int = 9
14
+ intermediate_size: int = 1536
15
+ max_position_embeddings: int = 2048
16
+ layer_norm_eps: float = 1e-5
17
+ hidden_dropout_prob: float = 0.1
18
+ attention_dropout_prob: float = 0.1
19
+
20
+ @property
21
+ def head_dim(self):
22
+ return self.hidden_size // self.num_attention_heads
23
+
24
+
25
+ class RotaryEmbedding(nn.Module):
26
+ """Rotary Position Embedding (RoPE)"""
27
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
28
+ super().__init__()
29
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
30
+ self.register_buffer("inv_freq", inv_freq)
31
+ self.max_seq_len_cached = max_position_embeddings
32
+
33
+ t = torch.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype)
34
+ freqs = torch.outer(t, self.inv_freq)
35
+ emb = torch.cat((freqs, freqs), dim=-1)
36
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
37
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
38
+
39
+ def forward(self, x, seq_len):
40
+ return (
41
+ self.cos_cached[:seq_len, ...],
42
+ self.sin_cached[:seq_len, ...],
43
+ )
44
+
45
+
46
+ def rotate_half(x):
47
+ """Rotates half the hidden dims of the input."""
48
+ x1 = x[..., : x.shape[-1] // 2]
49
+ x2 = x[..., x.shape[-1] // 2 :]
50
+ return torch.cat((-x2, x1), dim=-1)
51
+
52
+
53
+ def apply_rotary_pos_emb(q, k, cos, sin):
54
+ """Apply rotary position embedding to query and key tensors."""
55
+ q_embed = (q * cos) + (rotate_half(q) * sin)
56
+ k_embed = (k * cos) + (rotate_half(k) * sin)
57
+ return q_embed, k_embed
58
+
59
+
60
+ class MultiHeadAttention(nn.Module):
61
+ """Multi-head attention with RoPE"""
62
+ def __init__(self, config: ModelConfig):
63
+ super().__init__()
64
+ self.num_heads = config.num_attention_heads
65
+ self.head_dim = config.head_dim
66
+ self.hidden_size = config.hidden_size
67
+
68
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
69
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
70
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
71
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
72
+
73
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings)
74
+ self.dropout = nn.Dropout(config.attention_dropout_prob)
75
+
76
+ def forward(self, hidden_states, attention_mask=None):
77
+ batch_size, seq_len, _ = hidden_states.shape
78
+
79
+ # Project to Q, K, V
80
+ q = self.q_proj(hidden_states)
81
+ k = self.k_proj(hidden_states)
82
+ v = self.v_proj(hidden_states)
83
+
84
+ # Reshape for multi-head attention
85
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
86
+ k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
87
+ v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
88
+
89
+ # Apply rotary embeddings
90
+ cos, sin = self.rotary_emb(v, seq_len)
91
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
92
+
93
+ # Attention scores
94
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
95
+
96
+ if attention_mask is not None:
97
+ attn_weights = attn_weights + attention_mask
98
+
99
+ attn_weights = F.softmax(attn_weights, dim=-1)
100
+ attn_weights = self.dropout(attn_weights)
101
+
102
+ # Apply attention to values
103
+ attn_output = torch.matmul(attn_weights, v)
104
+
105
+ # Reshape and project
106
+ attn_output = attn_output.transpose(1, 2).contiguous()
107
+ attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
108
+ attn_output = self.o_proj(attn_output)
109
+
110
+ return attn_output
111
+
112
+
113
+ class MLP(nn.Module):
114
+ """Feed-forward network"""
115
+ def __init__(self, config: ModelConfig):
116
+ super().__init__()
117
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
118
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
119
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
120
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
121
+
122
+ def forward(self, x):
123
+ # SwiGLU activation
124
+ gate = F.silu(self.gate_proj(x))
125
+ up = self.up_proj(x)
126
+ return self.dropout(self.down_proj(gate * up))
127
+
128
+
129
+ class TransformerBlock(nn.Module):
130
+ """Single transformer block"""
131
+ def __init__(self, config: ModelConfig):
132
+ super().__init__()
133
+ self.attention = MultiHeadAttention(config)
134
+ self.mlp = MLP(config)
135
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
136
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
137
+
138
+ def forward(self, hidden_states, attention_mask=None):
139
+ # Pre-norm architecture
140
+ residual = hidden_states
141
+ hidden_states = self.input_layernorm(hidden_states)
142
+ hidden_states = self.attention(hidden_states, attention_mask)
143
+ hidden_states = residual + hidden_states
144
+
145
+ residual = hidden_states
146
+ hidden_states = self.post_attention_layernorm(hidden_states)
147
+ hidden_states = self.mlp(hidden_states)
148
+ hidden_states = residual + hidden_states
149
+
150
+ return hidden_states
151
+
152
+
153
+ class CustomSmolLM(nn.Module):
154
+ """Custom implementation mimicking SmolLM2-135M"""
155
+ def __init__(self, config: ModelConfig):
156
+ super().__init__()
157
+ self.config = config
158
+
159
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
160
+ self.layers = nn.ModuleList([
161
+ TransformerBlock(config) for _ in range(config.num_hidden_layers)
162
+ ])
163
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
164
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
165
+
166
+ # Tie weights
167
+ self.lm_head.weight = self.embed_tokens.weight
168
+
169
+ self.apply(self._init_weights)
170
+
171
+ def _init_weights(self, module):
172
+ std = 0.02
173
+ if isinstance(module, nn.Linear):
174
+ module.weight.data.normal_(mean=0.0, std=std)
175
+ if module.bias is not None:
176
+ module.bias.data.zero_()
177
+ elif isinstance(module, nn.Embedding):
178
+ module.weight.data.normal_(mean=0.0, std=std)
179
+
180
+ def forward(self, input_ids, attention_mask=None, labels=None):
181
+ batch_size, seq_len = input_ids.shape
182
+
183
+ # Create causal mask
184
+ if attention_mask is None:
185
+ causal_mask = torch.triu(
186
+ torch.full((seq_len, seq_len), float('-inf'), device=input_ids.device),
187
+ diagonal=1
188
+ )
189
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
190
+ else:
191
+ causal_mask = None # Simplified for this example
192
+
193
+ # Embed tokens
194
+ hidden_states = self.embed_tokens(input_ids)
195
+
196
+ # Pass through transformer blocks
197
+ for layer in self.layers:
198
+ hidden_states = layer(hidden_states, causal_mask)
199
+
200
+ hidden_states = self.norm(hidden_states)
201
+ logits = self.lm_head(hidden_states)
202
+
203
+ loss = None
204
+ if labels is not None:
205
+ shift_logits = logits[..., :-1, :].contiguous()
206
+ shift_labels = labels[..., 1:].contiguous()
207
+ loss = F.cross_entropy(
208
+ shift_logits.view(-1, self.config.vocab_size),
209
+ shift_labels.view(-1)
210
+ )
211
+
212
+ return {'loss': loss, 'logits': logits}
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- torch
2
- transformers
3
- gradio
4
- huggingface_hub
 
1
+ torch
2
+ transformers
3
+ gradio>=4.0.0
4
+ huggingface_hub