agileabhi commited on
Commit
84d9ea6
·
verified ·
1 Parent(s): f0ccf17

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +212 -0
model.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from dataclasses import dataclass
6
+
7
+ @dataclass
8
+ class ModelConfig:
9
+ """Configuration matching SmolLM2-135M"""
10
+ vocab_size: int = 49152
11
+ hidden_size: int = 576
12
+ num_hidden_layers: int = 30
13
+ num_attention_heads: int = 9
14
+ intermediate_size: int = 1536
15
+ max_position_embeddings: int = 2048
16
+ layer_norm_eps: float = 1e-5
17
+ hidden_dropout_prob: float = 0.1
18
+ attention_dropout_prob: float = 0.1
19
+
20
+ @property
21
+ def head_dim(self):
22
+ return self.hidden_size // self.num_attention_heads
23
+
24
+
25
+ class RotaryEmbedding(nn.Module):
26
+ """Rotary Position Embedding (RoPE)"""
27
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
28
+ super().__init__()
29
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
30
+ self.register_buffer("inv_freq", inv_freq)
31
+ self.max_seq_len_cached = max_position_embeddings
32
+
33
+ t = torch.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype)
34
+ freqs = torch.outer(t, self.inv_freq)
35
+ emb = torch.cat((freqs, freqs), dim=-1)
36
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
37
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
38
+
39
+ def forward(self, x, seq_len):
40
+ return (
41
+ self.cos_cached[:seq_len, ...],
42
+ self.sin_cached[:seq_len, ...],
43
+ )
44
+
45
+
46
+ def rotate_half(x):
47
+ """Rotates half the hidden dims of the input."""
48
+ x1 = x[..., : x.shape[-1] // 2]
49
+ x2 = x[..., x.shape[-1] // 2 :]
50
+ return torch.cat((-x2, x1), dim=-1)
51
+
52
+
53
+ def apply_rotary_pos_emb(q, k, cos, sin):
54
+ """Apply rotary position embedding to query and key tensors."""
55
+ q_embed = (q * cos) + (rotate_half(q) * sin)
56
+ k_embed = (k * cos) + (rotate_half(k) * sin)
57
+ return q_embed, k_embed
58
+
59
+
60
+ class MultiHeadAttention(nn.Module):
61
+ """Multi-head attention with RoPE"""
62
+ def __init__(self, config: ModelConfig):
63
+ super().__init__()
64
+ self.num_heads = config.num_attention_heads
65
+ self.head_dim = config.head_dim
66
+ self.hidden_size = config.hidden_size
67
+
68
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
69
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
70
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
71
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
72
+
73
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings)
74
+ self.dropout = nn.Dropout(config.attention_dropout_prob)
75
+
76
+ def forward(self, hidden_states, attention_mask=None):
77
+ batch_size, seq_len, _ = hidden_states.shape
78
+
79
+ # Project to Q, K, V
80
+ q = self.q_proj(hidden_states)
81
+ k = self.k_proj(hidden_states)
82
+ v = self.v_proj(hidden_states)
83
+
84
+ # Reshape for multi-head attention
85
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
86
+ k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
87
+ v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
88
+
89
+ # Apply rotary embeddings
90
+ cos, sin = self.rotary_emb(v, seq_len)
91
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
92
+
93
+ # Attention scores
94
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
95
+
96
+ if attention_mask is not None:
97
+ attn_weights = attn_weights + attention_mask
98
+
99
+ attn_weights = F.softmax(attn_weights, dim=-1)
100
+ attn_weights = self.dropout(attn_weights)
101
+
102
+ # Apply attention to values
103
+ attn_output = torch.matmul(attn_weights, v)
104
+
105
+ # Reshape and project
106
+ attn_output = attn_output.transpose(1, 2).contiguous()
107
+ attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
108
+ attn_output = self.o_proj(attn_output)
109
+
110
+ return attn_output
111
+
112
+
113
+ class MLP(nn.Module):
114
+ """Feed-forward network"""
115
+ def __init__(self, config: ModelConfig):
116
+ super().__init__()
117
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
118
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
119
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
120
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
121
+
122
+ def forward(self, x):
123
+ # SwiGLU activation
124
+ gate = F.silu(self.gate_proj(x))
125
+ up = self.up_proj(x)
126
+ return self.dropout(self.down_proj(gate * up))
127
+
128
+
129
+ class TransformerBlock(nn.Module):
130
+ """Single transformer block"""
131
+ def __init__(self, config: ModelConfig):
132
+ super().__init__()
133
+ self.attention = MultiHeadAttention(config)
134
+ self.mlp = MLP(config)
135
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
136
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
137
+
138
+ def forward(self, hidden_states, attention_mask=None):
139
+ # Pre-norm architecture
140
+ residual = hidden_states
141
+ hidden_states = self.input_layernorm(hidden_states)
142
+ hidden_states = self.attention(hidden_states, attention_mask)
143
+ hidden_states = residual + hidden_states
144
+
145
+ residual = hidden_states
146
+ hidden_states = self.post_attention_layernorm(hidden_states)
147
+ hidden_states = self.mlp(hidden_states)
148
+ hidden_states = residual + hidden_states
149
+
150
+ return hidden_states
151
+
152
+
153
+ class CustomSmolLM(nn.Module):
154
+ """Custom implementation mimicking SmolLM2-135M"""
155
+ def __init__(self, config: ModelConfig):
156
+ super().__init__()
157
+ self.config = config
158
+
159
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
160
+ self.layers = nn.ModuleList([
161
+ TransformerBlock(config) for _ in range(config.num_hidden_layers)
162
+ ])
163
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
164
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
165
+
166
+ # Tie weights
167
+ self.lm_head.weight = self.embed_tokens.weight
168
+
169
+ self.apply(self._init_weights)
170
+
171
+ def _init_weights(self, module):
172
+ std = 0.02
173
+ if isinstance(module, nn.Linear):
174
+ module.weight.data.normal_(mean=0.0, std=std)
175
+ if module.bias is not None:
176
+ module.bias.data.zero_()
177
+ elif isinstance(module, nn.Embedding):
178
+ module.weight.data.normal_(mean=0.0, std=std)
179
+
180
+ def forward(self, input_ids, attention_mask=None, labels=None):
181
+ batch_size, seq_len = input_ids.shape
182
+
183
+ # Create causal mask
184
+ if attention_mask is None:
185
+ causal_mask = torch.triu(
186
+ torch.full((seq_len, seq_len), float('-inf'), device=input_ids.device),
187
+ diagonal=1
188
+ )
189
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
190
+ else:
191
+ causal_mask = None # Simplified for this example
192
+
193
+ # Embed tokens
194
+ hidden_states = self.embed_tokens(input_ids)
195
+
196
+ # Pass through transformer blocks
197
+ for layer in self.layers:
198
+ hidden_states = layer(hidden_states, causal_mask)
199
+
200
+ hidden_states = self.norm(hidden_states)
201
+ logits = self.lm_head(hidden_states)
202
+
203
+ loss = None
204
+ if labels is not None:
205
+ shift_logits = logits[..., :-1, :].contiguous()
206
+ shift_labels = labels[..., 1:].contiguous()
207
+ loss = F.cross_entropy(
208
+ shift_logits.view(-1, self.config.vocab_size),
209
+ shift_labels.view(-1)
210
+ )
211
+
212
+ return {'loss': loss, 'logits': logits}