Bc-AI commited on
Commit
321b635
Β·
verified Β·
1 Parent(s): 52bcb69

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +41 -0
  2. app.py +545 -0
  3. requirements.txt +17 -0
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Mochiva HF Space β€” CPU inference server ──────────────────────────────────
2
+ # Base: Python 3.11 slim (small image, fast startup on HF free tier)
3
+ FROM python:3.11-slim
4
+
5
+ # HF Spaces runs as user 1000 β€” set up a non-root user
6
+ RUN useradd -m -u 1000 mochiva
7
+ WORKDIR /app
8
+ RUN chown mochiva /app
9
+
10
+ # ── System dependencies ────────────────────────────────────────────────────
11
+ # Only what we strictly need: no CUDA, no build tools for heavy packages
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ git \
14
+ curl \
15
+ && apt-get clean \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ # ── Python dependencies ────────────────────────────────────────────────────
19
+ COPY requirements.txt .
20
+ RUN pip install --no-cache-dir --upgrade pip \
21
+ && pip install --no-cache-dir -r requirements.txt
22
+
23
+ # ── App code ───────────────────────────────────────────────────────────────
24
+ COPY app.py .
25
+
26
+ # ── HF Spaces metadata ────────────────────────────────────────────────────
27
+ # Port 7860 is the standard HF Space port
28
+ EXPOSE 7860
29
+
30
+ # ── Run as non-root ────────────────────────────────────────────────────────
31
+ USER mochiva
32
+
33
+ # ── Startup ────────────────────────────────────────────────────────────────
34
+ # --workers 1: model is loaded once in the main process; threading handles concurrency
35
+ # --timeout-keep-alive 30: keep SSE connections alive
36
+ CMD ["uvicorn", "app:app", \
37
+ "--host", "0.0.0.0", \
38
+ "--port", "7860", \
39
+ "--workers", "1", \
40
+ "--timeout-keep-alive", "30", \
41
+ "--log-level", "info"]
app.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ hf_space/app.py
3
+ ──────────────────────────────────────────────────────────────────────────────
4
+ Mochiva inference server β€” runs on HuggingFace Spaces (free CPU tier).
5
+
6
+ Architecture
7
+ β€’ PyTorch re-implementation of the Mochiva model (mirrors train model.py)
8
+ β€” loads weights from safetensors exported by export.py
9
+ β€’ FastAPI + Server-Sent Events (SSE) for streaming token-by-token responses
10
+ β€’ Model + tokeniser loaded from HF Hub at startup
11
+ β€’ Thread-safe: uses a queue to stream tokens from the generation thread
12
+
13
+ Endpoints
14
+ POST /generate β€” streaming SSE generation
15
+ POST /generate_full β€” non-streaming, returns full response JSON
16
+ GET /health β€” liveness probe
17
+ GET /info β€” model metadata
18
+
19
+ Environment variables
20
+ MODEL_REPO : HF repo id (default: "my-username/Mochiva-model")
21
+ HF_TOKEN : optional HF token for private repos
22
+
23
+ SSE protocol (matching the frontend expectation)
24
+ data: {"token": "...", "done": false}\n\n
25
+ data: {"token": "", "done": true}\n\n
26
+ """
27
+
28
+ from __future__ import annotations
29
+ import os
30
+ import json
31
+ import math
32
+ import time
33
+ import threading
34
+ import queue
35
+ from typing import Iterator, Optional
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+
41
+ from fastapi import FastAPI, HTTPException
42
+ from fastapi.middleware.cors import CORSMiddleware
43
+ from fastapi.responses import StreamingResponse
44
+ from pydantic import BaseModel, Field
45
+
46
+ from huggingface_hub import hf_hub_download, snapshot_download
47
+ from tokenizers import Tokenizer
48
+
49
+
50
+ # ─── Config ───────────────────────────────────────────────────────────────────
51
+
52
+ MODEL_REPO = os.environ.get("MODEL_REPO", "my-username/Mochiva-model")
53
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
54
+ DEVICE = "cpu"
55
+ MAX_CTX = int(os.environ.get("MAX_CTX", "4096"))
56
+
57
+
58
+ # ─── PyTorch model (mirrors Flax model in mochiva_training/model.py) ─────────
59
+
60
+ class RMSNorm(nn.Module):
61
+ def __init__(self, dim: int, eps: float = 1e-6):
62
+ super().__init__()
63
+ self.eps = eps
64
+ self.scale = nn.Parameter(torch.ones(dim))
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
68
+ return (x.float() / rms).to(x.dtype) * self.scale
69
+
70
+
71
+ def precompute_freqs_cis(
72
+ head_dim: int,
73
+ max_seq: int,
74
+ theta: float = 10_000.0,
75
+ scaling_factor: float = 1.0,
76
+ ) -> torch.Tensor:
77
+ half = head_dim // 2
78
+ freqs = 1.0 / (theta ** (torch.arange(0, half, dtype=torch.float32) / half))
79
+ freqs = freqs / scaling_factor
80
+ t = torch.arange(max_seq, dtype=torch.float32)
81
+ freqs = torch.outer(t, freqs) # (seq, half)
82
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
83
+
84
+
85
+ def apply_rope(
86
+ xq: torch.Tensor, # (B, T, nh, hd)
87
+ xk: torch.Tensor,
88
+ freqs_cis: torch.Tensor, # (T, hd//2) complex
89
+ ) -> tuple[torch.Tensor, torch.Tensor]:
90
+ def rotate(x):
91
+ x_c = x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
92
+ x_c = torch.view_as_complex(x_c) # (..., half)
93
+ fc = freqs_cis.unsqueeze(0).unsqueeze(2) # (1, T, 1, half)
94
+ out = torch.view_as_real(x_c * fc).reshape(*x.shape)
95
+ return out.to(x.dtype)
96
+ return rotate(xq), rotate(xk)
97
+
98
+
99
+ class CausalSelfAttention(nn.Module):
100
+ def __init__(self, cfg: dict):
101
+ super().__init__()
102
+ self.nh = cfg["num_attention_heads"]
103
+ self.hd = cfg["head_dim"]
104
+ H = cfg["hidden_size"]
105
+ self.q_proj = nn.Linear(H, self.nh * self.hd, bias=False)
106
+ self.k_proj = nn.Linear(H, self.nh * self.hd, bias=False)
107
+ self.v_proj = nn.Linear(H, self.nh * self.hd, bias=False)
108
+ self.o_proj = nn.Linear(self.nh * self.hd, H, bias=False)
109
+
110
+ def forward(
111
+ self,
112
+ x: torch.Tensor, # (B, T, H)
113
+ freqs_cis: torch.Tensor, # (T, hd//2)
114
+ mask: torch.Tensor, # (1, 1, T, T) bool
115
+ kv_cache: Optional[dict] = None,
116
+ ) -> torch.Tensor:
117
+ B, T, _ = x.shape
118
+ nh, hd = self.nh, self.hd
119
+
120
+ q = self.q_proj(x).view(B, T, nh, hd)
121
+ k = self.k_proj(x).view(B, T, nh, hd)
122
+ v = self.v_proj(x).view(B, T, nh, hd)
123
+
124
+ q, k = apply_rope(q, k, freqs_cis)
125
+
126
+ if kv_cache is not None:
127
+ # Append current k, v to cache
128
+ if "k" in kv_cache:
129
+ k = torch.cat([kv_cache["k"], k], dim=1)
130
+ v = torch.cat([kv_cache["v"], v], dim=1)
131
+ kv_cache["k"] = k
132
+ kv_cache["v"] = v
133
+
134
+ # (B, nh, T, hd)
135
+ q = q.transpose(1, 2)
136
+ k = k.transpose(1, 2)
137
+ v = v.transpose(1, 2)
138
+
139
+ scale = 1.0 / math.sqrt(hd)
140
+ attn = torch.einsum("bhqd,bhkd->bhqk", q, k) * scale
141
+
142
+ # Apply causal mask (only over current q positions)
143
+ Tq, Tk = attn.shape[-2], attn.shape[-1]
144
+ if mask is not None:
145
+ m = mask[..., :Tq, :Tk]
146
+ attn = attn.masked_fill(~m, float("-inf"))
147
+
148
+ attn = F.softmax(attn.float(), dim=-1).to(q.dtype)
149
+ out = torch.einsum("bhqk,bhkd->bhqd", attn, v)
150
+ out = out.transpose(1, 2).contiguous().view(B, Tq, nh * hd)
151
+ return self.o_proj(out)
152
+
153
+
154
+ class SwiGLUMLP(nn.Module):
155
+ def __init__(self, cfg: dict):
156
+ super().__init__()
157
+ H, I = cfg["hidden_size"], cfg["intermediate_size"]
158
+ self.gate_proj = nn.Linear(H, I, bias=False)
159
+ self.up_proj = nn.Linear(H, I, bias=False)
160
+ self.down_proj = nn.Linear(I, H, bias=False)
161
+
162
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
163
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
164
+
165
+
166
+ class MochivaBlock(nn.Module):
167
+ def __init__(self, cfg: dict):
168
+ super().__init__()
169
+ eps = cfg.get("rms_norm_eps", 1e-6)
170
+ self.attn_norm = RMSNorm(cfg["hidden_size"], eps)
171
+ self.mlp_norm = RMSNorm(cfg["hidden_size"], eps)
172
+ self.attn = CausalSelfAttention(cfg)
173
+ self.mlp = SwiGLUMLP(cfg)
174
+
175
+ def forward(self, x, freqs_cis, mask, kv_cache=None):
176
+ x = x + self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache)
177
+ x = x + self.mlp(self.mlp_norm(x))
178
+ return x
179
+
180
+
181
+ class MochivaForInference(nn.Module):
182
+ """
183
+ Causal LM for inference.
184
+ Weights loaded from safetensors (exported by export.py).
185
+ Uses KV-cache for efficient auto-regressive decoding.
186
+ """
187
+
188
+ def __init__(self, cfg: dict):
189
+ super().__init__()
190
+ self.cfg = cfg
191
+ V = cfg["vocab_size"]
192
+ H = cfg["hidden_size"]
193
+ L = cfg["num_hidden_layers"]
194
+
195
+ self.embed_tokens = nn.Embedding(V, H)
196
+ self.layers = nn.ModuleList([MochivaBlock(cfg) for _ in range(L)])
197
+ self.norm = RMSNorm(H, cfg.get("rms_norm_eps", 1e-6))
198
+ # LM head is tied to embeddings β€” no extra parameter
199
+
200
+ hd = cfg["head_dim"]
201
+ ctx = cfg["max_position_embeddings"]
202
+ theta = cfg.get("rope_theta", 10_000.0)
203
+ scale = cfg.get("rope_scaling_factor", 1.0)
204
+ freqs = precompute_freqs_cis(hd, ctx, theta, scale)
205
+ self.register_buffer("freqs_cis", freqs) # (ctx, hd//2)
206
+
207
+ def forward(
208
+ self,
209
+ input_ids: torch.Tensor, # (B, T)
210
+ kv_caches: Optional[list] = None,
211
+ ) -> torch.Tensor: # (B, T, V)
212
+ B, T = input_ids.shape
213
+
214
+ # If we have a KV cache, the position offset is the cached length
215
+ offset = 0
216
+ if kv_caches and "k" in kv_caches[0]:
217
+ offset = kv_caches[0]["k"].shape[1]
218
+
219
+ x = self.embed_tokens(input_ids) # (B, T, H)
220
+
221
+ # Causal mask over full sequence (offset + T)
222
+ full_len = offset + T
223
+ mask = torch.tril(torch.ones(full_len, full_len, dtype=torch.bool,
224
+ device=x.device))
225
+ mask = mask.unsqueeze(0).unsqueeze(0) # (1,1,full,full)
226
+
227
+ freqs = self.freqs_cis[offset : offset + T]
228
+
229
+ for i, layer in enumerate(self.layers):
230
+ kvc = kv_caches[i] if kv_caches else None
231
+ x = layer(x, freqs, mask, kvc)
232
+
233
+ x = self.norm(x)
234
+ logits = x @ self.embed_tokens.weight.T # (B, T, V)
235
+ return logits
236
+
237
+ @torch.inference_mode()
238
+ def generate_stream(
239
+ self,
240
+ input_ids: torch.Tensor, # (1, prompt_len)
241
+ max_new_tokens: int = 256,
242
+ temperature: float = 0.8,
243
+ top_p: float = 0.9,
244
+ top_k: int = 50,
245
+ repetition_penalty: float = 1.1,
246
+ eos_token_id: int = 2,
247
+ ) -> Iterator[int]:
248
+ """
249
+ Yields token IDs one by one.
250
+ Uses KV-cache for O(1) per-step memory after prompt encoding.
251
+ """
252
+ kv_caches = [{} for _ in self.layers]
253
+
254
+ # encode prompt
255
+ logits = self(input_ids, kv_caches) # (1, T, V)
256
+ next_token = _sample(
257
+ logits[:, -1, :], temperature, top_p, top_k,
258
+ input_ids, repetition_penalty
259
+ )
260
+ yield int(next_token)
261
+
262
+ generated = input_ids.tolist()[0] + [int(next_token)]
263
+ cur = next_token.unsqueeze(0)
264
+
265
+ for _ in range(max_new_tokens - 1):
266
+ logits = self(cur, kv_caches) # (1, 1, V)
267
+ next_token = _sample(
268
+ logits[:, -1, :], temperature, top_p, top_k,
269
+ torch.tensor([generated]), repetition_penalty
270
+ )
271
+ tok_id = int(next_token)
272
+ if tok_id == eos_token_id:
273
+ break
274
+ generated.append(tok_id)
275
+ yield tok_id
276
+ cur = next_token.unsqueeze(0)
277
+
278
+
279
+ # ─── Sampling ─────────────────────────────────────────────────────────────────
280
+
281
+ def _sample(
282
+ logits: torch.Tensor, # (1, V)
283
+ temperature: float,
284
+ top_p: float,
285
+ top_k: int,
286
+ context_ids: torch.Tensor,
287
+ repetition_penalty: float,
288
+ ) -> torch.Tensor:
289
+ logits = logits.float().squeeze(0) # (V,)
290
+
291
+ # repetition penalty
292
+ if repetition_penalty != 1.0:
293
+ for tok in set(context_ids.tolist()):
294
+ if logits[tok] < 0:
295
+ logits[tok] *= repetition_penalty
296
+ else:
297
+ logits[tok] /= repetition_penalty
298
+
299
+ if temperature < 1e-4:
300
+ return logits.argmax(keepdim=True)
301
+
302
+ logits = logits / temperature
303
+
304
+ # top-k
305
+ if top_k > 0:
306
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
307
+ logits[logits < v[-1]] = float("-inf")
308
+
309
+ # top-p (nucleus)
310
+ if top_p < 1.0:
311
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
312
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
313
+ sorted_remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
314
+ sorted_logits[sorted_remove] = float("-inf")
315
+ logits = torch.zeros_like(logits).scatter_(0, sorted_idx, sorted_logits)
316
+
317
+ probs = F.softmax(logits, dim=-1)
318
+ return torch.multinomial(probs, num_samples=1)
319
+
320
+
321
+ # ─── Weight loading ───────────────────────────────────────────────────────────
322
+
323
+ def _remap_key(key: str) -> str:
324
+ """
325
+ Map flattened safetensors key β†’ PyTorch nn.Module attribute path.
326
+ E.g. "embed_tokens/embedding" β†’ "embed_tokens.weight"
327
+ "layer_0/attn/q_proj/kernel" β†’ "layers.0.attn.q_proj.weight"
328
+ """
329
+ key = key.replace("/", ".")
330
+ key = key.replace("embed_tokens.embedding", "embed_tokens.weight")
331
+ # layer_N β†’ layers.N
332
+ import re
333
+ key = re.sub(r"layer_(\d+)\.", r"layers.\1.", key)
334
+ # Flax kernel β†’ PyTorch weight
335
+ key = key.replace(".kernel", ".weight")
336
+ # norms: scale β†’ scale (already matches RMSNorm)
337
+ return key
338
+
339
+
340
+ def load_weights(model: MochivaForInference, weights_path: str):
341
+ try:
342
+ from safetensors.torch import load_file
343
+ flat = load_file(weights_path, device=DEVICE)
344
+ except Exception:
345
+ # fallback: numpy npz
346
+ import numpy as np
347
+ npz = np.load(weights_path)
348
+ flat = {k: torch.from_numpy(v) for k, v in npz.items()}
349
+
350
+ state_dict = model.state_dict()
351
+ mapped = {}
352
+ unmatched_st = []
353
+
354
+ for raw_key, tensor in flat.items():
355
+ pt_key = _remap_key(raw_key)
356
+ if pt_key in state_dict:
357
+ # Transpose: Flax Dense kernels are (in, out), PyTorch Linear (out, in)
358
+ if "weight" in pt_key and pt_key not in ("embed_tokens.weight",) \
359
+ and len(tensor.shape) == 2:
360
+ tensor = tensor.T
361
+ mapped[pt_key] = tensor.to(state_dict[pt_key].dtype)
362
+ else:
363
+ unmatched_st.append(pt_key)
364
+
365
+ # Tie LM head (no separate parameter)
366
+ missing, unexpected = model.load_state_dict(mapped, strict=False)
367
+ if missing:
368
+ print(f"[model] Missing keys: {missing[:5]}")
369
+ if unexpected:
370
+ print(f"[model] Unexpected keys: {unexpected[:5]}")
371
+ print(f"[model] Loaded {len(mapped)} tensors")
372
+
373
+
374
+ # ─── Startup: load model ─────────────────────────────────────────────────────
375
+
376
+ print(f"[startup] Downloading {MODEL_REPO} from HF Hub …")
377
+ t0 = time.time()
378
+
379
+ model_dir = snapshot_download(
380
+ MODEL_REPO,
381
+ token=HF_TOKEN,
382
+ ignore_patterns=["*.msgpack", "flax_model*"],
383
+ )
384
+
385
+ with open(f"{model_dir}/config.json") as f:
386
+ hf_cfg = json.load(f)
387
+
388
+ with open(f"{model_dir}/special_tokens.json") as f:
389
+ special = json.load(f)
390
+
391
+ tokenizer = Tokenizer.from_file(f"{model_dir}/tokenizer.json")
392
+ BOS_ID = special["bos_id"]
393
+ EOS_ID = special["eos_id"]
394
+ PAD_ID = special["pad_id"]
395
+
396
+ with open(f"{model_dir}/generation_config.json") as f:
397
+ gen_cfg = json.load(f)
398
+
399
+ model = MochivaForInference(hf_cfg)
400
+ model.eval()
401
+
402
+ weights_file = f"{model_dir}/model.safetensors"
403
+ if not os.path.exists(weights_file):
404
+ weights_file = f"{model_dir}/model_weights.npz"
405
+
406
+ load_weights(model, weights_file)
407
+ print(f"[startup] Model ready in {time.time()-t0:.1f}s "
408
+ f"(params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M)")
409
+
410
+
411
+ # ─── FastAPI ──────────────────────────────────────────────────────────────────
412
+
413
+ app = FastAPI(title="Mochiva Inference", version="1.0.0")
414
+
415
+ app.add_middleware(
416
+ CORSMiddleware,
417
+ allow_origins=["*"],
418
+ allow_methods=["*"],
419
+ allow_headers=["*"],
420
+ )
421
+
422
+
423
+ # ─── Request / Response schemas ───────────────────────────────────────────────
424
+
425
+ class GenerateRequest(BaseModel):
426
+ prompt: str
427
+ max_new_tokens: int = Field(default=256, ge=1, le=1024)
428
+ temperature: float = Field(default=0.8, ge=0.01, le=2.0)
429
+ top_p: float = Field(default=0.9, ge=0.0, le=1.0)
430
+ top_k: int = Field(default=50, ge=0, le=500)
431
+ repetition_penalty: float = Field(default=1.1, ge=1.0, le=3.0)
432
+ mochi_name: str = "" # injected persona context
433
+
434
+
435
+ # ─── Streaming SSE endpoint ──────────────────────────────────────────────────
436
+
437
+ def _sse_event(token: str = "", done: bool = False) -> str:
438
+ payload = json.dumps({"token": token, "done": done})
439
+ return f"data: {payload}\n\n"
440
+
441
+
442
+ def _generate_sse(req: GenerateRequest) -> Iterator[str]:
443
+ # Build prompt with persona context if provided
444
+ prompt = req.prompt
445
+ if req.mochi_name:
446
+ prompt = (
447
+ f"<mochi>You are {req.mochi_name}, a cute and playful virtual pet "
448
+ f"called a Mochi. You are friendly, energetic, and love the person "
449
+ f"who takes care of you.</mochi> {prompt}"
450
+ )
451
+
452
+ ids = [BOS_ID] + tokenizer.encode(prompt).ids
453
+ if len(ids) > MAX_CTX - req.max_new_tokens:
454
+ ids = ids[-(MAX_CTX - req.max_new_tokens):]
455
+
456
+ input_ids = torch.tensor([ids], dtype=torch.long)
457
+
458
+ tok_queue: queue.Queue[Optional[int]] = queue.Queue()
459
+
460
+ def _worker():
461
+ try:
462
+ for tok_id in model.generate_stream(
463
+ input_ids,
464
+ max_new_tokens = req.max_new_tokens,
465
+ temperature = req.temperature,
466
+ top_p = req.top_p,
467
+ top_k = req.top_k,
468
+ repetition_penalty = req.repetition_penalty,
469
+ eos_token_id = EOS_ID,
470
+ ):
471
+ tok_queue.put(tok_id)
472
+ finally:
473
+ tok_queue.put(None) # sentinel
474
+
475
+ t = threading.Thread(target=_worker, daemon=True)
476
+ t.start()
477
+
478
+ buf = []
479
+ while True:
480
+ tok_id = tok_queue.get()
481
+ if tok_id is None:
482
+ break
483
+ buf.append(tok_id)
484
+ # Decode incrementally (handles multi-byte UTF-8 via backtrack)
485
+ text = tokenizer.decode(buf)
486
+ if text.endswith("▁") or text.endswith("Δ "):
487
+ # incomplete byte β€” accumulate
488
+ continue
489
+ yield _sse_event(token=text)
490
+ buf = []
491
+
492
+ if buf:
493
+ yield _sse_event(token=tokenizer.decode(buf))
494
+ yield _sse_event(done=True)
495
+
496
+
497
+ @app.post("/generate")
498
+ def generate_stream(req: GenerateRequest):
499
+ return StreamingResponse(
500
+ _generate_sse(req),
501
+ media_type="text/event-stream",
502
+ headers={
503
+ "Cache-Control": "no-cache",
504
+ "X-Accel-Buffering": "no",
505
+ },
506
+ )
507
+
508
+
509
+ # ─── Non-streaming endpoint ───────────────────────────────────────────────────
510
+
511
+ @app.post("/generate_full")
512
+ def generate_full(req: GenerateRequest):
513
+ tokens = []
514
+ for chunk in _generate_sse(req):
515
+ if chunk.startswith("data: "):
516
+ obj = json.loads(chunk[6:])
517
+ if not obj["done"]:
518
+ tokens.append(obj["token"])
519
+ return {"text": "".join(tokens), "model": MODEL_REPO}
520
+
521
+
522
+ # ─── Health / info ────────────────────────────────────────────────────────────
523
+
524
+ @app.get("/health")
525
+ def health():
526
+ return {"status": "ok", "model": MODEL_REPO}
527
+
528
+
529
+ @app.get("/info")
530
+ def info():
531
+ return {
532
+ "model": MODEL_REPO,
533
+ "vocab_size": hf_cfg["vocab_size"],
534
+ "layers": hf_cfg["num_hidden_layers"],
535
+ "hidden": hf_cfg["hidden_size"],
536
+ "context": hf_cfg["max_position_embeddings"],
537
+ "device": DEVICE,
538
+ }
539
+
540
+
541
+ # ─── Entrypoint ───────────────────────────────────────────────────────────────
542
+
543
+ if __name__ == "__main__":
544
+ import uvicorn
545
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Mochiva HF Space β€” inference requirements ────────────────────────────────
2
+ # CPU-only PyTorch (much smaller image than CUDA build)
3
+ torch==2.3.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
4
+
5
+ # Web server
6
+ fastapi==0.111.0
7
+ uvicorn[standard]==0.30.1
8
+ pydantic==2.7.1
9
+
10
+ # HF Hub for downloading the model at startup
11
+ huggingface_hub==0.23.2
12
+
13
+ # Fast BPE tokeniser (same library used at training time)
14
+ tokenizers==0.19.1
15
+
16
+ # Weights format
17
+ safetensors==0.4.3