gate369 commited on
Commit
9125ea0
·
verified ·
1 Parent(s): e33232e

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +338 -0
README.md ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoTokenizer
5
+ from huggingface_hub import hf_hub_download
6
+ from safetensors.torch import load_file
7
+ import json
8
+
9
+
10
+ # ----------------------
11
+ # MoR Model Components
12
+ # ----------------------
13
+ class ExpertChoiceRouter(nn.Module):
14
+ """Expert Choice Routing: Experts select top-k tokens"""
15
+ def __init__(self, dim, num_experts, k=2):
16
+ super().__init__()
17
+ self.num_experts = num_experts
18
+ self.k = k
19
+ self.gate = nn.Linear(dim, num_experts, bias=False)
20
+
21
+ def forward(self, x):
22
+ # x: (batch, seq_len, dim)
23
+ scores = self.gate(x) # (batch, seq_len, num_experts)
24
+ expert_weights, expert_indices = torch.topk(scores, self.k, dim=-1)
25
+ return expert_weights.softmax(dim=-1), expert_indices
26
+
27
+ class Quantizer4Bit(nn.Module):
28
+ """4-bit Quantization Utilities"""
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ @staticmethod
33
+ def quantize(tensor):
34
+ """Quantize tensor to 4-bit integers"""
35
+ scale = tensor.abs().max() / 7.5
36
+ scale = torch.clamp(scale, min=1e-8)
37
+ quantized = torch.clamp(torch.round(tensor / scale), -8, 7)
38
+ return quantized.to(torch.int8), scale
39
+
40
+ @staticmethod
41
+ def dequantize(quantized, scale):
42
+ """Dequantize 4-bit integers to float"""
43
+ return quantized.float() * scale
44
+
45
+ class QuantizedRecursiveTransformerBlock(nn.Module):
46
+ """Recursive Transformer Block with Quantization"""
47
+ def __init__(self, dim, num_heads, ffn_expansion=4):
48
+ super().__init__()
49
+ self.dim = dim
50
+ self.num_heads = num_heads
51
+ self.head_dim = dim // num_heads
52
+
53
+ # Attention layers
54
+ self.q_proj = nn.Linear(dim, dim)
55
+ self.k_proj = nn.Linear(dim, dim)
56
+ self.v_proj = nn.Linear(dim, dim)
57
+ self.attn_out = nn.Linear(dim, dim)
58
+
59
+ # FFN layers
60
+ self.ffn = nn.Sequential(
61
+ nn.Linear(dim, ffn_expansion * dim),
62
+ nn.GELU(),
63
+ nn.Linear(ffn_expansion * dim, dim)
64
+ )
65
+
66
+ # Normalization
67
+ self.norm1 = nn.LayerNorm(dim)
68
+ self.norm2 = nn.LayerNorm(dim)
69
+
70
+ def forward(self, x):
71
+ # x: (batch, seq_len, dim)
72
+ residual = x
73
+ x = self.norm1(x)
74
+
75
+ # Projections
76
+ q = self.q_proj(x)
77
+ k = self.k_proj(x)
78
+ v = self.v_proj(x)
79
+
80
+ # Quantize K and V
81
+ k_quant, k_scale = Quantizer4Bit.quantize(k)
82
+ v_quant, v_scale = Quantizer4Bit.quantize(v)
83
+
84
+ # Dequantize for computation
85
+ k = Quantizer4Bit.dequantize(k_quant, k_scale)
86
+ v = Quantizer4Bit.dequantize(v_quant, v_scale)
87
+
88
+ # Attention
89
+ B, T, _ = q.shape
90
+ q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
91
+ k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
92
+ v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
93
+
94
+ attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
95
+ attn = attn.softmax(dim=-1)
96
+ attn_out = (attn @ v).transpose(1, 2).contiguous().view(B, T, self.dim)
97
+ attn_out = self.attn_out(attn_out)
98
+
99
+ # Residual connection
100
+ x = residual + attn_out
101
+
102
+ # FFN
103
+ x = x + self.ffn(self.norm2(x))
104
+ return x
105
+
106
+ class RecursionDepthRouter(nn.Module):
107
+ """Lightweight Router for Dynamic Recursion Depth"""
108
+ def __init__(self, dim, max_depth=4):
109
+ super().__init__()
110
+ self.max_depth = max_depth
111
+ self.router = nn.Sequential(
112
+ nn.Linear(dim, 32),
113
+ nn.ReLU(),
114
+ nn.Linear(32, max_depth)
115
+ )
116
+
117
+ def forward(self, x):
118
+ # x: (batch, seq_len, dim)
119
+ router_logits = self.router(x.mean(dim=1)) # (batch, max_depth)
120
+ return router_logits.softmax(dim=-1)
121
+
122
+ class QuantizedMoRModel(nn.Module):
123
+ """Main MoR Architecture"""
124
+ def __init__(self, vocab_size, dim, num_layers, num_heads, max_recursion, num_experts, max_position_embeddings):
125
+ super().__init__()
126
+ self.dim = dim
127
+ self.max_recursion = max_recursion
128
+ self.num_experts = num_experts
129
+ self.max_position_embeddings = max_position_embeddings
130
+
131
+ # Embedding layers
132
+ self.embedding = nn.Embedding(vocab_size, dim)
133
+ self.pos_embed = nn.Embedding(max_position_embeddings, dim)
134
+
135
+
136
+ # Initial unique layers
137
+ self.init_layers = nn.ModuleList([
138
+ QuantizedRecursiveTransformerBlock(dim, num_heads)
139
+ for _ in range(2)
140
+ ])
141
+
142
+ # Middle-cycle shared layers
143
+ self.cycle_depth = 3
144
+ self.recursive_blocks = nn.ModuleList([
145
+ QuantizedRecursiveTransformerBlock(dim, num_heads)
146
+ for _ in range(self.cycle_depth)
147
+ ])
148
+
149
+ # Recursion routers
150
+ self.recursion_routers = nn.ModuleList([
151
+ RecursionDepthRouter(dim, max_depth=max_recursion)
152
+ for _ in range(num_layers - 4)
153
+ ])
154
+
155
+ # Expert choice routing
156
+ self.expert_routers = nn.ModuleList([
157
+ ExpertChoiceRouter(dim, num_experts)
158
+ for _ in range(max_recursion)
159
+ ])
160
+
161
+ # Final unique layers
162
+ self.final_layers = nn.ModuleList([
163
+ QuantizedRecursiveTransformerBlock(dim, num_heads)
164
+ for _ in range(2)
165
+ ])
166
+
167
+ # Output head
168
+ self.ln_f = nn.LayerNorm(dim)
169
+ self.head = nn.Linear(dim, vocab_size, bias=False)
170
+
171
+ def forward(self, x):
172
+ # Embedding
173
+ pos = torch.arange(0, x.shape[1], device=x.device)
174
+ x = self.embedding(x) + self.pos_embed(pos)
175
+
176
+ # Initial unique layers
177
+ for layer in self.init_layers:
178
+ x = layer(x)
179
+
180
+ # Middle-cycle with recursion
181
+ all_x = [x]
182
+ batch_size, seq_len, _ = x.shape
183
+
184
+ for router in self.recursion_routers:
185
+ # Get recursion depth probabilities
186
+ depth_probs = router(x)
187
+
188
+ # Sample recursion depth
189
+ depth = torch.multinomial(depth_probs, 1).squeeze()
190
+
191
+ # Process through recursive blocks
192
+ for d in range(self.max_recursion):
193
+ # Expert routing
194
+ expert_weights, expert_indices = self.expert_routers[d](x)
195
+
196
+ # Create full weight matrix
197
+ full_weights = torch.zeros((batch_size, seq_len, self.num_experts),
198
+ device=x.device)
199
+ full_weights.scatter_(2, expert_indices, expert_weights)
200
+
201
+ # Process each expert
202
+ expert_outputs = []
203
+ for expert_idx in range(self.num_experts):
204
+ # Get expert mask
205
+ expert_mask = full_weights[:, :, expert_idx] > 0
206
+
207
+ if expert_mask.any():
208
+ # Create expert input
209
+ expert_x = torch.zeros_like(x)
210
+ expert_x[expert_mask] = x[expert_mask]
211
+
212
+ # Process through block
213
+ out = self.recursive_blocks[d % self.cycle_depth](expert_x)
214
+ expert_outputs.append(out * full_weights[:, :, expert_idx].unsqueeze(-1))
215
+ else:
216
+ expert_outputs.append(torch.zeros_like(x))
217
+
218
+ # Combine expert outputs
219
+ x = sum(expert_outputs)
220
+
221
+ all_x.append(x)
222
+
223
+ # Combine outputs
224
+ x = torch.stack(all_x).mean(dim=0)
225
+
226
+ # Final unique layers
227
+ for layer in self.final_layers:
228
+ x = layer(x)
229
+
230
+ # Output
231
+ x = self.ln_f(x)
232
+ logits = self.head(x)
233
+ return logits
234
+
235
+ def generate(self, input_ids, max_length=100, temperature=0.8, top_k=50):
236
+ """Simple text generation function"""
237
+ device = next(self.parameters()).device
238
+ generated = input_ids.clone()
239
+
240
+ with torch.no_grad():
241
+ for _ in range(max_length):
242
+ # Use max_position_embeddings instead of SEQ_LEN
243
+ inputs = generated[:, -self.max_position_embeddings:] \
244
+ if generated.shape[1] > self.max_position_embeddings \
245
+ else generated
246
+
247
+ # Forward pass
248
+ logits = self(inputs)[:, -1, :]
249
+
250
+ # Apply temperature
251
+ logits = logits / temperature
252
+
253
+ # Top-k filtering
254
+ if top_k > 0:
255
+ top_values, _ = torch.topk(logits, top_k)
256
+ min_value = top_values[:, -1]
257
+ logits[logits < min_value.unsqueeze(-1)] = -float('Inf')
258
+
259
+ # Sample next token
260
+ probs = torch.softmax(logits, dim=-1)
261
+ next_token = torch.multinomial(probs, 1)
262
+
263
+ # Append to sequence
264
+ generated = torch.cat([generated, next_token], dim=-1)
265
+
266
+ # Break if EOS token
267
+ if next_token.item() == tokenizer.eos_token_id:
268
+ break
269
+
270
+ return generated
271
+
272
+ # ----------------------
273
+ # Load Model from Hugging Face Hub (Updated)
274
+ # ----------------------
275
+ def load_model_from_hub(repo_id="liminerity/MoR-v1"):
276
+ # 1. Download config
277
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
278
+ with open(config_path, "r") as f:
279
+ config = json.load(f)
280
+
281
+ print("Model Config:", config)
282
+
283
+ # 2. Initialize model with config (including max_position_embeddings)
284
+ model = QuantizedMoRModel(
285
+ vocab_size=config["vocab_size"],
286
+ dim=config["dim"],
287
+ num_layers=config["num_layers"],
288
+ num_heads=config["num_heads"],
289
+ max_recursion=config["max_recursion"],
290
+ num_experts=config["num_experts"],
291
+ max_position_embeddings=config["max_position_embeddings"]
292
+ )
293
+
294
+ # 3. Download and load weights
295
+ weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
296
+ weights = load_file(weights_path)
297
+ model.load_state_dict(weights)
298
+
299
+ # 4. Load tokenizer
300
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
301
+
302
+ return model, tokenizer
303
+
304
+ # ----------------------
305
+ # Run Inference
306
+ # ----------------------
307
+ def run_inference(model, tokenizer, prompt, max_length=100):
308
+ # Encode prompt
309
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
310
+ input_ids = inputs["input_ids"]
311
+
312
+ # Move to GPU if available
313
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
314
+ model = model.to(device).eval()
315
+ input_ids = input_ids.to(device)
316
+
317
+ # Generate text
318
+ output_ids = model.generate(input_ids, max_length=max_length)
319
+
320
+ # Decode and return
321
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
322
+
323
+ # ----------------------
324
+ # Main Execution
325
+ # ----------------------
326
+ if __name__ == "__main__":
327
+ # Load model and tokenizer
328
+ print("Loading model from Hugging Face Hub...")
329
+ model, tokenizer = load_model_from_hub()
330
+
331
+ # Run inference
332
+ prompt = "The future of artificial intelligence"
333
+ print(f"\nPrompt: {prompt}")
334
+
335
+ generated_text = run_inference(model, tokenizer, prompt, max_length=100)
336
+ print("\nGenerated Text:")
337
+ print(generated_text)
338
+ ```