import os import math import numpy as np import jax import jax.numpy as jnp import flax.linen as nn import flax.serialization from tokenizers import Tokenizer # --------------------------- # Constants and File Paths # --------------------------- TOKENIZER_PATH = "Path to tokenizer.json file" MODEL_PARAMS_SAVE_PATH = "Path to model file" # --------------------------- # Global Definitions # --------------------------- DTYPE = jnp.bfloat16 RMSNORM_EPS = 1e-05 dense_init = nn.initializers.normal(stddev=0.02) CTX_LEN = 2048 NUM_KV_HEADS = 4 # --------------------------- # Configuration Values (from provided config) # --------------------------- config = { "d_model": 768, "nhead": 16, "num_layers": 24, "ff_hidden_dim": 3072, "vocab_size": 49800, "max_len": 2048, "dropout_rate": 0.1, "window_layer_indices": [2, 5, 8, 11, 14, 17, 20, 23], "moe_layer_indices": [4, 9, 14, 19], "window_size": 512, "moe_params": {"num_experts": 4, "num_experts_per_tok": 2}, } # --------------------------- # Custom Modules (Updated Architecture) # --------------------------- class RMSNorm(nn.Module): epsilon: float = RMSNORM_EPS dtype: any = DTYPE @nn.compact def __call__(self, x): dim = x.shape[-1] scale = self.param("scale", nn.initializers.ones, (dim,)) norm = jnp.sqrt(jnp.mean(x ** 2, axis=-1, keepdims=True) + self.epsilon) return (x / norm) * scale class RoPE(nn.Module): d_model: int max_len: int dtype: any = DTYPE def setup(self): self.inv_freq = 1.0 / (10000.0 ** (jnp.arange(0, self.d_model, 2, dtype=jnp.float32) / self.d_model)) def __call__(self, x): seq_len = x.shape[-2] pos = jnp.arange(seq_len, dtype=jnp.float32)[None, None, :, None] inv_freq = self.inv_freq[None, None, None, :] freqs = pos * inv_freq cos = jnp.cos(freqs).astype(self.dtype) sin = jnp.sin(freqs).astype(self.dtype) x1 = x[..., ::2] x2 = x[..., 1::2] return jnp.concatenate([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1) class FeedForward(nn.Module): d_model: int hidden_dim: int dropout_rate: float dtype: any = DTYPE @nn.compact def __call__(self, x, deterministic: bool = True): proj = nn.Dense(self.hidden_dim * 2, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x) x1, x2 = jnp.split(proj, 2, axis=-1) x_act = x1 * nn.silu(x2) x_act = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x_act) return nn.Dropout(rate=self.dropout_rate)(x_act, deterministic=deterministic) class ExpertFFN(nn.Module): d_model: int hidden_dim: int dropout_rate: float dtype: any = DTYPE @nn.compact def __call__(self, x, deterministic: bool = True): hidden = nn.Dense(self.hidden_dim, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x) hidden = nn.silu(hidden) out = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(hidden) return out class MoEFeedForward(nn.Module): d_model: int hidden_dim: int dropout_rate: float num_experts: int = 4 num_experts_per_tok: int = 2 dtype: any = DTYPE @nn.compact def __call__(self, x, deterministic: bool = True): gate_logits = nn.Dense(self.num_experts, use_bias=False, dtype=self.dtype)(x) gate_scores = nn.softmax(gate_logits, axis=-1) expert_ffn = nn.vmap(ExpertFFN, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0, out_axes=0)(d_model=self.d_model, hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate, dtype=self.dtype) x_expert = jnp.broadcast_to(x, (self.num_experts,) + x.shape) experts = expert_ffn(x_expert) gate_scores = jnp.transpose(gate_scores, (2, 0, 1))[..., None] moe_output = jnp.sum(experts * gate_scores, axis=0) moe_output = nn.Dropout(rate=self.dropout_rate)(moe_output, deterministic=deterministic) return moe_output class LLaMAAttention(nn.Module): d_model: int nhead: int num_kv_heads: int dropout_rate: float dtype: any = DTYPE use_sliding_window: bool = False window_size: int = 512 def setup(self): self.head_dim = self.d_model // self.nhead self.q_proj = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype) self.kv_proj = nn.Dense(2 * (self.num_kv_heads * self.head_dim), use_bias=False, kernel_init=dense_init, dtype=self.dtype) self.out_proj = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.dropout_rate) self.rope = RoPE(d_model=self.head_dim, max_len=CTX_LEN, dtype=self.dtype) self.layer_scale_attn = self.param("layer_scale_attn", nn.initializers.constant(0.1), (self.d_model,)) def __call__(self, x, deterministic: bool = True): B, T, _ = x.shape q = self.q_proj(x).reshape(B, T, self.nhead, self.head_dim) kv = self.kv_proj(x).reshape(B, T, self.num_kv_heads, 2 * self.head_dim) k, v = jnp.split(kv, 2, axis=-1) group_factor = self.nhead // self.num_kv_heads k = jnp.repeat(k, repeats=group_factor, axis=2) v = jnp.repeat(v, repeats=group_factor, axis=2) q = jnp.transpose(q, (0, 2, 1, 3)) k = jnp.transpose(k, (0, 2, 1, 3)) q = self.rope(q) k = self.rope(k) q = jnp.transpose(q, (0, 2, 1, 3)) k = jnp.transpose(k, (0, 2, 1, 3)) attn_weights = jnp.einsum("bthd,bThd->bthT", q, k) / jnp.sqrt(self.head_dim) if self.use_sliding_window: i = jnp.arange(T)[:, None] j = jnp.arange(T)[None, :] sliding_mask = (i - j < self.window_size) & (i >= j) sliding_mask = sliding_mask[None, :, None, :] attn_weights = jnp.where(sliding_mask, attn_weights, -1e10) else: causal_mask = jnp.tril(jnp.ones((T, T), dtype=bool))[None, :, None, :] attn_weights = jnp.where(causal_mask, attn_weights, -1e10) attn_probs = nn.softmax(attn_weights, axis=-1) attn_probs = self.dropout(attn_probs, deterministic=deterministic) attn_output = jnp.einsum("bthT,bThd->bthd", attn_probs, v) attn_output = attn_output.reshape(B, T, self.d_model) output = self.out_proj(attn_output) output = self.dropout(output, deterministic=deterministic) return output * self.layer_scale_attn class TransformerLayer(nn.Module): d_model: int nhead: int ff_hidden_dim: int dropout_rate: float dtype: any = DTYPE use_sliding_window: bool = False window_size: int = 512 use_moe: bool = False moe_params: dict = None def setup(self): self.attn_norm = RMSNorm(dtype=self.dtype) self.attn = LLaMAAttention( d_model=self.d_model, nhead=self.nhead, num_kv_heads=NUM_KV_HEADS, dropout_rate=0.0, dtype=self.dtype, use_sliding_window=self.use_sliding_window, window_size=self.window_size ) self.ff_norm = RMSNorm(dtype=self.dtype) if self.use_moe: self.ff = MoEFeedForward( d_model=self.d_model, hidden_dim=self.ff_hidden_dim, dropout_rate=self.dropout_rate, num_experts=self.moe_params.get("num_experts", 4) if self.moe_params else 4, num_experts_per_tok=self.moe_params.get("num_experts_per_tok", 2) if self.moe_params else 2, dtype=self.dtype ) else: self.ff = FeedForward( d_model=self.d_model, hidden_dim=self.ff_hidden_dim, dropout_rate=self.dropout_rate, dtype=self.dtype ) self.layer_scale_ff = self.param("layer_scale_ff", nn.initializers.constant(0.1), (self.d_model,)) def __call__(self, x, deterministic: bool = True): x = x + self.attn(self.attn_norm(x), deterministic=deterministic) x = x + self.ff(self.ff_norm(x), deterministic=deterministic) * self.layer_scale_ff return x class DeepSeekModel(nn.Module): vocab_size: int d_model: int nhead: int num_layers: int ff_hidden_dim: int max_len: int dropout_rate: float dtype: any = DTYPE window_layer_indices: list = None moe_layer_indices: list = None window_size: int = 512 moe_params: dict = None def setup(self): self.embed = nn.Embed( num_embeddings=self.vocab_size, features=self.d_model, embedding_init=dense_init, dtype=self.dtype ) self.layers = [ TransformerLayer( d_model=self.d_model, nhead=self.nhead, ff_hidden_dim=self.ff_hidden_dim, dropout_rate=self.dropout_rate, dtype=self.dtype, use_sliding_window=(self.window_layer_indices is not None and i in self.window_layer_indices), window_size=self.window_size, use_moe=(self.moe_layer_indices is not None and i in self.moe_layer_indices), moe_params=self.moe_params ) for i in range(self.num_layers) ] self.norm = RMSNorm(dtype=self.dtype) def __call__(self, input_ids, deterministic: bool = True): x = self.embed(input_ids) for layer in self.layers: x = layer(x, deterministic=deterministic) x = self.norm(x) logits = x @ self.embed.embedding.T return logits # --------------------------- # Load Tokenizer and Model Parameters # --------------------------- tokenizer = Tokenizer.from_file(TOKENIZER_PATH) PAD_TOKEN_ID = tokenizer.token_to_id("") START_TOKEN_ID = tokenizer.token_to_id("") END_SEQ_TOKEN_ID = tokenizer.token_to_id("") model_instance = DeepSeekModel( vocab_size=config["vocab_size"], d_model=config["d_model"], nhead=config["nhead"], num_layers=config["num_layers"], ff_hidden_dim=config["ff_hidden_dim"], max_len=config["max_len"], dropout_rate=config["dropout_rate"], dtype=DTYPE, window_layer_indices=config["window_layer_indices"], moe_layer_indices=config["moe_layer_indices"], window_size=config["window_size"], moe_params=config["moe_params"] ) dummy_input = jnp.ones((1, config["max_len"] - 1), dtype=jnp.int32) rng = jax.random.PRNGKey(0) init_params = model_instance.init(rng, dummy_input, deterministic=True) with open(MODEL_PARAMS_SAVE_PATH, "rb") as f: saved_params_bytes = f.read() saved_params = flax.serialization.from_bytes(init_params, saved_params_bytes) print("Loaded model parameters.") # --------------------------- # Temperature Sampling Function with Fixed Parameters # --------------------------- def temperature_sample(params, prompt_ids, model, max_length=15, temperature=0.7, top_p=0.9, end_token_id=END_SEQ_TOKEN_ID): """ Generates text token-by-token using temperature scaling and nucleus (top-p) sampling. Args: params: Model parameters. prompt_ids: List of token IDs for the prompt. model: The language model. max_length: Maximum number of tokens to generate. temperature: Temperature for scaling logits. top_p: Nucleus sampling threshold. end_token_id: End-of-sequence token ID. Returns: A list of token IDs representing the generated text. """ generated = list(prompt_ids) for step in range(max_length): input_seq = jnp.array(generated)[None, :] logits = model.apply(params, input_seq, deterministic=True) logits_last = logits[0, -1] scaled_logits = logits_last / temperature probs = jax.nn.softmax(scaled_logits) probs_np = np.array(probs) sorted_indices = np.argsort(probs_np)[::-1] sorted_probs = probs_np[sorted_indices] cumulative_probs = np.cumsum(sorted_probs) cutoff_idx = np.where(cumulative_probs > top_p)[0] cutoff = cutoff_idx[0] + 1 if len(cutoff_idx) > 0 else len(sorted_probs) nucleus_indices = sorted_indices[:cutoff] nucleus_probs = sorted_probs[:cutoff] nucleus_probs /= np.sum(nucleus_probs) token_id = int(np.random.choice(nucleus_indices, p=nucleus_probs)) generated.append(token_id) token_str = tokenizer.decode([token_id]).strip() print(f"Step {step+1}: Generated token '{token_str}' (ID: {token_id})") if token_id == end_token_id: break return generated # --------------------------- # Interactive Chat Loop using Fixed Temperature Sampling # --------------------------- def chat(): print("\nInteractive Chat (type 'exit' or 'quit' to end):") while True: user_input = input("\nUser: ").strip() if user_input.lower() in ["exit", "quit"]: break if not user_input.startswith(""): user_input = " " + user_input prompt_ids = tokenizer.encode(user_input).ids max_prompt_length = config["max_len"] - 1 if len(prompt_ids) > max_prompt_length: prompt_ids = prompt_ids[-max_prompt_length:] print("\nModel generating response using temperature sampling (temp=0.7, top-p=0.9, max tokens=15)...") generated_ids = temperature_sample( saved_params, prompt_ids, model_instance, max_length=15, temperature=0.7, top_p=0.9, end_token_id=END_SEQ_TOKEN_ID ) generated_text = tokenizer.decode(generated_ids) print("\nModel:", generated_text) if __name__ == "__main__": chat()