Celestia / usage.py
Naqeeb-2424's picture
Update usage.py
d1a8524 verified
import os
import math
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import flax.serialization
from tokenizers import Tokenizer
# ---------------------------
# Constants and File Paths
# ---------------------------
TOKENIZER_PATH = "Path to tokenizer.json file"
MODEL_PARAMS_SAVE_PATH = "Path to model file"
# ---------------------------
# Global Definitions
# ---------------------------
DTYPE = jnp.bfloat16
RMSNORM_EPS = 1e-05
dense_init = nn.initializers.normal(stddev=0.02)
CTX_LEN = 2048
NUM_KV_HEADS = 4
# ---------------------------
# Configuration Values (from provided config)
# ---------------------------
config = {
"d_model": 768,
"nhead": 16,
"num_layers": 24,
"ff_hidden_dim": 3072,
"vocab_size": 49800,
"max_len": 2048,
"dropout_rate": 0.1,
"window_layer_indices": [2, 5, 8, 11, 14, 17, 20, 23],
"moe_layer_indices": [4, 9, 14, 19],
"window_size": 512,
"moe_params": {"num_experts": 4, "num_experts_per_tok": 2},
}
# ---------------------------
# Custom Modules (Updated Architecture)
# ---------------------------
class RMSNorm(nn.Module):
epsilon: float = RMSNORM_EPS
dtype: any = DTYPE
@nn.compact
def __call__(self, x):
dim = x.shape[-1]
scale = self.param("scale", nn.initializers.ones, (dim,))
norm = jnp.sqrt(jnp.mean(x ** 2, axis=-1, keepdims=True) + self.epsilon)
return (x / norm) * scale
class RoPE(nn.Module):
d_model: int
max_len: int
dtype: any = DTYPE
def setup(self):
self.inv_freq = 1.0 / (10000.0 ** (jnp.arange(0, self.d_model, 2, dtype=jnp.float32) / self.d_model))
def __call__(self, x):
seq_len = x.shape[-2]
pos = jnp.arange(seq_len, dtype=jnp.float32)[None, None, :, None]
inv_freq = self.inv_freq[None, None, None, :]
freqs = pos * inv_freq
cos = jnp.cos(freqs).astype(self.dtype)
sin = jnp.sin(freqs).astype(self.dtype)
x1 = x[..., ::2]
x2 = x[..., 1::2]
return jnp.concatenate([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1)
class FeedForward(nn.Module):
d_model: int
hidden_dim: int
dropout_rate: float
dtype: any = DTYPE
@nn.compact
def __call__(self, x, deterministic: bool = True):
proj = nn.Dense(self.hidden_dim * 2, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x)
x1, x2 = jnp.split(proj, 2, axis=-1)
x_act = x1 * nn.silu(x2)
x_act = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x_act)
return nn.Dropout(rate=self.dropout_rate)(x_act, deterministic=deterministic)
class ExpertFFN(nn.Module):
d_model: int
hidden_dim: int
dropout_rate: float
dtype: any = DTYPE
@nn.compact
def __call__(self, x, deterministic: bool = True):
hidden = nn.Dense(self.hidden_dim, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(x)
hidden = nn.silu(hidden)
out = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)(hidden)
return out
class MoEFeedForward(nn.Module):
d_model: int
hidden_dim: int
dropout_rate: float
num_experts: int = 4
num_experts_per_tok: int = 2
dtype: any = DTYPE
@nn.compact
def __call__(self, x, deterministic: bool = True):
gate_logits = nn.Dense(self.num_experts, use_bias=False, dtype=self.dtype)(x)
gate_scores = nn.softmax(gate_logits, axis=-1)
expert_ffn = nn.vmap(ExpertFFN,
variable_axes={'params': 0},
split_rngs={'params': True},
in_axes=0,
out_axes=0)(d_model=self.d_model,
hidden_dim=self.hidden_dim,
dropout_rate=self.dropout_rate,
dtype=self.dtype)
x_expert = jnp.broadcast_to(x, (self.num_experts,) + x.shape)
experts = expert_ffn(x_expert)
gate_scores = jnp.transpose(gate_scores, (2, 0, 1))[..., None]
moe_output = jnp.sum(experts * gate_scores, axis=0)
moe_output = nn.Dropout(rate=self.dropout_rate)(moe_output, deterministic=deterministic)
return moe_output
class LLaMAAttention(nn.Module):
d_model: int
nhead: int
num_kv_heads: int
dropout_rate: float
dtype: any = DTYPE
use_sliding_window: bool = False
window_size: int = 512
def setup(self):
self.head_dim = self.d_model // self.nhead
self.q_proj = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)
self.kv_proj = nn.Dense(2 * (self.num_kv_heads * self.head_dim),
use_bias=False, kernel_init=dense_init, dtype=self.dtype)
self.out_proj = nn.Dense(self.d_model, use_bias=False, kernel_init=dense_init, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.dropout_rate)
self.rope = RoPE(d_model=self.head_dim, max_len=CTX_LEN, dtype=self.dtype)
self.layer_scale_attn = self.param("layer_scale_attn", nn.initializers.constant(0.1), (self.d_model,))
def __call__(self, x, deterministic: bool = True):
B, T, _ = x.shape
q = self.q_proj(x).reshape(B, T, self.nhead, self.head_dim)
kv = self.kv_proj(x).reshape(B, T, self.num_kv_heads, 2 * self.head_dim)
k, v = jnp.split(kv, 2, axis=-1)
group_factor = self.nhead // self.num_kv_heads
k = jnp.repeat(k, repeats=group_factor, axis=2)
v = jnp.repeat(v, repeats=group_factor, axis=2)
q = jnp.transpose(q, (0, 2, 1, 3))
k = jnp.transpose(k, (0, 2, 1, 3))
q = self.rope(q)
k = self.rope(k)
q = jnp.transpose(q, (0, 2, 1, 3))
k = jnp.transpose(k, (0, 2, 1, 3))
attn_weights = jnp.einsum("bthd,bThd->bthT", q, k) / jnp.sqrt(self.head_dim)
if self.use_sliding_window:
i = jnp.arange(T)[:, None]
j = jnp.arange(T)[None, :]
sliding_mask = (i - j < self.window_size) & (i >= j)
sliding_mask = sliding_mask[None, :, None, :]
attn_weights = jnp.where(sliding_mask, attn_weights, -1e10)
else:
causal_mask = jnp.tril(jnp.ones((T, T), dtype=bool))[None, :, None, :]
attn_weights = jnp.where(causal_mask, attn_weights, -1e10)
attn_probs = nn.softmax(attn_weights, axis=-1)
attn_probs = self.dropout(attn_probs, deterministic=deterministic)
attn_output = jnp.einsum("bthT,bThd->bthd", attn_probs, v)
attn_output = attn_output.reshape(B, T, self.d_model)
output = self.out_proj(attn_output)
output = self.dropout(output, deterministic=deterministic)
return output * self.layer_scale_attn
class TransformerLayer(nn.Module):
d_model: int
nhead: int
ff_hidden_dim: int
dropout_rate: float
dtype: any = DTYPE
use_sliding_window: bool = False
window_size: int = 512
use_moe: bool = False
moe_params: dict = None
def setup(self):
self.attn_norm = RMSNorm(dtype=self.dtype)
self.attn = LLaMAAttention(
d_model=self.d_model,
nhead=self.nhead,
num_kv_heads=NUM_KV_HEADS,
dropout_rate=0.0,
dtype=self.dtype,
use_sliding_window=self.use_sliding_window,
window_size=self.window_size
)
self.ff_norm = RMSNorm(dtype=self.dtype)
if self.use_moe:
self.ff = MoEFeedForward(
d_model=self.d_model,
hidden_dim=self.ff_hidden_dim,
dropout_rate=self.dropout_rate,
num_experts=self.moe_params.get("num_experts", 4) if self.moe_params else 4,
num_experts_per_tok=self.moe_params.get("num_experts_per_tok", 2) if self.moe_params else 2,
dtype=self.dtype
)
else:
self.ff = FeedForward(
d_model=self.d_model,
hidden_dim=self.ff_hidden_dim,
dropout_rate=self.dropout_rate,
dtype=self.dtype
)
self.layer_scale_ff = self.param("layer_scale_ff", nn.initializers.constant(0.1), (self.d_model,))
def __call__(self, x, deterministic: bool = True):
x = x + self.attn(self.attn_norm(x), deterministic=deterministic)
x = x + self.ff(self.ff_norm(x), deterministic=deterministic) * self.layer_scale_ff
return x
class DeepSeekModel(nn.Module):
vocab_size: int
d_model: int
nhead: int
num_layers: int
ff_hidden_dim: int
max_len: int
dropout_rate: float
dtype: any = DTYPE
window_layer_indices: list = None
moe_layer_indices: list = None
window_size: int = 512
moe_params: dict = None
def setup(self):
self.embed = nn.Embed(
num_embeddings=self.vocab_size,
features=self.d_model,
embedding_init=dense_init,
dtype=self.dtype
)
self.layers = [
TransformerLayer(
d_model=self.d_model,
nhead=self.nhead,
ff_hidden_dim=self.ff_hidden_dim,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
use_sliding_window=(self.window_layer_indices is not None and i in self.window_layer_indices),
window_size=self.window_size,
use_moe=(self.moe_layer_indices is not None and i in self.moe_layer_indices),
moe_params=self.moe_params
)
for i in range(self.num_layers)
]
self.norm = RMSNorm(dtype=self.dtype)
def __call__(self, input_ids, deterministic: bool = True):
x = self.embed(input_ids)
for layer in self.layers:
x = layer(x, deterministic=deterministic)
x = self.norm(x)
logits = x @ self.embed.embedding.T
return logits
# ---------------------------
# Load Tokenizer and Model Parameters
# ---------------------------
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
PAD_TOKEN_ID = tokenizer.token_to_id("<pad>")
START_TOKEN_ID = tokenizer.token_to_id("<s>")
END_SEQ_TOKEN_ID = tokenizer.token_to_id("</s>")
model_instance = DeepSeekModel(
vocab_size=config["vocab_size"],
d_model=config["d_model"],
nhead=config["nhead"],
num_layers=config["num_layers"],
ff_hidden_dim=config["ff_hidden_dim"],
max_len=config["max_len"],
dropout_rate=config["dropout_rate"],
dtype=DTYPE,
window_layer_indices=config["window_layer_indices"],
moe_layer_indices=config["moe_layer_indices"],
window_size=config["window_size"],
moe_params=config["moe_params"]
)
dummy_input = jnp.ones((1, config["max_len"] - 1), dtype=jnp.int32)
rng = jax.random.PRNGKey(0)
init_params = model_instance.init(rng, dummy_input, deterministic=True)
with open(MODEL_PARAMS_SAVE_PATH, "rb") as f:
saved_params_bytes = f.read()
saved_params = flax.serialization.from_bytes(init_params, saved_params_bytes)
print("Loaded model parameters.")
# ---------------------------
# Temperature Sampling Function with Fixed Parameters
# ---------------------------
def temperature_sample(params, prompt_ids, model, max_length=15, temperature=0.7, top_p=0.9, end_token_id=END_SEQ_TOKEN_ID):
"""
Generates text token-by-token using temperature scaling and nucleus (top-p) sampling.
Args:
params: Model parameters.
prompt_ids: List of token IDs for the prompt.
model: The language model.
max_length: Maximum number of tokens to generate.
temperature: Temperature for scaling logits.
top_p: Nucleus sampling threshold.
end_token_id: End-of-sequence token ID.
Returns:
A list of token IDs representing the generated text.
"""
generated = list(prompt_ids)
for step in range(max_length):
input_seq = jnp.array(generated)[None, :]
logits = model.apply(params, input_seq, deterministic=True)
logits_last = logits[0, -1]
scaled_logits = logits_last / temperature
probs = jax.nn.softmax(scaled_logits)
probs_np = np.array(probs)
sorted_indices = np.argsort(probs_np)[::-1]
sorted_probs = probs_np[sorted_indices]
cumulative_probs = np.cumsum(sorted_probs)
cutoff_idx = np.where(cumulative_probs > top_p)[0]
cutoff = cutoff_idx[0] + 1 if len(cutoff_idx) > 0 else len(sorted_probs)
nucleus_indices = sorted_indices[:cutoff]
nucleus_probs = sorted_probs[:cutoff]
nucleus_probs /= np.sum(nucleus_probs)
token_id = int(np.random.choice(nucleus_indices, p=nucleus_probs))
generated.append(token_id)
token_str = tokenizer.decode([token_id]).strip()
print(f"Step {step+1}: Generated token '{token_str}' (ID: {token_id})")
if token_id == end_token_id:
break
return generated
# ---------------------------
# Interactive Chat Loop using Fixed Temperature Sampling
# ---------------------------
def chat():
print("\nInteractive Chat (type 'exit' or 'quit' to end):")
while True:
user_input = input("\nUser: ").strip()
if user_input.lower() in ["exit", "quit"]:
break
if not user_input.startswith("<s>"):
user_input = "<s> " + user_input
prompt_ids = tokenizer.encode(user_input).ids
max_prompt_length = config["max_len"] - 1
if len(prompt_ids) > max_prompt_length:
prompt_ids = prompt_ids[-max_prompt_length:]
print("\nModel generating response using temperature sampling (temp=0.7, top-p=0.9, max tokens=15)...")
generated_ids = temperature_sample(
saved_params, prompt_ids, model_instance,
max_length=15, temperature=0.7, top_p=0.9, end_token_id=END_SEQ_TOKEN_ID
)
generated_text = tokenizer.decode(generated_ids)
print("\nModel:", generated_text)
if __name__ == "__main__":
chat()