#!/usr/bin/env python3 """ DPSNR Inference — Fully self-contained single-file GPU inference for the Large model. This file contains the ENTIRE model architecture, checkpoint loading, and generation logic. It has ZERO dependencies on the dpsn_r_jax package. Usage: source .venv/bin/activate # Single prompt python infer.py --prompt "Once upon a time" # Interactive mode (default) python infer.py # Adjust generation parameters python infer.py --prompt "The future of AI" --max_tokens 200 --temp 0.8 --top_k 50 """ import os import sys import time import argparse from dataclasses import dataclass, field from collections import namedtuple from typing import Any, Callable, Optional os.environ["TOKENIZERS_PARALLELISM"] = "false" import jax import jax.numpy as jnp import flax.linen as nn from jax import lax from flax.training import train_state from flax import struct, traverse_util import optax import orbax.checkpoint from functools import partial from transformers import AutoTokenizer # ═══════════════════════════════════════════════════════════════════════════════ # DEVICE # ═══════════════════════════════════════════════════════════════════════════════ DEVICE = jax.devices()[0] PLATFORM = DEVICE.platform print(f"[Device] {DEVICE} (platform: {PLATFORM})") # ═══════════════════════════════════════════════════════════════════════════════ # CONFIG — Large model, hardcoded # ═══════════════════════════════════════════════════════════════════════════════ TOKENIZER_NAME = "EleutherAI/gpt-neo-125M" CHECKPOINT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints_dir") @dataclass class PoolConfig: total_vectors: int hidden_dim: int @dataclass class DPSNRConfig: vocab_size: int = 50257 controller_hidden_dim: int = 768 controller_num_layers: int = 12 controller_num_heads: int = 12 controller_ff_multiplier: float = 2.0 max_seq_len: int = 1024 dropout: float = 0.0 pool_total_vectors: int = 262144 pool_hidden_dim: int = 768 librarian_hidden_dim: int = 32 max_reasoning_loops: int = 6 min_reasoning_loops: int = 1 halt_threshold: float = 0.99 min_k: int = 4 max_k: int = 32 num_clusters_to_search: int = 4 pad_token_id: int = 0 learning_rate: float = 3e-4 gradient_checkpointing: bool = False use_bf16: bool = False num_indexer_heads: int = 1 sigma_min: float = 0.01 sigma_max: float = 5.0 use_2d_pool: bool = False pool_grid_rows: int = 512 pool_grid_cols: int = 512 sigma_anneal_steps: int = 0 sigma_target: float = 0.05 precision_loss_weight: float = 0.0 # Fields needed by create_train_state but unused for inference streaming: bool = True hf_dataset_name: Optional[str] = None hf_tokenizer_name: Optional[str] = None max_steps: Optional[int] = None generation_steps: Optional[int] = None generation_max_tokens: int = 20 generation_prompts: Optional[list] = None num_workers: int = 4 loss_chunk_size: int = 0 finetune: Optional[Any] = None CONFIG = DPSNRConfig() # ═══════════════════════════════════════════════════════════════════════════════ # MODEL LAYERS # ═══════════════════════════════════════════════════════════════════════════════ class FlashCausalSelfAttention(nn.Module): hidden_dim: int num_heads: int dropout_rate: float = 0.0 @nn.compact def __call__(self, x, mask=None, deterministic=True): head_dim = self.hidden_dim // self.num_heads qkv = nn.Dense(3 * self.hidden_dim, use_bias=False)(x) q, k, v = jnp.split(qkv, 3, axis=-1) q = q.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim) k = k.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim) v = v.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim) dropout_rng = ( self.make_rng("dropout") if not deterministic and self.dropout_rate > 0 else None ) y = nn.dot_product_attention( q, k, v, bias=mask, dropout_rate=self.dropout_rate, deterministic=deterministic, dropout_rng=dropout_rng, ) y = y.reshape(x.shape[0], x.shape[1], self.hidden_dim) y = nn.Dense(self.hidden_dim, use_bias=False)(y) if not deterministic: y = nn.Dropout(self.dropout_rate)(y, deterministic=deterministic) return y class TinyFFN(nn.Module): hidden_dim: int ff_dim: int dropout_rate: float = 0.0 @nn.compact def __call__(self, x, deterministic=True): x = nn.Dense(self.ff_dim)(x) x = nn.gelu(x) if not deterministic: x = nn.Dropout(self.dropout_rate)(x, deterministic=deterministic) x = nn.Dense(self.hidden_dim)(x) if not deterministic: x = nn.Dropout(self.dropout_rate)(x, deterministic=deterministic) return x class TinyTransformerLayer(nn.Module): hidden_dim: int num_heads: int ff_dim: int dropout_rate: float = 0.0 @nn.compact def __call__(self, x, mask=None, deterministic=True): norm1 = nn.LayerNorm()(x) attn_out = FlashCausalSelfAttention( self.hidden_dim, self.num_heads, self.dropout_rate )(norm1, mask=mask, deterministic=deterministic) x = x + attn_out norm2 = nn.LayerNorm()(x) ffn_out = TinyFFN(self.hidden_dim, self.ff_dim, self.dropout_rate)( norm2, deterministic=deterministic ) x = x + ffn_out return x # ═══════════════════════════════════════════════════════════════════════════════ # CONTROLLER # ═══════════════════════════════════════════════════════════════════════════════ class TinyController(nn.Module): config: DPSNRConfig def setup(self): self.embedding = nn.Embed( self.config.vocab_size, self.config.controller_hidden_dim ) self.pos_encoding = nn.Embed( self.config.max_seq_len, self.config.controller_hidden_dim ) ff_dim = int( self.config.controller_hidden_dim * self.config.controller_ff_multiplier ) layer_cls = TinyTransformerLayer if self.config.gradient_checkpointing: layer_cls = nn.remat(TinyTransformerLayer, static_argnums=(3,)) self.layers = [ layer_cls( self.config.controller_hidden_dim, self.config.controller_num_heads, ff_dim, self.config.dropout, ) for _ in range(self.config.controller_num_layers) ] self.final_norm = nn.LayerNorm() self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False) def __call__(self, input_ids, deterministic=True): return self.encode(input_ids, deterministic) def encode(self, input_ids, deterministic=True): B, T = input_ids.shape embed = self.embedding(input_ids) pos_ids = jnp.arange(T)[None, :] pos_embed = self.pos_encoding(pos_ids) x = embed + pos_embed mask = nn.make_causal_mask(input_ids) mask = jnp.where(mask, 0, -1e4) for layer in self.layers: x = layer(x, mask, deterministic) return x def decode(self, hidden): x = self.final_norm(hidden) logits = self.lm_head(x) return logits # ═══════════════════════════════════════════════════════════════════════════════ # MEMORY — Learned Indexer + 1D/2D Pool # ═══════════════════════════════════════════════════════════════════════════════ class LearnedIndexer(nn.Module): hidden_dim: int num_heads: int = 1 sigma_min: float = 0.01 sigma_max: float = 5.0 @nn.compact def __call__(self, hidden_states, sigma_max_scale: float = 1.0): attn_logits = nn.Dense(1, use_bias=False)(hidden_states) attn_weights = jax.nn.softmax(attn_logits, axis=1) pooled = jnp.sum(attn_weights * hidden_states, axis=1) x = nn.Dense(self.hidden_dim)(pooled) x = nn.gelu(x) x = nn.Dense(self.hidden_dim // 2)(x) x = nn.gelu(x) mu_raw = nn.Dense(self.num_heads)(x) sigma_raw = nn.Dense(self.num_heads)(x) mu = jax.nn.sigmoid(mu_raw) effective_sigma_max = self.sigma_max * sigma_max_scale sigma = ( self.sigma_min + (effective_sigma_max - self.sigma_min) * jax.nn.sigmoid(sigma_raw) ) return mu, sigma class CoordinateMassivePool(nn.Module): config: PoolConfig window_size: int def setup(self): self.params_storage = self.param( "params_storage", nn.initializers.normal(), (self.config.total_vectors, self.config.hidden_dim), ) def __call__(self, mu, sigma): B = mu.shape[0] Total = self.config.total_vectors D = self.config.hidden_dim W = self.window_size center_idx = mu * (Total - 1) start_indices = jnp.clip(center_idx - W // 2, 0, Total - W).astype(jnp.int32) def slice_fn(start): return lax.dynamic_slice(self.params_storage, (start, 0), (W, D)) selected = jax.vmap(slice_fn)(start_indices) relative_indices = jnp.arange(W)[None, :] + start_indices[:, None] distances = relative_indices - center_idx[:, None] weights = jnp.exp(-(distances**2) / (2 * (sigma[:, None] + 1e-6) ** 2)) + 1e-6 weights = weights / jnp.sum(weights, axis=-1, keepdims=True) aggregated = jnp.einsum("bw,bwd->bd", weights, selected) return aggregated, start_indices class CoordinateMassivePool2D(nn.Module): rows: int cols: int hidden_dim: int window_size: int def setup(self): self.params_storage = self.param( "params_storage", nn.initializers.normal(), (self.rows, self.cols, self.hidden_dim), ) def __call__(self, mu_row, mu_col, sigma): B = mu_row.shape[0] R = self.rows C = self.cols D = self.hidden_dim W = self.window_size r_center = mu_row * (R - 1) r_start = jnp.clip(r_center - W // 2, 0, R - W).astype(jnp.int32) c_center = mu_col * (C - 1) c_start = jnp.clip(c_center - W // 2, 0, C - W).astype(jnp.int32) def fetch_window(r_s, c_s): return lax.dynamic_slice(self.params_storage, (r_s, c_s, 0), (W, W, D)) windows = jax.vmap(fetch_window)(r_start, c_start) r_idx = jnp.arange(W)[None, :] + r_start[:, None] c_idx = jnp.arange(W)[None, :] + c_start[:, None] r_dist = r_idx - r_center[:, None] c_dist = c_idx - c_center[:, None] sigma_sq = (sigma + 1e-6) ** 2 r_w = jnp.exp(-r_dist ** 2 / (2 * sigma_sq[:, None])) c_w = jnp.exp(-c_dist ** 2 / (2 * sigma_sq[:, None])) w_2d = jnp.einsum("bi,bj->bij", r_w, c_w) + 1e-6 w_2d = w_2d / jnp.sum(w_2d, axis=(-2, -1), keepdims=True) aggregated = jnp.einsum("bij,bijd->bd", w_2d, windows) flat_start = r_start * C + c_start return aggregated, flat_start # ═══════════════════════════════════════════════════════════════════════════════ # REASONING — Adaptive Compute Controller # ═══════════════════════════════════════════════════════════════════════════════ class AdaptiveComputeController(nn.Module): hidden_dim: int max_loops: int = 8 halt_threshold: float = 0.99 def setup(self): self.halt_net = nn.Sequential( [nn.Dense(self.hidden_dim // 4), nn.gelu, nn.Dense(1), nn.sigmoid] ) self.state_gate = nn.Sequential([nn.Dense(self.hidden_dim), nn.sigmoid]) self.state_transform = nn.Dense(self.hidden_dim) self.state_norm = nn.LayerNorm() self.loop_embed = nn.Embed(32, self.hidden_dim) def __call__(self, state_hidden, step_output, loop_count, current_halt_prob, halted_mask): loop_idx = jnp.array([loop_count], dtype=jnp.int32) emb = self.loop_embed(loop_idx) step_output = step_output + emb combined = jnp.concatenate([step_output, state_hidden], axis=-1) g = self.state_gate(combined) candidate_state = g * self.state_transform(step_output) + (1 - g) * state_hidden candidate_state = self.state_norm(candidate_state) hp = self.halt_net(candidate_state) still_running_mask = 1.0 - halted_mask new_halt_prob = current_halt_prob + hp * still_running_mask is_halted_now = (new_halt_prob >= self.halt_threshold).astype(jnp.float32) final_halted_mask = jnp.maximum(halted_mask, is_halted_now) return candidate_state, new_halt_prob, final_halted_mask # ═══════════════════════════════════════════════════════════════════════════════ # DPSNR — Full model # ═══════════════════════════════════════════════════════════════════════════════ class DPSNR(nn.Module): config: DPSNRConfig def setup(self): self.controller = TinyController(self.config) self.indexer = LearnedIndexer( self.config.controller_hidden_dim, num_heads=self.config.num_indexer_heads, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, ) if self.config.use_2d_pool: axis_window = max(2, int(self.config.max_k ** 0.5)) self.pool = CoordinateMassivePool2D( rows=self.config.pool_grid_rows, cols=self.config.pool_grid_cols, hidden_dim=self.config.controller_hidden_dim, window_size=axis_window, ) else: self.pool = CoordinateMassivePool( PoolConfig( self.config.pool_total_vectors, self.config.controller_hidden_dim, ), window_size=self.config.max_k, ) self.acc = AdaptiveComputeController( self.config.controller_hidden_dim, self.config.max_reasoning_loops, self.config.halt_threshold, ) self.retrieval_integrator = nn.Sequential( [ nn.Dense(self.config.controller_hidden_dim), nn.gelu, nn.Dense(self.config.controller_hidden_dim), nn.LayerNorm(), ] ) def __call__(self, input_ids, deterministic=True, sigma_max_scale: float = 1.0): state_hidden, all_indices, mean_sigma = self._encode_hidden( input_ids, deterministic, sigma_max_scale ) logits = self.controller.decode(state_hidden) return logits, (self.config.max_reasoning_loops, all_indices, mean_sigma) def encode_to_hidden(self, input_ids, deterministic=True, sigma_max_scale: float = 1.0): state_hidden, all_indices, mean_sigma = self._encode_hidden( input_ids, deterministic, sigma_max_scale ) return state_hidden, (self.config.max_reasoning_loops, all_indices, mean_sigma) def _encode_hidden(self, input_ids, deterministic=True, sigma_max_scale: float = 1.0): hidden = self.controller(input_ids, deterministic) state_hidden = hidden B, T, D = hidden.shape halt_prob = jnp.zeros((B, T, 1), dtype=hidden.dtype) halted_mask = jnp.zeros((B, T, 1), dtype=hidden.dtype) # Warm-up calls: force Flax to trace all sub-modules before scan _mu, _sigma = self.indexer( jnp.zeros((B, T, D)), sigma_max_scale=sigma_max_scale ) if self.config.use_2d_pool: H = self.config.num_indexer_heads h_per_dim = max(1, H // 2) _ = self.pool(jnp.zeros((B,)), jnp.zeros((B,)), jnp.zeros((B,))) else: _ = self.pool(jnp.zeros((B,)), jnp.zeros((B,))) _ = self.retrieval_integrator( jnp.zeros((B, T, D + self.config.controller_hidden_dim)) ) _ = self.acc(state_hidden, state_hidden, 0, halt_prob, halted_mask) use_2d = self.config.use_2d_pool H = self.config.num_indexer_heads def reasoning_step(carry, i): s_hidden, h_prob, h_mask = carry prev_s_hidden = s_hidden mu, sigma = self.indexer(s_hidden, sigma_max_scale=sigma_max_scale) all_retrieved = [] all_start_indices = [] if use_2d: heads_per_dim = max(1, H // 2) for h in range(heads_per_dim): h_row = h h_col = min(h + heads_per_dim, H - 1) sigma_h = (sigma[:, h_row] + sigma[:, h_col]) / 2.0 retrieved_h, start_idx_h = self.pool( mu[:, h_row], mu[:, h_col], sigma_h ) all_retrieved.append(retrieved_h) all_start_indices.append(start_idx_h) else: for h in range(H): retrieved_h, start_idx_h = self.pool(mu[:, h], sigma[:, h]) all_retrieved.append(retrieved_h) all_start_indices.append(start_idx_h) retrieved = jnp.mean(jnp.stack(all_retrieved, axis=1), axis=1) start_indices = jnp.concatenate(all_start_indices, axis=0) mean_sigma_step = jnp.mean(sigma) retrieved_expanded = jnp.expand_dims(retrieved, 1).repeat(T, axis=1) combined = jnp.concatenate([s_hidden, retrieved_expanded], axis=-1) integrated = self.retrieval_integrator(combined) new_s_hidden, h_prob, new_h_mask = self.acc( s_hidden, s_hidden + integrated, i, h_prob, h_mask, ) update_mask = 1.0 - h_mask s_hidden = update_mask * new_s_hidden + h_mask * prev_s_hidden carry_dtype = prev_s_hidden.dtype s_hidden = s_hidden.astype(carry_dtype) h_prob = h_prob.astype(carry_dtype) new_h_mask = new_h_mask.astype(carry_dtype) return (s_hidden, h_prob, new_h_mask), (start_indices, mean_sigma_step) _scan_fn = reasoning_step if self.config.gradient_checkpointing: _scan_fn = jax.checkpoint(reasoning_step) init_carry = (state_hidden, halt_prob, halted_mask) (state_hidden, halt_prob, halted_mask), (all_indices, sigma_per_loop) = ( jax.lax.scan( _scan_fn, init_carry, jnp.arange(self.config.max_reasoning_loops), ) ) all_indices = jnp.transpose(all_indices, (1, 0)) mean_sigma = jnp.mean(sigma_per_loop) return state_hidden, all_indices, mean_sigma # ═══════════════════════════════════════════════════════════════════════════════ # TRAIN STATE — Minimal, just enough to restore the checkpoint pytree # ═══════════════════════════════════════════════════════════════════════════════ class TrainState(train_state.TrainState): rng: Any pool_m: jnp.ndarray pool_v: jnp.ndarray window_size: int = struct.field(pytree_node=False) learning_rate_fn: Callable[[int], float] = struct.field(pytree_node=False) sigma_anneal_fn: Callable[[int], float] = struct.field(pytree_node=False) def _create_dummy_state(rng, config): """Create a dummy TrainState with the correct pytree structure for checkpoint restore.""" model = DPSNR(config) dummy_input = jnp.ones((1, config.max_seq_len), dtype=jnp.int32) variables = model.init(rng, dummy_input) params = variables["params"] flat_params = traverse_util.flatten_dict(params) pool_key = ("pool", "params_storage") pool_params = flat_params[pool_key] dense_flat_params = {k: v for k, v in flat_params.items() if k != pool_key} dense_params = traverse_util.unflatten_dict(dense_flat_params) learning_rate_fn = lambda step: config.learning_rate tx = optax.chain( optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=learning_rate_fn), ) opt_state = tx.init(dense_params) pool_m = jnp.zeros_like(pool_params) pool_v = jnp.zeros_like(pool_params) sigma_anneal_fn = lambda step: 1.0 return TrainState( step=jnp.array(0, dtype=jnp.int32), apply_fn=model.apply, params=params, tx=tx, opt_state=opt_state, rng=rng, pool_m=pool_m, pool_v=pool_v, window_size=config.max_k, learning_rate_fn=learning_rate_fn, sigma_anneal_fn=sigma_anneal_fn, ) # ═══════════════════════════════════════════════════════════════════════════════ # INFERENCE CONTAINER # ═══════════════════════════════════════════════════════════════════════════════ InferenceModel = namedtuple("InferenceModel", ["apply_fn", "params", "step"]) # ═══════════════════════════════════════════════════════════════════════════════ # TOKENIZER # ═══════════════════════════════════════════════════════════════════════════════ def load_tokenizer(): tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer # ═══════════════════════════════════════════════════════════════════════════════ # CHECKPOINT LOADING # ═══════════════════════════════════════════════════════════════════════════════ def load_checkpoint(): """Load trained weights from checkpoint. Returns only params + apply_fn.""" rng = jax.random.PRNGKey(0) cpu = jax.devices("cpu")[0] print("[Init] Creating model skeleton on CPU...") with jax.default_device(cpu): dummy_state = _create_dummy_state(rng, CONFIG) dummy_state = jax.device_put(dummy_state, cpu) abs_ckpt = os.path.abspath(CHECKPOINT_DIR) checkpointer = orbax.checkpoint.PyTreeCheckpointer() restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(dummy_state) mgr = orbax.checkpoint.CheckpointManager(abs_ckpt, checkpointer) latest_step = mgr.latest_step() if latest_step is not None: print(f"[Checkpoint] Restoring step {latest_step} from {abs_ckpt}") state = mgr.restore( latest_step, items=dummy_state, restore_kwargs={"restore_args": restore_args}, ) else: target = None for sub in ("default", ""): p = os.path.join(abs_ckpt, sub) if sub else abs_ckpt if os.path.exists(os.path.join(p, "_METADATA")): target = p break if target is None: raise FileNotFoundError(f"No valid checkpoint found in {abs_ckpt}") print(f"[Checkpoint] Restoring directly from {target}") state = checkpointer.restore(target, item=dummy_state, restore_args=restore_args) step = int(state.step) apply_fn = state.apply_fn params = state.params del dummy_state, state if PLATFORM != "cpu": print(f"[Device] Moving model params to {DEVICE}...") params = jax.device_put(params, DEVICE) print(f"[Checkpoint] Loaded at training step {step}") return InferenceModel(apply_fn=apply_fn, params=params, step=step) # ═══════════════════════════════════════════════════════════════════════════════ # JIT FORWARD PASS # ═══════════════════════════════════════════════════════════════════════════════ @partial(jax.jit, static_argnums=(0,)) def _forward(apply_fn, params, input_ids): logits, _ = apply_fn({"params": params}, input_ids, deterministic=True) return logits # ═══════════════════════════════════════════════════════════════════════════════ # TEXT GENERATION # ═══════════════════════════════════════════════════════════════════════════════ def generate( model: InferenceModel, prompt: str, tokenizer, rng, max_tokens: int = 100, temperature: float = 0.7, top_k: int = 40, repetition_penalty: float = 1.2, ): """Autoregressive generation with fixed-size buffers (no XLA recompilation).""" input_ids = tokenizer.encode(prompt, return_tensors="np") eos_id = tokenizer.eos_token_id prompt_len = input_ids.shape[1] max_seq = CONFIG.max_seq_len if prompt_len > max_seq: input_ids = input_ids[:, :max_seq] prompt_len = max_seq buf = jnp.zeros((1, max_seq), dtype=jnp.int32) buf = buf.at[:, :prompt_len].set(input_ids) gen_buf = jnp.zeros((max_tokens,), dtype=jnp.int32) n_gen = 0 for step in range(max_tokens): pos = prompt_len + step if pos >= max_seq: break logits = _forward(model.apply_fn, model.params, buf) next_logits = logits[0, pos - 1, :] # Repetition penalty if n_gen > 0: prev = gen_buf[:n_gen] vocab = next_logits.shape[-1] mask = jnp.zeros(vocab, dtype=jnp.bool_) mask = mask.at[prev].set(True) penalized = jnp.where( next_logits > 0, next_logits / repetition_penalty, next_logits * repetition_penalty, ) next_logits = jnp.where(mask, penalized, next_logits) # Top-k filtering k = min(top_k, next_logits.shape[-1]) vals, _ = jax.lax.top_k(next_logits, k=k) threshold = vals[-1] next_logits = jnp.where(next_logits < threshold, -1e10, next_logits) # Temperature sampling rng, key = jax.random.split(rng) token = jax.random.categorical(key, next_logits / max(temperature, 1e-8)) token_int = int(token) buf = buf.at[0, pos].set(token_int) gen_buf = gen_buf.at[n_gen].set(token_int) n_gen += 1 if token_int == eos_id: break return tokenizer.decode( buf[0, prompt_len : prompt_len + n_gen].tolist(), skip_special_tokens=True, ) # ═══════════════════════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════════════════════ def main(): parser = argparse.ArgumentParser(description="DPSNR Large — Inference") parser.add_argument("--prompt", type=str, default=None, help="Input prompt (omit for interactive mode)") parser.add_argument("--max_tokens", type=int, default=100, help="Max tokens to generate (default: 100)") parser.add_argument("--temp", type=float, default=0.7, help="Sampling temperature (default: 0.7)") parser.add_argument("--top_k", type=int, default=40, help="Top-k sampling (default: 40)") parser.add_argument("--penalty", type=float, default=1.2, help="Repetition penalty (default: 1.2)") parser.add_argument("--checkpoint_dir", type=str, default=None, help="Override checkpoint path") args = parser.parse_args() if args.checkpoint_dir: global CHECKPOINT_DIR CHECKPOINT_DIR = args.checkpoint_dir print("=" * 60) print(" DPSNR Large — Loading Model") print("=" * 60) tokenizer = load_tokenizer() model = load_checkpoint() # Warmup: compile forward pass once print("[Warmup] Compiling forward pass...") t0 = time.time() warmup_ids = jnp.zeros((1, CONFIG.max_seq_len), dtype=jnp.int32) _ = _forward(model.apply_fn, model.params, warmup_ids) jax.effects_barrier() print(f"[Warmup] Done in {time.time() - t0:.1f}s") rng = jax.random.PRNGKey(42) def run(prompt: str): nonlocal rng rng, key = jax.random.split(rng) t0 = time.time() output = generate( model, prompt, tokenizer, key, max_tokens=args.max_tokens, temperature=args.temp, top_k=args.top_k, repetition_penalty=args.penalty, ) elapsed = time.time() - t0 print(f"\n{'─' * 50}") print(f"Prompt: {prompt}") print(f"Generated: {output}") print(f"Time: {elapsed:.2f}s") print(f"{'─' * 50}\n") if args.prompt: run(args.prompt) else: print("\n╔══════════════════════════════════════════════════╗") print("║ DPSNR Interactive Inference ║") print("║ Type 'exit' or 'quit' to stop ║") print("╚══════════════════════════════════════════════════╝\n") while True: try: user_input = input(">>> ") if user_input.strip().lower() in ("exit", "quit"): break if not user_input.strip(): continue run(user_input) except (EOFError, KeyboardInterrupt): print("\nExiting...") break if __name__ == "__main__": main()