Naqeeb-2424 commited on
Commit
293217d
·
verified ·
1 Parent(s): 8b57805

Create Beam_search.py

Browse files
Files changed (1) hide show
  1. Beam_search.py +375 -0
Beam_search.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import numpy as np
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import flax.linen as nn
8
+ import flax.serialization
9
+ from tokenizers import Tokenizer
10
+
11
+ # ---------------------------
12
+ # Constants and File Paths
13
+ # ---------------------------
14
+ # Updated tokenizer path from provided config.
15
+ TOKENIZER_PATH = "TOKENIZER_PATH"
16
+ # Path for saved model parameters (assumed unchanged).
17
+ MODEL_PARAMS_SAVE_PATH = "MODEL_PATH"
18
+
19
+ # ---------------------------
20
+ # Global Definitions
21
+ # ---------------------------
22
+ DTYPE = jnp.bfloat16
23
+ RMSNORM_EPS = 1e-05
24
+ dense_init = nn.initializers.normal(stddev=0.02)
25
+ CTX_LEN = 2048
26
+ NUM_KV_HEADS = 4
27
+
28
+ # ---------------------------
29
+ # Configuration Values (from provided config)
30
+ # ---------------------------
31
+ config = {
32
+ "d_model": 768,
33
+ "nhead": 16,
34
+ "num_layers": 24,
35
+ "ff_hidden_dim": 3072,
36
+ "vocab_size": 49800,
37
+ "max_len": 2048, # CTX_LEN value from provided config.
38
+ "dropout_rate": 0.1, # Set dropout rate as needed.
39
+ "window_layer_indices": [2, 5, 8, 11, 14, 17, 20, 23],
40
+ "moe_layer_indices": [4, 9, 14, 19],
41
+ "window_size": 512,
42
+ "moe_params": {"num_experts": 4, "num_experts_per_tok": 2},
43
+ }
44
+
45
+ # ---------------------------
46
+ # Custom Modules (Updated Architecture)
47
+ # ---------------------------
48
+ class RMSNorm(nn.Module):
49
+ epsilon: float = RMSNORM_EPS
50
+ dtype: any = DTYPE
51
+ @nn.compact
52
+ def __call__(self, x):
53
+ dim = x.shape[-1]
54
+ scale = self.param("scale", nn.initializers.ones, (dim,))
55
+ norm = jnp.sqrt(jnp.mean(x ** 2, axis=-1, keepdims=True) + self.epsilon)
56
+ return (x / norm) * scale
57
+
58
+ class RoPE(nn.Module):
59
+ d_model: int
60
+ max_len: int
61
+ dtype: any = DTYPE
62
+ def setup(self):
63
+ self.inv_freq = 1.0 / (10000.0 ** (jnp.arange(0, self.d_model, 2, dtype=jnp.float32) / self.d_model))
64
+ def __call__(self, x):
65
+ seq_len = x.shape[-2]
66
+ pos = jnp.arange(seq_len, dtype=jnp.float32)[None, None, :, None]
67
+ inv_freq = self.inv_freq[None, None, None, :]
68
+ freqs = pos * inv_freq
69
+ cos = jnp.cos(freqs).astype(self.dtype)
70
+ sin = jnp.sin(freqs).astype(self.dtype)
71
+ x1 = x[..., ::2]
72
+ x2 = x[..., 1::2]
73
+ return jnp.concatenate([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1)
74
+
75
+ class FeedForward(nn.Module):
76
+ d_model: int
77
+ hidden_dim: int
78
+ dropout_rate: float
79
+ dtype: any = DTYPE
80
+ @nn.compact
81
+ def __call__(self, x, deterministic: bool = True):
82
+ proj = nn.Dense(self.hidden_dim * 2, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x)
83
+ x1, x2 = jnp.split(proj, 2, axis=-1)
84
+ x_act = x1 * nn.silu(x2)
85
+ x_act = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x_act)
86
+ return nn.Dropout(rate=self.dropout_rate)(x_act, deterministic=deterministic)
87
+
88
+ # Expert module for MoE.
89
+ class ExpertFFN(nn.Module):
90
+ d_model: int
91
+ hidden_dim: int
92
+ dropout_rate: float
93
+ dtype: any = DTYPE
94
+ @nn.compact
95
+ def __call__(self, x, deterministic: bool = True):
96
+ hidden = nn.Dense(self.hidden_dim, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x)
97
+ hidden = nn.silu(hidden)
98
+ out = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(hidden)
99
+ return out
100
+
101
+ # MoE feed-forward block using nn.vmap for experts.
102
+ class MoEFeedForward(nn.Module):
103
+ d_model: int
104
+ hidden_dim: int
105
+ dropout_rate: float
106
+ num_experts: int = 4
107
+ num_experts_per_tok: int = 2
108
+ dtype: any = DTYPE
109
+ @nn.compact
110
+ def __call__(self, x, deterministic: bool = True):
111
+ gate_logits = nn.Dense(self.num_experts, use_bias=False, dtype=self.dtype)(x)
112
+ gate_scores = nn.softmax(gate_logits, axis=-1) # [B, T, num_experts]
113
+ # Create expert module using vmap (do not pass deterministic here)
114
+ expert_ffn = nn.vmap(ExpertFFN,
115
+ variable_axes={'params': 0},
116
+ split_rngs={'params': True},
117
+ in_axes=0,
118
+ out_axes=0)(d_model=self.d_model,
119
+ hidden_dim=self.hidden_dim,
120
+ dropout_rate=self.dropout_rate,
121
+ dtype=self.dtype)
122
+ x_expert = jnp.broadcast_to(x, (self.num_experts,) + x.shape)
123
+ experts = expert_ffn(x_expert) # [num_experts, B, T, d_model]
124
+ gate_scores = jnp.transpose(gate_scores, (2, 0, 1))[..., None] # [num_experts, B, T, 1]
125
+ moe_output = jnp.sum(experts * gate_scores, axis=0) # [B, T, d_model]
126
+ moe_output = nn.Dropout(rate=self.dropout_rate)(moe_output, deterministic=deterministic)
127
+ return moe_output
128
+
129
+ class LLaMAAttention(nn.Module):
130
+ d_model: int
131
+ nhead: int
132
+ num_kv_heads: int
133
+ dropout_rate: float
134
+ dtype: any = DTYPE
135
+ use_sliding_window: bool = False
136
+ window_size: int = 512
137
+ def setup(self):
138
+ self.head_dim = self.d_model // self.nhead
139
+ self.q_proj = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)
140
+ self.kv_proj = nn.Dense(2 * (self.num_kv_heads * self.head_dim),
141
+ use_bias=False, kernel_init=dense_init, dtype=self.dtype)
142
+ self.out_proj = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)
143
+ self.dropout = nn.Dropout(rate=self.dropout_rate)
144
+ self.rope = RoPE(d_model=self.head_dim, max_len=CTX_LEN, dtype=self.dtype)
145
+ self.layer_scale_attn = self.param("layer_scale_attn", nn.initializers.constant(0.1), (self.d_model,))
146
+ def __call__(self, x, deterministic: bool = True):
147
+ B, T, _ = x.shape
148
+ q = self.q_proj(x).reshape(B, T, self.nhead, self.head_dim)
149
+ kv = self.kv_proj(x).reshape(B, T, self.num_kv_heads, 2 * self.head_dim)
150
+ k, v = jnp.split(kv, 2, axis=-1)
151
+ group_factor = self.nhead // self.num_kv_heads
152
+ k = jnp.repeat(k, repeats=group_factor, axis=2)
153
+ v = jnp.repeat(v, repeats=group_factor, axis=2)
154
+ q = jnp.transpose(q, (0, 2, 1, 3))
155
+ k = jnp.transpose(k, (0, 2, 1, 3))
156
+ q = self.rope(q)
157
+ k = self.rope(k)
158
+ q = jnp.transpose(q, (0, 2, 1, 3))
159
+ k = jnp.transpose(k, (0, 2, 1, 3))
160
+ attn_weights = jnp.einsum("bthd,bThd->bthT", q, k) / jnp.sqrt(self.head_dim)
161
+ if self.use_sliding_window:
162
+ i = jnp.arange(T)[:, None]
163
+ j = jnp.arange(T)[None, :]
164
+ sliding_mask = (i - j < self.window_size) & (i >= j)
165
+ sliding_mask = sliding_mask[None, :, None, :]
166
+ attn_weights = jnp.where(sliding_mask, attn_weights, -1e10)
167
+ else:
168
+ causal_mask = jnp.tril(jnp.ones((T, T), dtype=bool))[None, :, None, :]
169
+ attn_weights = jnp.where(causal_mask, attn_weights, -1e10)
170
+ attn_probs = nn.softmax(attn_weights, axis=-1)
171
+ attn_probs = self.dropout(attn_probs, deterministic=deterministic)
172
+ attn_output = jnp.einsum("bthT,bThd->bthd", attn_probs, v)
173
+ attn_output = attn_output.reshape(B, T, self.d_model)
174
+ output = self.out_proj(attn_output)
175
+ output = self.dropout(output, deterministic=deterministic)
176
+ return output * self.layer_scale_attn
177
+
178
+ class TransformerLayer(nn.Module):
179
+ d_model: int
180
+ nhead: int
181
+ ff_hidden_dim: int
182
+ dropout_rate: float
183
+ dtype: any = DTYPE
184
+ use_sliding_window: bool = False
185
+ window_size: int = 512
186
+ use_moe: bool = False
187
+ moe_params: dict = None
188
+ def setup(self):
189
+ self.attn_norm = RMSNorm(dtype=self.dtype)
190
+ self.attn = LLaMAAttention(
191
+ d_model=self.d_model,
192
+ nhead=self.nhead,
193
+ num_kv_heads=NUM_KV_HEADS,
194
+ dropout_rate=0.0,
195
+ dtype=self.dtype,
196
+ use_sliding_window=self.use_sliding_window,
197
+ window_size=self.window_size
198
+ )
199
+ self.ff_norm = RMSNorm(dtype=self.dtype)
200
+ if self.use_moe:
201
+ self.ff = MoEFeedForward(
202
+ d_model=self.d_model,
203
+ hidden_dim=self.ff_hidden_dim,
204
+ dropout_rate=self.dropout_rate,
205
+ num_experts=self.moe_params.get("num_experts", 4) if self.moe_params else 4,
206
+ num_experts_per_tok=self.moe_params.get("num_experts_per_tok", 2) if self.moe_params else 2,
207
+ dtype=self.dtype
208
+ )
209
+ else:
210
+ self.ff = FeedForward(
211
+ d_model=self.d_model,
212
+ hidden_dim=self.ff_hidden_dim,
213
+ dropout_rate=self.dropout_rate,
214
+ dtype=self.dtype
215
+ )
216
+ self.layer_scale_ff = self.param("layer_scale_ff", nn.initializers.constant(0.1), (self.d_model,))
217
+ def __call__(self, x, deterministic: bool = True):
218
+ x = x + self.attn(self.attn_norm(x), deterministic=deterministic)
219
+ x = x + self.ff(self.ff_norm(x), deterministic=deterministic) * self.layer_scale_ff
220
+ return x
221
+
222
+ class DeepSeekModel(nn.Module):
223
+ vocab_size: int
224
+ d_model: int
225
+ nhead: int
226
+ num_layers: int
227
+ ff_hidden_dim: int
228
+ max_len: int
229
+ dropout_rate: float
230
+ dtype: any = DTYPE
231
+ window_layer_indices: list = None # For sliding window attention
232
+ moe_layer_indices: list = None # For MoE feed-forward layers
233
+ window_size: int = 512
234
+ moe_params: dict = None
235
+ def setup(self):
236
+ self.embed = nn.Embed(
237
+ num_embeddings=self.vocab_size,
238
+ features=self.d_model,
239
+ embedding_init=dense_init,
240
+ dtype=self.dtype
241
+ )
242
+ self.layers = [
243
+ TransformerLayer(
244
+ d_model=self.d_model,
245
+ nhead=self.nhead,
246
+ ff_hidden_dim=self.ff_hidden_dim,
247
+ dropout_rate=self.dropout_rate,
248
+ dtype=self.dtype,
249
+ use_sliding_window=(self.window_layer_indices is not None and i in self.window_layer_indices),
250
+ window_size=self.window_size,
251
+ use_moe=(self.moe_layer_indices is not None and i in self.moe_layer_indices),
252
+ moe_params=self.moe_params
253
+ )
254
+ for i in range(self.num_layers)
255
+ ]
256
+ self.norm = RMSNorm(dtype=self.dtype)
257
+ def __call__(self, input_ids, deterministic: bool = True):
258
+ x = self.embed(input_ids)
259
+ for layer in self.layers:
260
+ x = layer(x, deterministic=deterministic)
261
+ x = self.norm(x)
262
+ logits = x @ self.embed.embedding.T # weight tying
263
+ return logits
264
+
265
+ # ---------------------------
266
+ # Load Tokenizer and Model Parameters
267
+ # ---------------------------
268
+ tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
269
+ PAD_TOKEN_ID = tokenizer.token_to_id("<pad>")
270
+ START_TOKEN_ID = tokenizer.token_to_id("<s>")
271
+ END_SEQ_TOKEN_ID = tokenizer.token_to_id("</s>")
272
+
273
+ # Instantiate model using the provided config values.
274
+ model_instance = DeepSeekModel(
275
+ vocab_size=config["vocab_size"],
276
+ d_model=config["d_model"],
277
+ nhead=config["nhead"],
278
+ num_layers=config["num_layers"],
279
+ ff_hidden_dim=config["ff_hidden_dim"],
280
+ max_len=config["max_len"],
281
+ dropout_rate=config["dropout_rate"],
282
+ dtype=DTYPE,
283
+ window_layer_indices=config["window_layer_indices"],
284
+ moe_layer_indices=config["moe_layer_indices"],
285
+ window_size=config["window_size"],
286
+ moe_params=config["moe_params"]
287
+ )
288
+
289
+ # Initialize dummy parameters to match the model structure.
290
+ dummy_input = jnp.ones((1, config["max_len"] - 1), dtype=jnp.int32)
291
+ rng = jax.random.PRNGKey(0)
292
+ init_params = model_instance.init(rng, dummy_input, deterministic=True)
293
+
294
+ # Load saved parameters.
295
+ with open(MODEL_PARAMS_SAVE_PATH, "rb") as f:
296
+ saved_params_bytes = f.read()
297
+ saved_params = flax.serialization.from_bytes(init_params, saved_params_bytes)
298
+ print("Loaded model parameters.")
299
+
300
+ # ---------------------------
301
+ # Updated Beam Search Inference Function
302
+ # ---------------------------
303
+ def beam_search(params, prompt_ids, model, beam_size=3, max_length=50, end_token_id=END_SEQ_TOKEN_ID):
304
+ """
305
+ Performs beam search starting from prompt_ids.
306
+ At each generation step, candidate tokens for each beam are printed.
307
+ After each step, all current beams (with cumulative scores) are printed.
308
+ Finally, all predicted beams and the best beam are printed.
309
+ Returns the best generated sequence of token IDs.
310
+ """
311
+ beams = [(prompt_ids, 0.0)] # (sequence, cumulative log probability)
312
+
313
+ for step in range(max_length):
314
+ all_candidates = []
315
+ print(f"\n--- Generation Step {step+1} ---")
316
+ for seq, score in beams:
317
+ input_seq = jnp.array(seq)[None, :] # shape (1, seq_length)
318
+ logits = model.apply(params, input_seq, deterministic=True)
319
+ logits_last = logits[0, -1] # last token logits
320
+ probs = jax.nn.softmax(logits_last)
321
+ # Select top beam_size tokens for this beam.
322
+ top_indices = np.array(jnp.argsort(probs)[-beam_size:][::-1])
323
+ top_probs = np.array(probs[top_indices])
324
+ for token_idx, token_prob in zip(top_indices, top_probs):
325
+ token_id = int(token_idx)
326
+ token_str = tokenizer.decode([token_id]).strip()
327
+ print(f"Candidate token: '{token_str}' (ID: {token_id}) with probability: {token_prob:.4f}")
328
+ new_seq = seq + [token_id]
329
+ new_score = score + math.log(token_prob + 1e-10)
330
+ all_candidates.append((new_seq, new_score))
331
+ # Select top beam_size beams overall.
332
+ beams = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)[:beam_size]
333
+
334
+ print(f"\nBeams after generation step {step+1}:")
335
+ for beam in beams:
336
+ decoded = tokenizer.decode(beam[0])
337
+ print(f"Beam: {decoded} | Score: {beam[1]:.4f}")
338
+
339
+ # Check if best beam has ended.
340
+ best_seq, best_score = beams[0]
341
+ if best_seq[-1] == end_token_id:
342
+ break
343
+
344
+ print("\nFinal predicted beams:")
345
+ for beam in beams:
346
+ decoded = tokenizer.decode(beam[0])
347
+ print(f"Beam: {decoded} | Score: {beam[1]:.4f}")
348
+
349
+ best_seq, best_score = beams[0]
350
+ print("\nBest beam:", tokenizer.decode(best_seq))
351
+ return best_seq
352
+
353
+ # ---------------------------
354
+ # Interactive Chat Loop
355
+ # ---------------------------
356
+ def chat():
357
+ print("\nInteractive Chat (type 'exit' or 'quit' to end):")
358
+ while True:
359
+ user_input = input("\nUser: ").strip()
360
+ if user_input.lower() in ["exit", "quit"]:
361
+ break
362
+ if not user_input.startswith("<s>"):
363
+ user_input = "<s> " + user_input
364
+ prompt_ids = tokenizer.encode(user_input).ids
365
+ max_prompt_length = config["max_len"] - 1
366
+ if len(prompt_ids) > max_prompt_length:
367
+ prompt_ids = prompt_ids[-max_prompt_length:]
368
+ print("\nModel generating response using beam search...")
369
+ generated_ids = beam_search(saved_params, prompt_ids, model_instance,
370
+ beam_size=5, max_length=25, end_token_id=END_SEQ_TOKEN_ID)
371
+ generated_text = tokenizer.decode(generated_ids)
372
+ print("\nModel:", generated_text)
373
+
374
+ if __name__ == "__main__":
375
+ chat()