JulianKrgd commited on
Commit
75f00a8
·
verified ·
1 Parent(s): 4e7bc2c

Upload src/model/julian.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/model/julian.py +225 -0
src/model/julian.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Julian Model - 250M Parameter LLM.
3
+ GPT-style decoder-only transformer with modern improvements.
4
+ """
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import flax.linen as nn
11
+ from flax.linen import initializers
12
+
13
+ from .config import JulianConfig
14
+ from .layers import RMSNorm, TransformerBlock, precompute_rope_frequencies
15
+
16
+
17
+ class JulianModel(nn.Module):
18
+ """
19
+ Julian Language Model.
20
+
21
+ A GPT-style decoder-only transformer with:
22
+ - RMSNorm (instead of LayerNorm)
23
+ - RoPE positional encoding
24
+ - SwiGLU activation
25
+ - No bias terms
26
+ """
27
+
28
+ config: JulianConfig
29
+
30
+ def setup(self):
31
+ config = self.config
32
+
33
+ # Token embeddings
34
+ self.embed_tokens = nn.Embed(
35
+ num_embeddings=config.vocab_size,
36
+ features=config.d_model,
37
+ embedding_init=initializers.normal(config.initializer_range),
38
+ name="embed_tokens",
39
+ )
40
+
41
+ # Transformer blocks
42
+ self.layers = [
43
+ TransformerBlock(config, name=f"layers_{i}")
44
+ for i in range(config.n_layers)
45
+ ]
46
+
47
+ # Final norm
48
+ self.norm = RMSNorm(config.d_model, config.rms_norm_eps, name="norm")
49
+
50
+ # Output projection (tied with embeddings)
51
+ self.lm_head = nn.Dense(
52
+ config.vocab_size,
53
+ use_bias=False,
54
+ kernel_init=initializers.normal(config.initializer_range),
55
+ name="lm_head",
56
+ )
57
+
58
+ def __call__(
59
+ self,
60
+ input_ids: jnp.ndarray,
61
+ deterministic: bool = True,
62
+ ) -> jnp.ndarray:
63
+ """
64
+ Forward pass.
65
+
66
+ Args:
67
+ input_ids: Token IDs [batch, seq_len]
68
+ deterministic: If True, disable dropout
69
+
70
+ Returns:
71
+ logits: [batch, seq_len, vocab_size]
72
+ """
73
+ config = self.config
74
+ batch_size, seq_len = input_ids.shape
75
+
76
+ # Token embeddings
77
+ hidden_states = self.embed_tokens(input_ids)
78
+
79
+ # Precompute RoPE frequencies
80
+ sin, cos = precompute_rope_frequencies(
81
+ config.head_dim,
82
+ config.max_seq_len,
83
+ config.rope_theta,
84
+ )
85
+
86
+ # Causal mask
87
+ mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))
88
+ mask = mask[None, None, :, :] # [1, 1, seq, seq]
89
+
90
+ # Apply transformer layers
91
+ for layer in self.layers:
92
+ hidden_states = layer(
93
+ hidden_states,
94
+ sin,
95
+ cos,
96
+ mask,
97
+ deterministic,
98
+ )
99
+
100
+ # Final norm
101
+ hidden_states = self.norm(hidden_states)
102
+
103
+ # Project to vocabulary
104
+ logits = self.lm_head(hidden_states)
105
+
106
+ return logits
107
+
108
+ def generate(
109
+ self,
110
+ input_ids: jnp.ndarray,
111
+ max_new_tokens: int = 100,
112
+ temperature: float = 1.0,
113
+ top_k: Optional[int] = 50,
114
+ top_p: Optional[float] = 0.9,
115
+ rng: Optional[jax.random.PRNGKey] = None,
116
+ ) -> jnp.ndarray:
117
+ """
118
+ Generate text autoregressively.
119
+
120
+ Args:
121
+ input_ids: Prompt token IDs [batch, seq_len]
122
+ max_new_tokens: Maximum tokens to generate
123
+ temperature: Sampling temperature
124
+ top_k: Top-k filtering
125
+ top_p: Nucleus sampling threshold
126
+ rng: Random key for sampling
127
+
128
+ Returns:
129
+ Generated token IDs [batch, seq_len + max_new_tokens]
130
+ """
131
+ if rng is None:
132
+ rng = jax.random.PRNGKey(0)
133
+
134
+ config = self.config
135
+
136
+ for _ in range(max_new_tokens):
137
+ # Truncate if exceeding max length
138
+ if input_ids.shape[1] >= config.max_seq_len:
139
+ context = input_ids[:, -config.max_seq_len:]
140
+ else:
141
+ context = input_ids
142
+
143
+ # Forward pass
144
+ logits = self(context, deterministic=True)
145
+
146
+ # Get next token logits
147
+ next_logits = logits[:, -1, :] / temperature
148
+
149
+ # Top-k filtering
150
+ if top_k is not None:
151
+ top_k_logits, top_k_indices = jax.lax.top_k(next_logits, top_k)
152
+ next_logits = jnp.full_like(next_logits, -1e9)
153
+ next_logits = next_logits.at[
154
+ jnp.arange(next_logits.shape[0])[:, None],
155
+ top_k_indices
156
+ ].set(top_k_logits)
157
+
158
+ # Top-p (nucleus) filtering
159
+ if top_p is not None:
160
+ sorted_indices = jnp.argsort(next_logits, axis=-1)[:, ::-1]
161
+ sorted_logits = jnp.take_along_axis(next_logits, sorted_indices, axis=-1)
162
+ cumprobs = jnp.cumsum(jax.nn.softmax(sorted_logits, axis=-1), axis=-1)
163
+
164
+ # Remove tokens with cumulative prob > top_p
165
+ sorted_mask = cumprobs > top_p
166
+ sorted_mask = jnp.concatenate([
167
+ jnp.zeros_like(sorted_mask[:, :1]),
168
+ sorted_mask[:, :-1]
169
+ ], axis=-1)
170
+
171
+ sorted_logits = jnp.where(sorted_mask, -1e9, sorted_logits)
172
+ next_logits = jnp.take_along_axis(
173
+ sorted_logits,
174
+ jnp.argsort(sorted_indices, axis=-1),
175
+ axis=-1
176
+ )
177
+
178
+ # Sample
179
+ rng, sample_rng = jax.random.split(rng)
180
+ probs = jax.nn.softmax(next_logits, axis=-1)
181
+ next_token = jax.random.categorical(sample_rng, jnp.log(probs + 1e-10))
182
+ next_token = next_token[:, None]
183
+
184
+ # Append
185
+ input_ids = jnp.concatenate([input_ids, next_token], axis=1)
186
+
187
+ # Stop at EOS
188
+ if jnp.all(next_token == config.eos_token_id):
189
+ break
190
+
191
+ return input_ids
192
+
193
+
194
+ def create_model(config: Optional[JulianConfig] = None) -> JulianModel:
195
+ """Create Julian model instance."""
196
+ if config is None:
197
+ config = JulianConfig()
198
+ return JulianModel(config)
199
+
200
+
201
+ def count_params(params) -> int:
202
+ """Count total parameters in pytree."""
203
+ return sum(x.size for x in jax.tree_util.tree_leaves(params))
204
+
205
+
206
+ if __name__ == "__main__":
207
+ # Test model creation
208
+ config = JulianConfig()
209
+ model = create_model(config)
210
+
211
+ # Initialize with dummy input
212
+ rng = jax.random.PRNGKey(0)
213
+ dummy_input = jnp.ones((1, 128), dtype=jnp.int32)
214
+
215
+ variables = model.init(rng, dummy_input)
216
+ params = variables["params"]
217
+
218
+ n_params = count_params(params)
219
+ print(f"Julian Model initialized!")
220
+ print(f" Config estimate: {config.estimate_params():,}")
221
+ print(f" Actual params: {n_params:,} ({n_params/1e6:.1f}M)")
222
+
223
+ # Test forward pass
224
+ logits = model.apply(variables, dummy_input)
225
+ print(f" Output shape: {logits.shape}") # [1, 128, 24000]