""" hf_space/app.py ────────────────────────────────────────────────────────────────────────────── Mochiva inference server — runs on HuggingFace Spaces (free CPU tier). Fixes vs original: • No longer reads special_tokens.json or generation_config.json (never created) • BOS/EOS/PAD resolved directly from tokenizer vocab • Prompt format updated to match new ......... • Generation stops on tag in addition to EOS token • Graceful fallback if config keys are missing """ from __future__ import annotations import os import json import math import time import threading import queue import re from typing import Iterator, Optional import torch import torch.nn as nn import torch.nn.functional as F from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from huggingface_hub import snapshot_download from tokenizers import Tokenizer # ─── Config ─────────────────────────────────────────────────────────────────── MODEL_REPO = os.environ.get("MODEL_REPO", "Mochiva-team/Mochiva-model") HF_TOKEN = os.environ.get("HF_TOKEN", None) DEVICE = "cpu" MAX_CTX = int(os.environ.get("MAX_CTX", "4096")) # ─── Model ──────────────────────────────────────────────────────────────────── class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).sqrt() return (x.float() / rms).to(x.dtype) * self.scale def precompute_freqs_cis( head_dim: int, max_seq: int, theta: float = 10_000.0, scaling_factor: float = 1.0, ) -> torch.Tensor: half = head_dim // 2 freqs = 1.0 / (theta ** (torch.arange(0, half, dtype=torch.float32) / half)) freqs = freqs / scaling_factor t = torch.arange(max_seq, dtype=torch.float32) freqs = torch.outer(t, freqs) return torch.polar(torch.ones_like(freqs), freqs) def apply_rope( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: def rotate(x): x_c = x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2) x_c = torch.view_as_complex(x_c) fc = freqs_cis.unsqueeze(0).unsqueeze(2) out = torch.view_as_real(x_c * fc).reshape(*x.shape) return out.to(x.dtype) return rotate(xq), rotate(xk) class CausalSelfAttention(nn.Module): def __init__(self, cfg: dict): super().__init__() self.nh = cfg["num_attention_heads"] self.hd = cfg["head_dim"] H = cfg["hidden_size"] self.q_proj = nn.Linear(H, self.nh * self.hd, bias=False) self.k_proj = nn.Linear(H, self.nh * self.hd, bias=False) self.v_proj = nn.Linear(H, self.nh * self.hd, bias=False) self.o_proj = nn.Linear(self.nh * self.hd, H, bias=False) def forward(self, x, freqs_cis, mask, kv_cache=None): B, T, _ = x.shape nh, hd = self.nh, self.hd q = self.q_proj(x).view(B, T, nh, hd) k = self.k_proj(x).view(B, T, nh, hd) v = self.v_proj(x).view(B, T, nh, hd) q, k = apply_rope(q, k, freqs_cis) if kv_cache is not None: if "k" in kv_cache: k = torch.cat([kv_cache["k"], k], dim=1) v = torch.cat([kv_cache["v"], v], dim=1) kv_cache["k"] = k kv_cache["v"] = v q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) scale = 1.0 / math.sqrt(hd) attn = torch.einsum("bhqd,bhkd->bhqk", q, k) * scale Tq, Tk = attn.shape[-2], attn.shape[-1] if mask is not None: attn = attn.masked_fill(~mask[..., :Tq, :Tk], float("-inf")) attn = F.softmax(attn.float(), dim=-1).to(q.dtype) out = torch.einsum("bhqk,bhkd->bhqd", attn, v) out = out.transpose(1, 2).contiguous().view(B, Tq, nh * hd) return self.o_proj(out) class SwiGLUMLP(nn.Module): def __init__(self, cfg: dict): super().__init__() H, I = cfg["hidden_size"], cfg["intermediate_size"] self.gate_proj = nn.Linear(H, I, bias=False) self.up_proj = nn.Linear(H, I, bias=False) self.down_proj = nn.Linear(I, H, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class MochivaBlock(nn.Module): def __init__(self, cfg: dict): super().__init__() eps = cfg.get("rms_norm_eps", 1e-6) self.attn_norm = RMSNorm(cfg["hidden_size"], eps) self.mlp_norm = RMSNorm(cfg["hidden_size"], eps) self.attn = CausalSelfAttention(cfg) self.mlp = SwiGLUMLP(cfg) def forward(self, x, freqs_cis, mask, kv_cache=None): x = x + self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache) x = x + self.mlp(self.mlp_norm(x)) return x class MochivaForInference(nn.Module): def __init__(self, cfg: dict): super().__init__() self.cfg = cfg V, H, L = cfg["vocab_size"], cfg["hidden_size"], cfg["num_hidden_layers"] self.embed_tokens = nn.Embedding(V, H) self.layers = nn.ModuleList([MochivaBlock(cfg) for _ in range(L)]) self.norm = RMSNorm(H, cfg.get("rms_norm_eps", 1e-6)) hd = cfg["head_dim"] ctx = cfg["max_position_embeddings"] theta = cfg.get("rope_theta", 10_000.0) scale = cfg.get("rope_scaling_factor", 1.0) self.register_buffer("freqs_cis", precompute_freqs_cis(hd, ctx, theta, scale)) def forward(self, input_ids, kv_caches=None): B, T = input_ids.shape offset = 0 if kv_caches and "k" in kv_caches[0]: offset = kv_caches[0]["k"].shape[1] x = self.embed_tokens(input_ids) full_len = offset + T mask = torch.tril(torch.ones(full_len, full_len, dtype=torch.bool, device=x.device)) mask = mask.unsqueeze(0).unsqueeze(0) freqs = self.freqs_cis[offset: offset + T] for i, layer in enumerate(self.layers): kvc = kv_caches[i] if kv_caches else None x = layer(x, freqs, mask, kvc) x = self.norm(x) return x @ self.embed_tokens.weight.T @torch.inference_mode() def generate_stream(self, input_ids, max_new_tokens=256, temperature=0.8, top_p=0.9, top_k=50, repetition_penalty=1.1, eos_token_id=2, stop_token_ids=None): stop_ids = set(stop_token_ids or []) stop_ids.add(eos_token_id) kv_caches = [{} for _ in self.layers] logits = self(input_ids, kv_caches) # Pass input_ids flattened to 1D next_token = _sample(logits[:, -1, :], temperature, top_p, top_k, input_ids.reshape(-1), repetition_penalty) tok_id = int(next_token) if tok_id not in stop_ids: yield tok_id generated = input_ids.reshape(-1).tolist() + [tok_id] cur = next_token.unsqueeze(0) for _ in range(max_new_tokens - 1): logits = self(cur, kv_caches) # Pass generated as a 1D tensor next_token = _sample(logits[:, -1, :], temperature, top_p, top_k, torch.tensor(generated, dtype=torch.long), repetition_penalty) tok_id = int(next_token) if tok_id in stop_ids: break generated.append(tok_id) yield tok_id cur = next_token.unsqueeze(0) # ─── Sampling ───────────────────────────────────────────────────────────────── def _sample(logits, temperature, top_p, top_k, context_ids, repetition_penalty): logits = logits.float().squeeze(0) if repetition_penalty != 1.0: # Flatten safely regardless of whether context_ids is a tensor or nested list if isinstance(context_ids, torch.Tensor): flat_ids = context_ids.reshape(-1).tolist() else: flat_ids = context_ids if isinstance(context_ids[0], int) else [t for row in context_ids for t in row] for tok in set(flat_ids): logits[tok] = logits[tok] / repetition_penalty if logits[tok] > 0 \ else logits[tok] * repetition_penalty if temperature < 1e-4: return logits.argmax(keepdim=True) logits /= temperature if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[-1]] = float("-inf") if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p sorted_logits[sorted_remove] = float("-inf") logits = torch.zeros_like(logits).scatter_(0, sorted_idx, sorted_logits) return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) # ─── Weight loading ─────────────────────────────────────────────────────────── def _remap_key(key: str) -> str: key = key.replace("/", ".") key = key.replace("embed_tokens.embedding", "embed_tokens.weight") key = re.sub(r"layer_(\d+)\.", r"layers.\1.", key) key = key.replace(".kernel", ".weight") return key def load_weights(model: MochivaForInference, weights_path: str): try: from safetensors.torch import load_file flat = load_file(weights_path, device=DEVICE) except Exception: import numpy as np npz = np.load(weights_path) flat = {k: torch.from_numpy(v) for k, v in npz.items()} state_dict = model.state_dict() mapped = {} for raw_key, tensor in flat.items(): pt_key = _remap_key(raw_key) if pt_key in state_dict: if ("weight" in pt_key and pt_key != "embed_tokens.weight" and len(tensor.shape) == 2): tensor = tensor.T mapped[pt_key] = tensor.to(state_dict[pt_key].dtype) else: print(f"[model] No match for: {raw_key} → {pt_key}") missing, unexpected = model.load_state_dict(mapped, strict=False) if missing: print(f"[model] Missing: {missing[:8]}") if unexpected: print(f"[model] Unexpected: {unexpected[:8]}") print(f"[model] Loaded {len(mapped)} / {len(state_dict)} tensors") # ─── Token ID helpers ───────────────────────────────────────────────────────── def _tok_id(tokenizer: Tokenizer, token: str, fallback: int) -> int: """Look up a special token id; return fallback if absent.""" tid = tokenizer.token_to_id(token) return tid if tid is not None else fallback # ─── Startup ────────────────────────────────────────────────────────────────── print(f"[startup] Downloading {MODEL_REPO} …") t0 = time.time() model_dir = snapshot_download( MODEL_REPO, token=HF_TOKEN, ignore_patterns=["*.msgpack", "flax_model*"], ) with open(f"{model_dir}/config.json") as f: hf_cfg = json.load(f) tokenizer = Tokenizer.from_file(f"{model_dir}/tokenizer.json") # ── Resolve special token IDs directly from tokenizer vocab ────────────────── # No special_tokens.json needed — everything is in the tokenizer itself. BOS_ID = _tok_id(tokenizer, "", 1) EOS_ID = _tok_id(tokenizer, "", 2) PAD_ID = _tok_id(tokenizer, "", 0) # We also want to stop generation when the model closes the tag MOCHI_CLOSE_ID = _tok_id(tokenizer, "", -1) STOP_IDS = [EOS_ID] if MOCHI_CLOSE_ID != -1: STOP_IDS.append(MOCHI_CLOSE_ID) print(f"[startup] Special tokens — bos={BOS_ID} eos={EOS_ID} pad={PAD_ID} " f"={MOCHI_CLOSE_ID}") # ── Default generation params (hardcoded since generation_config.json doesn't exist) ── DEFAULT_GEN = { "max_new_tokens": 256, "temperature": 0.8, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.1, } model = MochivaForInference(hf_cfg) model.eval() weights_file = f"{model_dir}/model.safetensors" if not os.path.exists(weights_file): weights_file = f"{model_dir}/model_weights.npz" load_weights(model, weights_file) print(f"[startup] Ready in {time.time()-t0:.1f}s " f"({sum(p.numel() for p in model.parameters())/1e6:.1f}M params)") # ─── FastAPI ────────────────────────────────────────────────────────────────── app = FastAPI(title="Mochiva Inference", version="2.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class GenerateRequest(BaseModel): prompt: str persona: str = "berry" # which mochi persona hunger: float = Field(default=0.5, ge=0.0, le=1.0) happiness: float = Field(default=0.7, ge=0.0, le=1.0) bond: float = Field(default=0.5, ge=0.0, le=1.0) time_of_day: str = "afternoon" user_tone: str = "friendly" max_new_tokens: int = Field(default=256, ge=1, le=1024) temperature: float = Field(default=0.8, ge=0.01, le=2.0) top_p: float = Field(default=0.9, ge=0.0, le=1.0) top_k: int = Field(default=50, ge=0, le=500) repetition_penalty: float = Field(default=1.1, ge=1.0, le=3.0) PERSONA_STYLES = { "vanilla": "calm, soft, warm, gently poetic", "apple": "chaotic, fast, high energy, scattered but lovable", "cocoa": "sleepy, slow, cozy, dreamy", "berry": "emotional, expressive, dramatic, deeply caring", "lemon": "sarcastic, sharp, secretly soft, plays it cool", } def build_prompt(req: GenerateRequest) -> str: """ Build the full input string in the format the model was trained on: system contextuser message The trailing tag prompts the model to start its response. """ style = PERSONA_STYLES.get(req.persona.lower(), "friendly and playful") system = ( f"You are {req.persona.title()}, a {style} Mochi character. " f"hunger: {req.hunger:.2f}, happiness: {req.happiness:.2f}, bond: {req.bond:.2f}. " f"Time: {req.time_of_day}. User tone: {req.user_tone}." ) return f"{system}{req.prompt}" # ─── SSE helpers ────────────────────────────────────────────────────────────── def _sse(token: str = "", done: bool = False) -> str: return f"data: {json.dumps({'token': token, 'done': done})}\n\n" def _generate_sse(req: GenerateRequest) -> Iterator[str]: prompt = build_prompt(req) ids = [BOS_ID] + tokenizer.encode(prompt, add_special_tokens=False).ids if len(ids) > MAX_CTX - req.max_new_tokens: ids = ids[-(MAX_CTX - req.max_new_tokens):] input_ids = torch.tensor([ids], dtype=torch.long) tok_queue: queue.Queue[Optional[int]] = queue.Queue() def _worker(): try: for tok_id in model.generate_stream( input_ids, max_new_tokens = req.max_new_tokens, temperature = req.temperature, top_p = req.top_p, top_k = req.top_k, repetition_penalty = req.repetition_penalty, eos_token_id = EOS_ID, stop_token_ids = STOP_IDS, ): tok_queue.put(tok_id) finally: tok_queue.put(None) threading.Thread(target=_worker, daemon=True).start() buf = [] while True: tok_id = tok_queue.get() if tok_id is None: break buf.append(tok_id) text = tokenizer.decode(buf) # Hold back until we have a complete UTF-8 character if text.endswith("\ufffd"): continue # Strip any closing mochi tag that leaked through text = text.replace("", "").replace("", "") if text: yield _sse(token=text) buf = [] if buf: text = tokenizer.decode(buf).replace("", "").replace("", "") if text: yield _sse(token=text) yield _sse(done=True) # ─── Endpoints ──────────────────────────────────────────────────────────────── @app.post("/generate") def generate_stream(req: GenerateRequest): return StreamingResponse( _generate_sse(req), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @app.post("/generate_full") def generate_full(req: GenerateRequest): tokens = [] for chunk in _generate_sse(req): if chunk.startswith("data: "): obj = json.loads(chunk[6:]) if not obj["done"]: tokens.append(obj["token"]) return {"text": "".join(tokens), "persona": req.persona, "model": MODEL_REPO} @app.get("/health") def health(): return {"status": "ok", "model": MODEL_REPO} @app.get("/info") def info(): return { "model": MODEL_REPO, "vocab_size": hf_cfg["vocab_size"], "layers": hf_cfg["num_hidden_layers"], "hidden": hf_cfg["hidden_size"], "context": hf_cfg["max_position_embeddings"], "personas": list(PERSONA_STYLES.keys()), "special_toks": {"bos": BOS_ID, "eos": EOS_ID, "pad": PAD_ID}, "device": DEVICE, } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)