LisaMegaWatts commited on
Commit
8a15026
Β·
verified Β·
1 Parent(s): a2de593

Add model definition

Browse files
Files changed (1) hide show
  1. juliaflux_model.py +253 -0
juliaflux_model.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """JuliaFluxGPT β€” PyTorch reimplementation of JuliaFluxGPT (Flux.jl).
2
+
3
+ LLaMA-style decoder with Grouped Query Attention (8Q/2KV), RMSNorm,
4
+ SwiGLU, RoPE, and weight-tied output. Matches model.jl exactly.
5
+
6
+ Config: d_model=512, n_layers=8, n_heads=8, n_kv_heads=2, head_dim=64,
7
+ ctx=256, vocab=2000, ~23M params.
8
+ """
9
+ import math
10
+ from dataclasses import dataclass
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ # ═══════════════════════════════════════════════════════════════════
18
+ # Configuration
19
+ # ═══════════════════════════════════════════════════════════════════
20
+
21
+
22
+ @dataclass
23
+ class JuliaFluxConfig:
24
+ d_model: int = 512
25
+ n_layers: int = 8
26
+ n_heads: int = 8
27
+ n_kv_heads: int = 2
28
+ head_dim: int = 64
29
+ context_length: int = 256
30
+ vocab_size: int = 2000
31
+ dropout: float = 0.0
32
+ weight_tying: bool = True
33
+ rope_base: float = 10000.0
34
+
35
+
36
+ # ═══════════════════════════════════════════════════════════════════
37
+ # Building blocks
38
+ # ═══════════════════════════════════════════════════════════════════
39
+
40
+
41
+ class RMSNorm(nn.Module):
42
+ def __init__(self, dim: int, eps: float = 1e-6):
43
+ super().__init__()
44
+ self.weight = nn.Parameter(torch.ones(dim))
45
+ self.eps = eps
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
49
+ return x / rms * self.weight
50
+
51
+
52
+ class RotaryEmbedding(nn.Module):
53
+ def __init__(self, dim: int, max_seq_len: int = 512, base: float = 10000.0):
54
+ super().__init__()
55
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
56
+ positions = torch.arange(max_seq_len).float()
57
+ angles = torch.outer(positions, freqs)
58
+ self.register_buffer("cos_cache", angles.cos())
59
+ self.register_buffer("sin_cache", angles.sin())
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ # x: (B, n_heads, T, head_dim)
63
+ seq_len = x.size(2)
64
+ half = x.size(-1) // 2
65
+ x1, x2 = x[..., :half], x[..., half:]
66
+ cos = self.cos_cache[:seq_len, :half].unsqueeze(0).unsqueeze(0)
67
+ sin = self.sin_cache[:seq_len, :half].unsqueeze(0).unsqueeze(0)
68
+ return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
69
+
70
+
71
+ class SwiGLU(nn.Module):
72
+ def __init__(self, d_model: int):
73
+ super().__init__()
74
+ raw_inner = int(4 * d_model * 2 / 3)
75
+ inner_dim = max(64, 64 * ((raw_inner + 32) // 64)) # round-to-nearest-64 (matches Julia)
76
+ self.w_gate = nn.Linear(d_model, inner_dim, bias=False)
77
+ self.w_up = nn.Linear(d_model, inner_dim, bias=False)
78
+ self.w_down = nn.Linear(inner_dim, d_model, bias=False)
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
82
+
83
+
84
+ class GQACausalAttention(nn.Module):
85
+ """Grouped Query Attention with fused K+V projection.
86
+
87
+ Matches JuliaFluxGPT's CausalSelfAttention:
88
+ - wq: (d_model β†’ n_heads * head_dim) for query
89
+ - wkv: (d_model β†’ 2 * n_kv_heads * head_dim) for fused key+value
90
+ - proj: (d_model β†’ d_model) output projection
91
+ - KV heads repeated `groups` times to match query head count
92
+ """
93
+
94
+ def __init__(self, d_model: int, n_heads: int, n_kv_heads: int, head_dim: int):
95
+ super().__init__()
96
+ self.n_heads = n_heads
97
+ self.n_kv_heads = n_kv_heads
98
+ self.head_dim = head_dim
99
+ self.groups = n_heads // n_kv_heads
100
+ kv_dim = n_kv_heads * head_dim
101
+
102
+ self.wq = nn.Linear(d_model, n_heads * head_dim, bias=False)
103
+ self.wkv = nn.Linear(d_model, 2 * kv_dim, bias=False)
104
+ self.proj = nn.Linear(d_model, d_model, bias=False)
105
+
106
+ def forward(self, x: torch.Tensor, rope: RotaryEmbedding,
107
+ mask: torch.Tensor) -> torch.Tensor:
108
+ B, T, _ = x.shape
109
+ H, KVH, HD = self.n_heads, self.n_kv_heads, self.head_dim
110
+
111
+ # Query: (B, T, H*HD) β†’ (B, H, T, HD)
112
+ q = self.wq(x).view(B, T, H, HD).transpose(1, 2)
113
+
114
+ # Fused K+V: (B, T, 2*KVH*HD) β†’ split β†’ each (B, KVH, T, HD)
115
+ kv = self.wkv(x)
116
+ kv_dim = KVH * HD
117
+ k = kv[..., :kv_dim].view(B, T, KVH, HD).transpose(1, 2)
118
+ v = kv[..., kv_dim:].view(B, T, KVH, HD).transpose(1, 2)
119
+
120
+ # Apply RoPE
121
+ q = rope(q)
122
+ k = rope(k)
123
+
124
+ # Repeat KV heads to match query heads
125
+ if self.groups > 1:
126
+ k = k.unsqueeze(2).expand(-1, -1, self.groups, -1, -1)
127
+ k = k.reshape(B, H, T, HD)
128
+ v = v.unsqueeze(2).expand(-1, -1, self.groups, -1, -1)
129
+ v = v.reshape(B, H, T, HD)
130
+
131
+ # Scaled dot-product attention
132
+ scale = 1.0 / math.sqrt(HD)
133
+ attn = torch.matmul(q, k.transpose(-2, -1)) * scale
134
+ attn = attn + mask
135
+ attn = F.softmax(attn, dim=-1)
136
+ out = torch.matmul(attn, v)
137
+
138
+ # Reshape back: (B, H, T, HD) β†’ (B, T, H*HD)
139
+ out = out.transpose(1, 2).contiguous().view(B, T, H * HD)
140
+ return self.proj(out)
141
+
142
+
143
+ # ═══════════════════════════════════════════════════════════════════
144
+ # Transformer block and model
145
+ # ═══════════════════════════════════════════════════════════════════
146
+
147
+
148
+ class TransformerBlock(nn.Module):
149
+ def __init__(self, config: JuliaFluxConfig):
150
+ super().__init__()
151
+ self.ln1 = RMSNorm(config.d_model)
152
+ self.attn = GQACausalAttention(
153
+ config.d_model, config.n_heads, config.n_kv_heads, config.head_dim
154
+ )
155
+ self.ln2 = RMSNorm(config.d_model)
156
+ self.ffn = SwiGLU(config.d_model)
157
+
158
+ def forward(self, x: torch.Tensor, rope: RotaryEmbedding,
159
+ mask: torch.Tensor) -> torch.Tensor:
160
+ x = x + self.attn(self.ln1(x), rope, mask)
161
+ x = x + self.ffn(self.ln2(x))
162
+ return x
163
+
164
+
165
+ class JuliaFluxGPT(nn.Module):
166
+ def __init__(self, config: JuliaFluxConfig):
167
+ super().__init__()
168
+ self.config = config
169
+ self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
170
+ self.rope = RotaryEmbedding(config.head_dim, config.context_length, config.rope_base)
171
+ self.blocks = nn.ModuleList(
172
+ [TransformerBlock(config) for _ in range(config.n_layers)]
173
+ )
174
+ self.ln_f = RMSNorm(config.d_model)
175
+ if config.weight_tying:
176
+ self.head = None
177
+ else:
178
+ self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
179
+
180
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
181
+ B, T = input_ids.shape
182
+ x = self.tok_emb(input_ids)
183
+ mask = torch.triu(
184
+ torch.full((T, T), float("-inf"), device=x.device, dtype=x.dtype),
185
+ diagonal=1,
186
+ )
187
+ for block in self.blocks:
188
+ x = block(x, self.rope, mask)
189
+ x = self.ln_f(x)
190
+ if self.head is not None:
191
+ return self.head(x)
192
+ return F.linear(x, self.tok_emb.weight)
193
+
194
+ @property
195
+ def num_parameters(self) -> int:
196
+ return sum(p.numel() for p in self.parameters())
197
+
198
+ @property
199
+ def weight_entropy(self) -> float:
200
+ """Shannon entropy of weight distribution (bits), 100 bins."""
201
+ all_w = torch.cat([p.detach().flatten() for p in self.parameters()])
202
+ if all_w.numel() == 0:
203
+ return 0.0
204
+ hist = torch.histc(all_w.float(), bins=100)
205
+ probs = hist / hist.sum()
206
+ probs = probs[probs > 0]
207
+ return -(probs * torch.log2(probs)).sum().item()
208
+
209
+ @property
210
+ def effective_rank(self) -> float:
211
+ """Average effective rank across all Linear layers (SVD, >1% threshold)."""
212
+ ranks = []
213
+ for module in self.modules():
214
+ if isinstance(module, nn.Linear):
215
+ w = module.weight.detach()
216
+ try:
217
+ s = torch.linalg.svdvals(w)
218
+ threshold = 0.01 * s[0] if s.numel() > 0 and s[0] > 0 else 0.0
219
+ ranks.append((s > threshold).sum().item())
220
+ except Exception:
221
+ ranks.append(float(min(w.shape)))
222
+ return sum(ranks) / len(ranks) if ranks else 0.0
223
+
224
+
225
+ def load_from_npz(npz_path: str, config: JuliaFluxConfig = None) -> JuliaFluxGPT:
226
+ """Load JuliaFluxGPT from NPZ file exported by convert_juliaflux.jl."""
227
+ import numpy as np
228
+
229
+ data = np.load(npz_path)
230
+
231
+ # Read hyperparams if config not provided
232
+ if config is None:
233
+ config = JuliaFluxConfig(
234
+ vocab_size=int(data["_hp_vocab_size"][0]),
235
+ d_model=int(data["_hp_n_embd"][0]),
236
+ context_length=int(data["_hp_block_size"][0]),
237
+ n_layers=int(data["_hp_n_layer"][0]),
238
+ n_heads=int(data["_hp_n_head"][0]),
239
+ n_kv_heads=int(data["_hp_n_kv_head"][0]),
240
+ )
241
+
242
+ model = JuliaFluxGPT(config)
243
+
244
+ # Build state_dict from NPZ arrays
245
+ state_dict = {}
246
+ for key in data.files:
247
+ if key.startswith("_hp_"):
248
+ continue
249
+ arr = data[key]
250
+ state_dict[key] = torch.from_numpy(arr.copy())
251
+
252
+ model.load_state_dict(state_dict, strict=False)
253
+ return model