Spaces:
Runtime error
Runtime error
| """ | |
| hf_space/app.py | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| Mochiva inference server β runs on HuggingFace Spaces (free CPU tier). | |
| Fixes vs original: | |
| β’ No longer reads special_tokens.json or generation_config.json (never created) | |
| β’ BOS/EOS/PAD resolved directly from tokenizer vocab | |
| β’ Prompt format updated to match new <s>...</s><user>...</user><mochi>...</mochi> | |
| β’ Generation stops on </mochi> tag in addition to EOS token | |
| β’ Graceful fallback if config keys are missing | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import math | |
| import time | |
| import threading | |
| import queue | |
| import re | |
| from typing import Iterator, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from huggingface_hub import snapshot_download | |
| from tokenizers import Tokenizer | |
| # βββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "Mochiva-team/Mochiva-model") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| DEVICE = "cpu" | |
| MAX_CTX = int(os.environ.get("MAX_CTX", "4096")) | |
| # βββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.scale = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).sqrt() | |
| return (x.float() / rms).to(x.dtype) * self.scale | |
| def precompute_freqs_cis( | |
| head_dim: int, | |
| max_seq: int, | |
| theta: float = 10_000.0, | |
| scaling_factor: float = 1.0, | |
| ) -> torch.Tensor: | |
| half = head_dim // 2 | |
| freqs = 1.0 / (theta ** (torch.arange(0, half, dtype=torch.float32) / half)) | |
| freqs = freqs / scaling_factor | |
| t = torch.arange(max_seq, dtype=torch.float32) | |
| freqs = torch.outer(t, freqs) | |
| return torch.polar(torch.ones_like(freqs), freqs) | |
| def apply_rope( | |
| xq: torch.Tensor, | |
| xk: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| def rotate(x): | |
| x_c = x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2) | |
| x_c = torch.view_as_complex(x_c) | |
| fc = freqs_cis.unsqueeze(0).unsqueeze(2) | |
| out = torch.view_as_real(x_c * fc).reshape(*x.shape) | |
| return out.to(x.dtype) | |
| return rotate(xq), rotate(xk) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, cfg: dict): | |
| super().__init__() | |
| self.nh = cfg["num_attention_heads"] | |
| self.hd = cfg["head_dim"] | |
| H = cfg["hidden_size"] | |
| self.q_proj = nn.Linear(H, self.nh * self.hd, bias=False) | |
| self.k_proj = nn.Linear(H, self.nh * self.hd, bias=False) | |
| self.v_proj = nn.Linear(H, self.nh * self.hd, bias=False) | |
| self.o_proj = nn.Linear(self.nh * self.hd, H, bias=False) | |
| def forward(self, x, freqs_cis, mask, kv_cache=None): | |
| B, T, _ = x.shape | |
| nh, hd = self.nh, self.hd | |
| q = self.q_proj(x).view(B, T, nh, hd) | |
| k = self.k_proj(x).view(B, T, nh, hd) | |
| v = self.v_proj(x).view(B, T, nh, hd) | |
| q, k = apply_rope(q, k, freqs_cis) | |
| if kv_cache is not None: | |
| if "k" in kv_cache: | |
| k = torch.cat([kv_cache["k"], k], dim=1) | |
| v = torch.cat([kv_cache["v"], v], dim=1) | |
| kv_cache["k"] = k | |
| kv_cache["v"] = v | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| scale = 1.0 / math.sqrt(hd) | |
| attn = torch.einsum("bhqd,bhkd->bhqk", q, k) * scale | |
| Tq, Tk = attn.shape[-2], attn.shape[-1] | |
| if mask is not None: | |
| attn = attn.masked_fill(~mask[..., :Tq, :Tk], float("-inf")) | |
| attn = F.softmax(attn.float(), dim=-1).to(q.dtype) | |
| out = torch.einsum("bhqk,bhkd->bhqd", attn, v) | |
| out = out.transpose(1, 2).contiguous().view(B, Tq, nh * hd) | |
| return self.o_proj(out) | |
| class SwiGLUMLP(nn.Module): | |
| def __init__(self, cfg: dict): | |
| super().__init__() | |
| H, I = cfg["hidden_size"], cfg["intermediate_size"] | |
| self.gate_proj = nn.Linear(H, I, bias=False) | |
| self.up_proj = nn.Linear(H, I, bias=False) | |
| self.down_proj = nn.Linear(I, H, bias=False) | |
| def forward(self, x): | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class MochivaBlock(nn.Module): | |
| def __init__(self, cfg: dict): | |
| super().__init__() | |
| eps = cfg.get("rms_norm_eps", 1e-6) | |
| self.attn_norm = RMSNorm(cfg["hidden_size"], eps) | |
| self.mlp_norm = RMSNorm(cfg["hidden_size"], eps) | |
| self.attn = CausalSelfAttention(cfg) | |
| self.mlp = SwiGLUMLP(cfg) | |
| def forward(self, x, freqs_cis, mask, kv_cache=None): | |
| x = x + self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache) | |
| x = x + self.mlp(self.mlp_norm(x)) | |
| return x | |
| class MochivaForInference(nn.Module): | |
| def __init__(self, cfg: dict): | |
| super().__init__() | |
| self.cfg = cfg | |
| V, H, L = cfg["vocab_size"], cfg["hidden_size"], cfg["num_hidden_layers"] | |
| self.embed_tokens = nn.Embedding(V, H) | |
| self.layers = nn.ModuleList([MochivaBlock(cfg) for _ in range(L)]) | |
| self.norm = RMSNorm(H, cfg.get("rms_norm_eps", 1e-6)) | |
| hd = cfg["head_dim"] | |
| ctx = cfg["max_position_embeddings"] | |
| theta = cfg.get("rope_theta", 10_000.0) | |
| scale = cfg.get("rope_scaling_factor", 1.0) | |
| self.register_buffer("freqs_cis", precompute_freqs_cis(hd, ctx, theta, scale)) | |
| def forward(self, input_ids, kv_caches=None): | |
| B, T = input_ids.shape | |
| offset = 0 | |
| if kv_caches and "k" in kv_caches[0]: | |
| offset = kv_caches[0]["k"].shape[1] | |
| x = self.embed_tokens(input_ids) | |
| full_len = offset + T | |
| mask = torch.tril(torch.ones(full_len, full_len, dtype=torch.bool, device=x.device)) | |
| mask = mask.unsqueeze(0).unsqueeze(0) | |
| freqs = self.freqs_cis[offset: offset + T] | |
| for i, layer in enumerate(self.layers): | |
| kvc = kv_caches[i] if kv_caches else None | |
| x = layer(x, freqs, mask, kvc) | |
| x = self.norm(x) | |
| return x @ self.embed_tokens.weight.T | |
| def generate_stream(self, input_ids, max_new_tokens=256, temperature=0.8, | |
| top_p=0.9, top_k=50, repetition_penalty=1.1, | |
| eos_token_id=2, stop_token_ids=None): | |
| stop_ids = set(stop_token_ids or []) | |
| stop_ids.add(eos_token_id) | |
| kv_caches = [{} for _ in self.layers] | |
| logits = self(input_ids, kv_caches) | |
| # Pass input_ids flattened to 1D | |
| next_token = _sample(logits[:, -1, :], temperature, top_p, top_k, | |
| input_ids.reshape(-1), repetition_penalty) | |
| tok_id = int(next_token) | |
| if tok_id not in stop_ids: | |
| yield tok_id | |
| generated = input_ids.reshape(-1).tolist() + [tok_id] | |
| cur = next_token.unsqueeze(0) | |
| for _ in range(max_new_tokens - 1): | |
| logits = self(cur, kv_caches) | |
| # Pass generated as a 1D tensor | |
| next_token = _sample(logits[:, -1, :], temperature, top_p, top_k, | |
| torch.tensor(generated, dtype=torch.long), repetition_penalty) | |
| tok_id = int(next_token) | |
| if tok_id in stop_ids: | |
| break | |
| generated.append(tok_id) | |
| yield tok_id | |
| cur = next_token.unsqueeze(0) | |
| # βββ Sampling βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _sample(logits, temperature, top_p, top_k, context_ids, repetition_penalty): | |
| logits = logits.float().squeeze(0) | |
| if repetition_penalty != 1.0: | |
| # Flatten safely regardless of whether context_ids is a tensor or nested list | |
| if isinstance(context_ids, torch.Tensor): | |
| flat_ids = context_ids.reshape(-1).tolist() | |
| else: | |
| flat_ids = context_ids if isinstance(context_ids[0], int) else [t for row in context_ids for t in row] | |
| for tok in set(flat_ids): | |
| logits[tok] = logits[tok] / repetition_penalty if logits[tok] > 0 \ | |
| else logits[tok] * repetition_penalty | |
| if temperature < 1e-4: | |
| return logits.argmax(keepdim=True) | |
| logits /= temperature | |
| if top_k > 0: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[-1]] = float("-inf") | |
| if top_p < 1.0: | |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) | |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p | |
| sorted_logits[sorted_remove] = float("-inf") | |
| logits = torch.zeros_like(logits).scatter_(0, sorted_idx, sorted_logits) | |
| return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) | |
| # βββ Weight loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _remap_key(key: str) -> str: | |
| key = key.replace("/", ".") | |
| key = key.replace("embed_tokens.embedding", "embed_tokens.weight") | |
| key = re.sub(r"layer_(\d+)\.", r"layers.\1.", key) | |
| key = key.replace(".kernel", ".weight") | |
| return key | |
| def load_weights(model: MochivaForInference, weights_path: str): | |
| try: | |
| from safetensors.torch import load_file | |
| flat = load_file(weights_path, device=DEVICE) | |
| except Exception: | |
| import numpy as np | |
| npz = np.load(weights_path) | |
| flat = {k: torch.from_numpy(v) for k, v in npz.items()} | |
| state_dict = model.state_dict() | |
| mapped = {} | |
| for raw_key, tensor in flat.items(): | |
| pt_key = _remap_key(raw_key) | |
| if pt_key in state_dict: | |
| if ("weight" in pt_key | |
| and pt_key != "embed_tokens.weight" | |
| and len(tensor.shape) == 2): | |
| tensor = tensor.T | |
| mapped[pt_key] = tensor.to(state_dict[pt_key].dtype) | |
| else: | |
| print(f"[model] No match for: {raw_key} β {pt_key}") | |
| missing, unexpected = model.load_state_dict(mapped, strict=False) | |
| if missing: | |
| print(f"[model] Missing: {missing[:8]}") | |
| if unexpected: | |
| print(f"[model] Unexpected: {unexpected[:8]}") | |
| print(f"[model] Loaded {len(mapped)} / {len(state_dict)} tensors") | |
| # βββ Token ID helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _tok_id(tokenizer: Tokenizer, token: str, fallback: int) -> int: | |
| """Look up a special token id; return fallback if absent.""" | |
| tid = tokenizer.token_to_id(token) | |
| return tid if tid is not None else fallback | |
| # βββ Startup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"[startup] Downloading {MODEL_REPO} β¦") | |
| t0 = time.time() | |
| model_dir = snapshot_download( | |
| MODEL_REPO, | |
| token=HF_TOKEN, | |
| ignore_patterns=["*.msgpack", "flax_model*"], | |
| ) | |
| with open(f"{model_dir}/config.json") as f: | |
| hf_cfg = json.load(f) | |
| tokenizer = Tokenizer.from_file(f"{model_dir}/tokenizer.json") | |
| # ββ Resolve special token IDs directly from tokenizer vocab ββββββββββββββββββ | |
| # No special_tokens.json needed β everything is in the tokenizer itself. | |
| BOS_ID = _tok_id(tokenizer, "<bos>", 1) | |
| EOS_ID = _tok_id(tokenizer, "<eos>", 2) | |
| PAD_ID = _tok_id(tokenizer, "<pad>", 0) | |
| # We also want to stop generation when the model closes the <mochi> tag | |
| MOCHI_CLOSE_ID = _tok_id(tokenizer, "</mochi>", -1) | |
| STOP_IDS = [EOS_ID] | |
| if MOCHI_CLOSE_ID != -1: | |
| STOP_IDS.append(MOCHI_CLOSE_ID) | |
| print(f"[startup] Special tokens β bos={BOS_ID} eos={EOS_ID} pad={PAD_ID} " | |
| f"</mochi>={MOCHI_CLOSE_ID}") | |
| # ββ Default generation params (hardcoded since generation_config.json doesn't exist) ββ | |
| DEFAULT_GEN = { | |
| "max_new_tokens": 256, | |
| "temperature": 0.8, | |
| "top_p": 0.9, | |
| "top_k": 50, | |
| "repetition_penalty": 1.1, | |
| } | |
| model = MochivaForInference(hf_cfg) | |
| model.eval() | |
| weights_file = f"{model_dir}/model.safetensors" | |
| if not os.path.exists(weights_file): | |
| weights_file = f"{model_dir}/model_weights.npz" | |
| load_weights(model, weights_file) | |
| print(f"[startup] Ready in {time.time()-t0:.1f}s " | |
| f"({sum(p.numel() for p in model.parameters())/1e6:.1f}M params)") | |
| # βββ FastAPI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="Mochiva Inference", version="2.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| persona: str = "berry" # which mochi persona | |
| hunger: float = Field(default=0.5, ge=0.0, le=1.0) | |
| happiness: float = Field(default=0.7, ge=0.0, le=1.0) | |
| bond: float = Field(default=0.5, ge=0.0, le=1.0) | |
| time_of_day: str = "afternoon" | |
| user_tone: str = "friendly" | |
| max_new_tokens: int = Field(default=256, ge=1, le=1024) | |
| temperature: float = Field(default=0.8, ge=0.01, le=2.0) | |
| top_p: float = Field(default=0.9, ge=0.0, le=1.0) | |
| top_k: int = Field(default=50, ge=0, le=500) | |
| repetition_penalty: float = Field(default=1.1, ge=1.0, le=3.0) | |
| PERSONA_STYLES = { | |
| "vanilla": "calm, soft, warm, gently poetic", | |
| "apple": "chaotic, fast, high energy, scattered but lovable", | |
| "cocoa": "sleepy, slow, cozy, dreamy", | |
| "berry": "emotional, expressive, dramatic, deeply caring", | |
| "lemon": "sarcastic, sharp, secretly soft, plays it cool", | |
| } | |
| def build_prompt(req: GenerateRequest) -> str: | |
| """ | |
| Build the full input string in the format the model was trained on: | |
| <s>system context</s><user>user message</user><mochi> | |
| The trailing <mochi> tag prompts the model to start its response. | |
| """ | |
| style = PERSONA_STYLES.get(req.persona.lower(), "friendly and playful") | |
| system = ( | |
| f"You are {req.persona.title()}, a {style} Mochi character. " | |
| f"hunger: {req.hunger:.2f}, happiness: {req.happiness:.2f}, bond: {req.bond:.2f}. " | |
| f"Time: {req.time_of_day}. User tone: {req.user_tone}." | |
| ) | |
| return f"<s>{system}</s><user>{req.prompt}</user><mochi>" | |
| # βββ SSE helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _sse(token: str = "", done: bool = False) -> str: | |
| return f"data: {json.dumps({'token': token, 'done': done})}\n\n" | |
| def _generate_sse(req: GenerateRequest) -> Iterator[str]: | |
| prompt = build_prompt(req) | |
| ids = [BOS_ID] + tokenizer.encode(prompt, add_special_tokens=False).ids | |
| if len(ids) > MAX_CTX - req.max_new_tokens: | |
| ids = ids[-(MAX_CTX - req.max_new_tokens):] | |
| input_ids = torch.tensor([ids], dtype=torch.long) | |
| tok_queue: queue.Queue[Optional[int]] = queue.Queue() | |
| def _worker(): | |
| try: | |
| for tok_id in model.generate_stream( | |
| input_ids, | |
| max_new_tokens = req.max_new_tokens, | |
| temperature = req.temperature, | |
| top_p = req.top_p, | |
| top_k = req.top_k, | |
| repetition_penalty = req.repetition_penalty, | |
| eos_token_id = EOS_ID, | |
| stop_token_ids = STOP_IDS, | |
| ): | |
| tok_queue.put(tok_id) | |
| finally: | |
| tok_queue.put(None) | |
| threading.Thread(target=_worker, daemon=True).start() | |
| buf = [] | |
| while True: | |
| tok_id = tok_queue.get() | |
| if tok_id is None: | |
| break | |
| buf.append(tok_id) | |
| text = tokenizer.decode(buf) | |
| # Hold back until we have a complete UTF-8 character | |
| if text.endswith("\ufffd"): | |
| continue | |
| # Strip any closing mochi tag that leaked through | |
| text = text.replace("</mochi>", "").replace("<mochi>", "") | |
| if text: | |
| yield _sse(token=text) | |
| buf = [] | |
| if buf: | |
| text = tokenizer.decode(buf).replace("</mochi>", "").replace("<mochi>", "") | |
| if text: | |
| yield _sse(token=text) | |
| yield _sse(done=True) | |
| # βββ Endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_stream(req: GenerateRequest): | |
| return StreamingResponse( | |
| _generate_sse(req), | |
| media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| def generate_full(req: GenerateRequest): | |
| tokens = [] | |
| for chunk in _generate_sse(req): | |
| if chunk.startswith("data: "): | |
| obj = json.loads(chunk[6:]) | |
| if not obj["done"]: | |
| tokens.append(obj["token"]) | |
| return {"text": "".join(tokens), "persona": req.persona, "model": MODEL_REPO} | |
| def health(): | |
| return {"status": "ok", "model": MODEL_REPO} | |
| def info(): | |
| return { | |
| "model": MODEL_REPO, | |
| "vocab_size": hf_cfg["vocab_size"], | |
| "layers": hf_cfg["num_hidden_layers"], | |
| "hidden": hf_cfg["hidden_size"], | |
| "context": hf_cfg["max_position_embeddings"], | |
| "personas": list(PERSONA_STYLES.keys()), | |
| "special_toks": {"bos": BOS_ID, "eos": EOS_ID, "pad": PAD_ID}, | |
| "device": DEVICE, | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |