|
|
import os |
|
|
import math |
|
|
import numpy as np |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
import flax.linen as nn |
|
|
import flax.serialization |
|
|
from tokenizers import Tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TOKENIZER_PATH = "Path to tokenizer.json file" |
|
|
MODEL_PARAMS_SAVE_PATH = "Path to model file" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DTYPE = jnp.bfloat16 |
|
|
RMSNORM_EPS = 1e-05 |
|
|
dense_init = nn.initializers.normal(stddev=0.02) |
|
|
CTX_LEN = 2048 |
|
|
NUM_KV_HEADS = 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = { |
|
|
"d_model": 768, |
|
|
"nhead": 16, |
|
|
"num_layers": 24, |
|
|
"ff_hidden_dim": 3072, |
|
|
"vocab_size": 49800, |
|
|
"max_len": 2048, |
|
|
"dropout_rate": 0.1, |
|
|
"window_layer_indices": [2, 5, 8, 11, 14, 17, 20, 23], |
|
|
"moe_layer_indices": [4, 9, 14, 19], |
|
|
"window_size": 512, |
|
|
"moe_params": {"num_experts": 4, "num_experts_per_tok": 2}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
epsilon: float = RMSNORM_EPS |
|
|
dtype: any = DTYPE |
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
dim = x.shape[-1] |
|
|
scale = self.param("scale", nn.initializers.ones, (dim,)) |
|
|
norm = jnp.sqrt(jnp.mean(x ** 2, axis=-1, keepdims=True) + self.epsilon) |
|
|
return (x / norm) * scale |
|
|
|
|
|
class RoPE(nn.Module): |
|
|
d_model: int |
|
|
max_len: int |
|
|
dtype: any = DTYPE |
|
|
def setup(self): |
|
|
self.inv_freq = 1.0 / (10000.0 ** (jnp.arange(0, self.d_model, 2, dtype=jnp.float32) / self.d_model)) |
|
|
def __call__(self, x): |
|
|
seq_len = x.shape[-2] |
|
|
pos = jnp.arange(seq_len, dtype=jnp.float32)[None, None, :, None] |
|
|
inv_freq = self.inv_freq[None, None, None, :] |
|
|
freqs = pos * inv_freq |
|
|
cos = jnp.cos(freqs).astype(self.dtype) |
|
|
sin = jnp.sin(freqs).astype(self.dtype) |
|
|
x1 = x[..., ::2] |
|
|
x2 = x[..., 1::2] |
|
|
return jnp.concatenate([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
d_model: int |
|
|
hidden_dim: int |
|
|
dropout_rate: float |
|
|
dtype: any = DTYPE |
|
|
@nn.compact |
|
|
def __call__(self, x, deterministic: bool = True): |
|
|
proj = nn.Dense(self.hidden_dim * 2, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x) |
|
|
x1, x2 = jnp.split(proj, 2, axis=-1) |
|
|
x_act = x1 * nn.silu(x2) |
|
|
x_act = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x_act) |
|
|
return nn.Dropout(rate=self.dropout_rate)(x_act, deterministic=deterministic) |
|
|
|
|
|
class ExpertFFN(nn.Module): |
|
|
d_model: int |
|
|
hidden_dim: int |
|
|
dropout_rate: float |
|
|
dtype: any = DTYPE |
|
|
@nn.compact |
|
|
def __call__(self, x, deterministic: bool = True): |
|
|
hidden = nn.Dense(self.hidden_dim, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x) |
|
|
hidden = nn.silu(hidden) |
|
|
out = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(hidden) |
|
|
return out |
|
|
|
|
|
class MoEFeedForward(nn.Module): |
|
|
d_model: int |
|
|
hidden_dim: int |
|
|
dropout_rate: float |
|
|
num_experts: int = 4 |
|
|
num_experts_per_tok: int = 2 |
|
|
dtype: any = DTYPE |
|
|
@nn.compact |
|
|
def __call__(self, x, deterministic: bool = True): |
|
|
gate_logits = nn.Dense(self.num_experts, use_bias=False, dtype=self.dtype)(x) |
|
|
gate_scores = nn.softmax(gate_logits, axis=-1) |
|
|
expert_ffn = nn.vmap(ExpertFFN, |
|
|
variable_axes={'params': 0}, |
|
|
split_rngs={'params': True}, |
|
|
in_axes=0, |
|
|
out_axes=0)(d_model=self.d_model, |
|
|
hidden_dim=self.hidden_dim, |
|
|
dropout_rate=self.dropout_rate, |
|
|
dtype=self.dtype) |
|
|
x_expert = jnp.broadcast_to(x, (self.num_experts,) + x.shape) |
|
|
experts = expert_ffn(x_expert) |
|
|
gate_scores = jnp.transpose(gate_scores, (2, 0, 1))[..., None] |
|
|
moe_output = jnp.sum(experts * gate_scores, axis=0) |
|
|
moe_output = nn.Dropout(rate=self.dropout_rate)(moe_output, deterministic=deterministic) |
|
|
return moe_output |
|
|
|
|
|
class LLaMAAttention(nn.Module): |
|
|
d_model: int |
|
|
nhead: int |
|
|
num_kv_heads: int |
|
|
dropout_rate: float |
|
|
dtype: any = DTYPE |
|
|
use_sliding_window: bool = False |
|
|
window_size: int = 512 |
|
|
def setup(self): |
|
|
self.head_dim = self.d_model // self.nhead |
|
|
self.q_proj = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype) |
|
|
self.kv_proj = nn.Dense(2 * (self.num_kv_heads * self.head_dim), |
|
|
use_bias=False, kernel_init=dense_init, dtype=self.dtype) |
|
|
self.out_proj = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype) |
|
|
self.dropout = nn.Dropout(rate=self.dropout_rate) |
|
|
self.rope = RoPE(d_model=self.head_dim, max_len=CTX_LEN, dtype=self.dtype) |
|
|
self.layer_scale_attn = self.param("layer_scale_attn", nn.initializers.constant(0.1), (self.d_model,)) |
|
|
def __call__(self, x, deterministic: bool = True): |
|
|
B, T, _ = x.shape |
|
|
q = self.q_proj(x).reshape(B, T, self.nhead, self.head_dim) |
|
|
kv = self.kv_proj(x).reshape(B, T, self.num_kv_heads, 2 * self.head_dim) |
|
|
k, v = jnp.split(kv, 2, axis=-1) |
|
|
group_factor = self.nhead // self.num_kv_heads |
|
|
k = jnp.repeat(k, repeats=group_factor, axis=2) |
|
|
v = jnp.repeat(v, repeats=group_factor, axis=2) |
|
|
q = jnp.transpose(q, (0, 2, 1, 3)) |
|
|
k = jnp.transpose(k, (0, 2, 1, 3)) |
|
|
q = self.rope(q) |
|
|
k = self.rope(k) |
|
|
q = jnp.transpose(q, (0, 2, 1, 3)) |
|
|
k = jnp.transpose(k, (0, 2, 1, 3)) |
|
|
attn_weights = jnp.einsum("bthd,bThd->bthT", q, k) / jnp.sqrt(self.head_dim) |
|
|
if self.use_sliding_window: |
|
|
i = jnp.arange(T)[:, None] |
|
|
j = jnp.arange(T)[None, :] |
|
|
sliding_mask = (i - j < self.window_size) & (i >= j) |
|
|
sliding_mask = sliding_mask[None, :, None, :] |
|
|
attn_weights = jnp.where(sliding_mask, attn_weights, -1e10) |
|
|
else: |
|
|
causal_mask = jnp.tril(jnp.ones((T, T), dtype=bool))[None, :, None, :] |
|
|
attn_weights = jnp.where(causal_mask, attn_weights, -1e10) |
|
|
attn_probs = nn.softmax(attn_weights, axis=-1) |
|
|
attn_probs = self.dropout(attn_probs, deterministic=deterministic) |
|
|
attn_output = jnp.einsum("bthT,bThd->bthd", attn_probs, v) |
|
|
attn_output = attn_output.reshape(B, T, self.d_model) |
|
|
output = self.out_proj(attn_output) |
|
|
output = self.dropout(output, deterministic=deterministic) |
|
|
return output * self.layer_scale_attn |
|
|
|
|
|
class TransformerLayer(nn.Module): |
|
|
d_model: int |
|
|
nhead: int |
|
|
ff_hidden_dim: int |
|
|
dropout_rate: float |
|
|
dtype: any = DTYPE |
|
|
use_sliding_window: bool = False |
|
|
window_size: int = 512 |
|
|
use_moe: bool = False |
|
|
moe_params: dict = None |
|
|
def setup(self): |
|
|
self.attn_norm = RMSNorm(dtype=self.dtype) |
|
|
self.attn = LLaMAAttention( |
|
|
d_model=self.d_model, |
|
|
nhead=self.nhead, |
|
|
num_kv_heads=NUM_KV_HEADS, |
|
|
dropout_rate=0.0, |
|
|
dtype=self.dtype, |
|
|
use_sliding_window=self.use_sliding_window, |
|
|
window_size=self.window_size |
|
|
) |
|
|
self.ff_norm = RMSNorm(dtype=self.dtype) |
|
|
if self.use_moe: |
|
|
self.ff = MoEFeedForward( |
|
|
d_model=self.d_model, |
|
|
hidden_dim=self.ff_hidden_dim, |
|
|
dropout_rate=self.dropout_rate, |
|
|
num_experts=self.moe_params.get("num_experts", 4) if self.moe_params else 4, |
|
|
num_experts_per_tok=self.moe_params.get("num_experts_per_tok", 2) if self.moe_params else 2, |
|
|
dtype=self.dtype |
|
|
) |
|
|
else: |
|
|
self.ff = FeedForward( |
|
|
d_model=self.d_model, |
|
|
hidden_dim=self.ff_hidden_dim, |
|
|
dropout_rate=self.dropout_rate, |
|
|
dtype=self.dtype |
|
|
) |
|
|
self.layer_scale_ff = self.param("layer_scale_ff", nn.initializers.constant(0.1), (self.d_model,)) |
|
|
def __call__(self, x, deterministic: bool = True): |
|
|
x = x + self.attn(self.attn_norm(x), deterministic=deterministic) |
|
|
x = x + self.ff(self.ff_norm(x), deterministic=deterministic) * self.layer_scale_ff |
|
|
return x |
|
|
|
|
|
class DeepSeekModel(nn.Module): |
|
|
vocab_size: int |
|
|
d_model: int |
|
|
nhead: int |
|
|
num_layers: int |
|
|
ff_hidden_dim: int |
|
|
max_len: int |
|
|
dropout_rate: float |
|
|
dtype: any = DTYPE |
|
|
window_layer_indices: list = None |
|
|
moe_layer_indices: list = None |
|
|
window_size: int = 512 |
|
|
moe_params: dict = None |
|
|
def setup(self): |
|
|
self.embed = nn.Embed( |
|
|
num_embeddings=self.vocab_size, |
|
|
features=self.d_model, |
|
|
embedding_init=dense_init, |
|
|
dtype=self.dtype |
|
|
) |
|
|
self.layers = [ |
|
|
TransformerLayer( |
|
|
d_model=self.d_model, |
|
|
nhead=self.nhead, |
|
|
ff_hidden_dim=self.ff_hidden_dim, |
|
|
dropout_rate=self.dropout_rate, |
|
|
dtype=self.dtype, |
|
|
use_sliding_window=(self.window_layer_indices is not None and i in self.window_layer_indices), |
|
|
window_size=self.window_size, |
|
|
use_moe=(self.moe_layer_indices is not None and i in self.moe_layer_indices), |
|
|
moe_params=self.moe_params |
|
|
) |
|
|
for i in range(self.num_layers) |
|
|
] |
|
|
self.norm = RMSNorm(dtype=self.dtype) |
|
|
def __call__(self, input_ids, deterministic: bool = True): |
|
|
x = self.embed(input_ids) |
|
|
for layer in self.layers: |
|
|
x = layer(x, deterministic=deterministic) |
|
|
x = self.norm(x) |
|
|
logits = x @ self.embed.embedding.T |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = Tokenizer.from_file(TOKENIZER_PATH) |
|
|
PAD_TOKEN_ID = tokenizer.token_to_id("<pad>") |
|
|
START_TOKEN_ID = tokenizer.token_to_id("<s>") |
|
|
END_SEQ_TOKEN_ID = tokenizer.token_to_id("</s>") |
|
|
|
|
|
model_instance = DeepSeekModel( |
|
|
vocab_size=config["vocab_size"], |
|
|
d_model=config["d_model"], |
|
|
nhead=config["nhead"], |
|
|
num_layers=config["num_layers"], |
|
|
ff_hidden_dim=config["ff_hidden_dim"], |
|
|
max_len=config["max_len"], |
|
|
dropout_rate=config["dropout_rate"], |
|
|
dtype=DTYPE, |
|
|
window_layer_indices=config["window_layer_indices"], |
|
|
moe_layer_indices=config["moe_layer_indices"], |
|
|
window_size=config["window_size"], |
|
|
moe_params=config["moe_params"] |
|
|
) |
|
|
|
|
|
dummy_input = jnp.ones((1, config["max_len"] - 1), dtype=jnp.int32) |
|
|
rng = jax.random.PRNGKey(0) |
|
|
init_params = model_instance.init(rng, dummy_input, deterministic=True) |
|
|
|
|
|
with open(MODEL_PARAMS_SAVE_PATH, "rb") as f: |
|
|
saved_params_bytes = f.read() |
|
|
saved_params = flax.serialization.from_bytes(init_params, saved_params_bytes) |
|
|
print("Loaded model parameters.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def temperature_sample(params, prompt_ids, model, max_length=15, temperature=0.7, top_p=0.9, end_token_id=END_SEQ_TOKEN_ID): |
|
|
""" |
|
|
Generates text token-by-token using temperature scaling and nucleus (top-p) sampling. |
|
|
|
|
|
Args: |
|
|
params: Model parameters. |
|
|
prompt_ids: List of token IDs for the prompt. |
|
|
model: The language model. |
|
|
max_length: Maximum number of tokens to generate. |
|
|
temperature: Temperature for scaling logits. |
|
|
top_p: Nucleus sampling threshold. |
|
|
end_token_id: End-of-sequence token ID. |
|
|
|
|
|
Returns: |
|
|
A list of token IDs representing the generated text. |
|
|
""" |
|
|
generated = list(prompt_ids) |
|
|
for step in range(max_length): |
|
|
input_seq = jnp.array(generated)[None, :] |
|
|
logits = model.apply(params, input_seq, deterministic=True) |
|
|
logits_last = logits[0, -1] |
|
|
scaled_logits = logits_last / temperature |
|
|
probs = jax.nn.softmax(scaled_logits) |
|
|
|
|
|
probs_np = np.array(probs) |
|
|
sorted_indices = np.argsort(probs_np)[::-1] |
|
|
sorted_probs = probs_np[sorted_indices] |
|
|
cumulative_probs = np.cumsum(sorted_probs) |
|
|
cutoff_idx = np.where(cumulative_probs > top_p)[0] |
|
|
cutoff = cutoff_idx[0] + 1 if len(cutoff_idx) > 0 else len(sorted_probs) |
|
|
nucleus_indices = sorted_indices[:cutoff] |
|
|
nucleus_probs = sorted_probs[:cutoff] |
|
|
nucleus_probs /= np.sum(nucleus_probs) |
|
|
|
|
|
token_id = int(np.random.choice(nucleus_indices, p=nucleus_probs)) |
|
|
generated.append(token_id) |
|
|
|
|
|
token_str = tokenizer.decode([token_id]).strip() |
|
|
print(f"Step {step+1}: Generated token '{token_str}' (ID: {token_id})") |
|
|
|
|
|
if token_id == end_token_id: |
|
|
break |
|
|
return generated |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat(): |
|
|
print("\nInteractive Chat (type 'exit' or 'quit' to end):") |
|
|
while True: |
|
|
user_input = input("\nUser: ").strip() |
|
|
if user_input.lower() in ["exit", "quit"]: |
|
|
break |
|
|
if not user_input.startswith("<s>"): |
|
|
user_input = "<s> " + user_input |
|
|
prompt_ids = tokenizer.encode(user_input).ids |
|
|
max_prompt_length = config["max_len"] - 1 |
|
|
if len(prompt_ids) > max_prompt_length: |
|
|
prompt_ids = prompt_ids[-max_prompt_length:] |
|
|
|
|
|
print("\nModel generating response using temperature sampling (temp=0.7, top-p=0.9, max tokens=15)...") |
|
|
generated_ids = temperature_sample( |
|
|
saved_params, prompt_ids, model_instance, |
|
|
max_length=15, temperature=0.7, top_p=0.9, end_token_id=END_SEQ_TOKEN_ID |
|
|
) |
|
|
generated_text = tokenizer.decode(generated_ids) |
|
|
print("\nModel:", generated_text) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
chat() |
|
|
|