backend / app.py
Bc-AI's picture
Update app.py
1a09475 verified
"""
hf_space/app.py
──────────────────────────────────────────────────────────────────────────────
Mochiva inference server β€” runs on HuggingFace Spaces (free CPU tier).
Fixes vs original:
β€’ No longer reads special_tokens.json or generation_config.json (never created)
β€’ BOS/EOS/PAD resolved directly from tokenizer vocab
β€’ Prompt format updated to match new <s>...</s><user>...</user><mochi>...</mochi>
β€’ Generation stops on </mochi> tag in addition to EOS token
β€’ Graceful fallback if config keys are missing
"""
from __future__ import annotations
import os
import json
import math
import time
import threading
import queue
import re
from typing import Iterator, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from huggingface_hub import snapshot_download
from tokenizers import Tokenizer
# ─── Config ───────────────────────────────────────────────────────────────────
MODEL_REPO = os.environ.get("MODEL_REPO", "Mochiva-team/Mochiva-model")
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DEVICE = "cpu"
MAX_CTX = int(os.environ.get("MAX_CTX", "4096"))
# ─── Model ────────────────────────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
return (x.float() / rms).to(x.dtype) * self.scale
def precompute_freqs_cis(
head_dim: int,
max_seq: int,
theta: float = 10_000.0,
scaling_factor: float = 1.0,
) -> torch.Tensor:
half = head_dim // 2
freqs = 1.0 / (theta ** (torch.arange(0, half, dtype=torch.float32) / half))
freqs = freqs / scaling_factor
t = torch.arange(max_seq, dtype=torch.float32)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rope(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
def rotate(x):
x_c = x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
x_c = torch.view_as_complex(x_c)
fc = freqs_cis.unsqueeze(0).unsqueeze(2)
out = torch.view_as_real(x_c * fc).reshape(*x.shape)
return out.to(x.dtype)
return rotate(xq), rotate(xk)
class CausalSelfAttention(nn.Module):
def __init__(self, cfg: dict):
super().__init__()
self.nh = cfg["num_attention_heads"]
self.hd = cfg["head_dim"]
H = cfg["hidden_size"]
self.q_proj = nn.Linear(H, self.nh * self.hd, bias=False)
self.k_proj = nn.Linear(H, self.nh * self.hd, bias=False)
self.v_proj = nn.Linear(H, self.nh * self.hd, bias=False)
self.o_proj = nn.Linear(self.nh * self.hd, H, bias=False)
def forward(self, x, freqs_cis, mask, kv_cache=None):
B, T, _ = x.shape
nh, hd = self.nh, self.hd
q = self.q_proj(x).view(B, T, nh, hd)
k = self.k_proj(x).view(B, T, nh, hd)
v = self.v_proj(x).view(B, T, nh, hd)
q, k = apply_rope(q, k, freqs_cis)
if kv_cache is not None:
if "k" in kv_cache:
k = torch.cat([kv_cache["k"], k], dim=1)
v = torch.cat([kv_cache["v"], v], dim=1)
kv_cache["k"] = k
kv_cache["v"] = v
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scale = 1.0 / math.sqrt(hd)
attn = torch.einsum("bhqd,bhkd->bhqk", q, k) * scale
Tq, Tk = attn.shape[-2], attn.shape[-1]
if mask is not None:
attn = attn.masked_fill(~mask[..., :Tq, :Tk], float("-inf"))
attn = F.softmax(attn.float(), dim=-1).to(q.dtype)
out = torch.einsum("bhqk,bhkd->bhqd", attn, v)
out = out.transpose(1, 2).contiguous().view(B, Tq, nh * hd)
return self.o_proj(out)
class SwiGLUMLP(nn.Module):
def __init__(self, cfg: dict):
super().__init__()
H, I = cfg["hidden_size"], cfg["intermediate_size"]
self.gate_proj = nn.Linear(H, I, bias=False)
self.up_proj = nn.Linear(H, I, bias=False)
self.down_proj = nn.Linear(I, H, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class MochivaBlock(nn.Module):
def __init__(self, cfg: dict):
super().__init__()
eps = cfg.get("rms_norm_eps", 1e-6)
self.attn_norm = RMSNorm(cfg["hidden_size"], eps)
self.mlp_norm = RMSNorm(cfg["hidden_size"], eps)
self.attn = CausalSelfAttention(cfg)
self.mlp = SwiGLUMLP(cfg)
def forward(self, x, freqs_cis, mask, kv_cache=None):
x = x + self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache)
x = x + self.mlp(self.mlp_norm(x))
return x
class MochivaForInference(nn.Module):
def __init__(self, cfg: dict):
super().__init__()
self.cfg = cfg
V, H, L = cfg["vocab_size"], cfg["hidden_size"], cfg["num_hidden_layers"]
self.embed_tokens = nn.Embedding(V, H)
self.layers = nn.ModuleList([MochivaBlock(cfg) for _ in range(L)])
self.norm = RMSNorm(H, cfg.get("rms_norm_eps", 1e-6))
hd = cfg["head_dim"]
ctx = cfg["max_position_embeddings"]
theta = cfg.get("rope_theta", 10_000.0)
scale = cfg.get("rope_scaling_factor", 1.0)
self.register_buffer("freqs_cis", precompute_freqs_cis(hd, ctx, theta, scale))
def forward(self, input_ids, kv_caches=None):
B, T = input_ids.shape
offset = 0
if kv_caches and "k" in kv_caches[0]:
offset = kv_caches[0]["k"].shape[1]
x = self.embed_tokens(input_ids)
full_len = offset + T
mask = torch.tril(torch.ones(full_len, full_len, dtype=torch.bool, device=x.device))
mask = mask.unsqueeze(0).unsqueeze(0)
freqs = self.freqs_cis[offset: offset + T]
for i, layer in enumerate(self.layers):
kvc = kv_caches[i] if kv_caches else None
x = layer(x, freqs, mask, kvc)
x = self.norm(x)
return x @ self.embed_tokens.weight.T
@torch.inference_mode()
def generate_stream(self, input_ids, max_new_tokens=256, temperature=0.8,
top_p=0.9, top_k=50, repetition_penalty=1.1,
eos_token_id=2, stop_token_ids=None):
stop_ids = set(stop_token_ids or [])
stop_ids.add(eos_token_id)
kv_caches = [{} for _ in self.layers]
logits = self(input_ids, kv_caches)
# Pass input_ids flattened to 1D
next_token = _sample(logits[:, -1, :], temperature, top_p, top_k,
input_ids.reshape(-1), repetition_penalty)
tok_id = int(next_token)
if tok_id not in stop_ids:
yield tok_id
generated = input_ids.reshape(-1).tolist() + [tok_id]
cur = next_token.unsqueeze(0)
for _ in range(max_new_tokens - 1):
logits = self(cur, kv_caches)
# Pass generated as a 1D tensor
next_token = _sample(logits[:, -1, :], temperature, top_p, top_k,
torch.tensor(generated, dtype=torch.long), repetition_penalty)
tok_id = int(next_token)
if tok_id in stop_ids:
break
generated.append(tok_id)
yield tok_id
cur = next_token.unsqueeze(0)
# ─── Sampling ─────────────────────────────────────────────────────────────────
def _sample(logits, temperature, top_p, top_k, context_ids, repetition_penalty):
logits = logits.float().squeeze(0)
if repetition_penalty != 1.0:
# Flatten safely regardless of whether context_ids is a tensor or nested list
if isinstance(context_ids, torch.Tensor):
flat_ids = context_ids.reshape(-1).tolist()
else:
flat_ids = context_ids if isinstance(context_ids[0], int) else [t for row in context_ids for t in row]
for tok in set(flat_ids):
logits[tok] = logits[tok] / repetition_penalty if logits[tok] > 0 \
else logits[tok] * repetition_penalty
if temperature < 1e-4:
return logits.argmax(keepdim=True)
logits /= temperature
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[-1]] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
sorted_logits[sorted_remove] = float("-inf")
logits = torch.zeros_like(logits).scatter_(0, sorted_idx, sorted_logits)
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
# ─── Weight loading ───────────────────────────────────────────────────────────
def _remap_key(key: str) -> str:
key = key.replace("/", ".")
key = key.replace("embed_tokens.embedding", "embed_tokens.weight")
key = re.sub(r"layer_(\d+)\.", r"layers.\1.", key)
key = key.replace(".kernel", ".weight")
return key
def load_weights(model: MochivaForInference, weights_path: str):
try:
from safetensors.torch import load_file
flat = load_file(weights_path, device=DEVICE)
except Exception:
import numpy as np
npz = np.load(weights_path)
flat = {k: torch.from_numpy(v) for k, v in npz.items()}
state_dict = model.state_dict()
mapped = {}
for raw_key, tensor in flat.items():
pt_key = _remap_key(raw_key)
if pt_key in state_dict:
if ("weight" in pt_key
and pt_key != "embed_tokens.weight"
and len(tensor.shape) == 2):
tensor = tensor.T
mapped[pt_key] = tensor.to(state_dict[pt_key].dtype)
else:
print(f"[model] No match for: {raw_key} β†’ {pt_key}")
missing, unexpected = model.load_state_dict(mapped, strict=False)
if missing:
print(f"[model] Missing: {missing[:8]}")
if unexpected:
print(f"[model] Unexpected: {unexpected[:8]}")
print(f"[model] Loaded {len(mapped)} / {len(state_dict)} tensors")
# ─── Token ID helpers ─────────────────────────────────────────────────────────
def _tok_id(tokenizer: Tokenizer, token: str, fallback: int) -> int:
"""Look up a special token id; return fallback if absent."""
tid = tokenizer.token_to_id(token)
return tid if tid is not None else fallback
# ─── Startup ──────────────────────────────────────────────────────────────────
print(f"[startup] Downloading {MODEL_REPO} …")
t0 = time.time()
model_dir = snapshot_download(
MODEL_REPO,
token=HF_TOKEN,
ignore_patterns=["*.msgpack", "flax_model*"],
)
with open(f"{model_dir}/config.json") as f:
hf_cfg = json.load(f)
tokenizer = Tokenizer.from_file(f"{model_dir}/tokenizer.json")
# ── Resolve special token IDs directly from tokenizer vocab ──────────────────
# No special_tokens.json needed β€” everything is in the tokenizer itself.
BOS_ID = _tok_id(tokenizer, "<bos>", 1)
EOS_ID = _tok_id(tokenizer, "<eos>", 2)
PAD_ID = _tok_id(tokenizer, "<pad>", 0)
# We also want to stop generation when the model closes the <mochi> tag
MOCHI_CLOSE_ID = _tok_id(tokenizer, "</mochi>", -1)
STOP_IDS = [EOS_ID]
if MOCHI_CLOSE_ID != -1:
STOP_IDS.append(MOCHI_CLOSE_ID)
print(f"[startup] Special tokens β€” bos={BOS_ID} eos={EOS_ID} pad={PAD_ID} "
f"</mochi>={MOCHI_CLOSE_ID}")
# ── Default generation params (hardcoded since generation_config.json doesn't exist) ──
DEFAULT_GEN = {
"max_new_tokens": 256,
"temperature": 0.8,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.1,
}
model = MochivaForInference(hf_cfg)
model.eval()
weights_file = f"{model_dir}/model.safetensors"
if not os.path.exists(weights_file):
weights_file = f"{model_dir}/model_weights.npz"
load_weights(model, weights_file)
print(f"[startup] Ready in {time.time()-t0:.1f}s "
f"({sum(p.numel() for p in model.parameters())/1e6:.1f}M params)")
# ─── FastAPI ──────────────────────────────────────────────────────────────────
app = FastAPI(title="Mochiva Inference", version="2.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class GenerateRequest(BaseModel):
prompt: str
persona: str = "berry" # which mochi persona
hunger: float = Field(default=0.5, ge=0.0, le=1.0)
happiness: float = Field(default=0.7, ge=0.0, le=1.0)
bond: float = Field(default=0.5, ge=0.0, le=1.0)
time_of_day: str = "afternoon"
user_tone: str = "friendly"
max_new_tokens: int = Field(default=256, ge=1, le=1024)
temperature: float = Field(default=0.8, ge=0.01, le=2.0)
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
top_k: int = Field(default=50, ge=0, le=500)
repetition_penalty: float = Field(default=1.1, ge=1.0, le=3.0)
PERSONA_STYLES = {
"vanilla": "calm, soft, warm, gently poetic",
"apple": "chaotic, fast, high energy, scattered but lovable",
"cocoa": "sleepy, slow, cozy, dreamy",
"berry": "emotional, expressive, dramatic, deeply caring",
"lemon": "sarcastic, sharp, secretly soft, plays it cool",
}
def build_prompt(req: GenerateRequest) -> str:
"""
Build the full input string in the format the model was trained on:
<s>system context</s><user>user message</user><mochi>
The trailing <mochi> tag prompts the model to start its response.
"""
style = PERSONA_STYLES.get(req.persona.lower(), "friendly and playful")
system = (
f"You are {req.persona.title()}, a {style} Mochi character. "
f"hunger: {req.hunger:.2f}, happiness: {req.happiness:.2f}, bond: {req.bond:.2f}. "
f"Time: {req.time_of_day}. User tone: {req.user_tone}."
)
return f"<s>{system}</s><user>{req.prompt}</user><mochi>"
# ─── SSE helpers ──────────────────────────────────────────────────────────────
def _sse(token: str = "", done: bool = False) -> str:
return f"data: {json.dumps({'token': token, 'done': done})}\n\n"
def _generate_sse(req: GenerateRequest) -> Iterator[str]:
prompt = build_prompt(req)
ids = [BOS_ID] + tokenizer.encode(prompt, add_special_tokens=False).ids
if len(ids) > MAX_CTX - req.max_new_tokens:
ids = ids[-(MAX_CTX - req.max_new_tokens):]
input_ids = torch.tensor([ids], dtype=torch.long)
tok_queue: queue.Queue[Optional[int]] = queue.Queue()
def _worker():
try:
for tok_id in model.generate_stream(
input_ids,
max_new_tokens = req.max_new_tokens,
temperature = req.temperature,
top_p = req.top_p,
top_k = req.top_k,
repetition_penalty = req.repetition_penalty,
eos_token_id = EOS_ID,
stop_token_ids = STOP_IDS,
):
tok_queue.put(tok_id)
finally:
tok_queue.put(None)
threading.Thread(target=_worker, daemon=True).start()
buf = []
while True:
tok_id = tok_queue.get()
if tok_id is None:
break
buf.append(tok_id)
text = tokenizer.decode(buf)
# Hold back until we have a complete UTF-8 character
if text.endswith("\ufffd"):
continue
# Strip any closing mochi tag that leaked through
text = text.replace("</mochi>", "").replace("<mochi>", "")
if text:
yield _sse(token=text)
buf = []
if buf:
text = tokenizer.decode(buf).replace("</mochi>", "").replace("<mochi>", "")
if text:
yield _sse(token=text)
yield _sse(done=True)
# ─── Endpoints ────────────────────────────────────────────────────────────────
@app.post("/generate")
def generate_stream(req: GenerateRequest):
return StreamingResponse(
_generate_sse(req),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/generate_full")
def generate_full(req: GenerateRequest):
tokens = []
for chunk in _generate_sse(req):
if chunk.startswith("data: "):
obj = json.loads(chunk[6:])
if not obj["done"]:
tokens.append(obj["token"])
return {"text": "".join(tokens), "persona": req.persona, "model": MODEL_REPO}
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_REPO}
@app.get("/info")
def info():
return {
"model": MODEL_REPO,
"vocab_size": hf_cfg["vocab_size"],
"layers": hf_cfg["num_hidden_layers"],
"hidden": hf_cfg["hidden_size"],
"context": hf_cfg["max_position_embeddings"],
"personas": list(PERSONA_STYLES.keys()),
"special_toks": {"bos": BOS_ID, "eos": EOS_ID, "pad": PAD_ID},
"device": DEVICE,
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)