File size: 4,265 Bytes
d3a7a1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce5cab3
 
 
 
 
 
 
 
 
 
 
d3a7a1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""In-process GPU inference for Hugging Face Spaces ZeroGPU.

Mirrors modal_app.py, but instead of calling Modal over HTTP the models run
locally inside the Space behind @spaces.GPU (which allocates a ZeroGPU for the
duration of each call). Selected by TINYWORLD_INFER=local.

Imported LAZILY (only when the local backend is active) because it pulls in
torch / transformers / voxcpm / whisper — heavy deps the offline test suite and
the Modal-backed path never need. Models are loaded on CPU once and moved to CUDA
inside the GPU-decorated functions (CUDA is only available there on ZeroGPU).
"""

import os
import re
import tempfile

import spaces
import torch

LLM_MODEL_ID = os.environ.get("TINYWORLD_LLM", "nvidia/Nemotron-Mini-4B-Instruct")
VOICE_MODEL_ID = os.environ.get("TINYWORLD_VOICE_MODEL", "openbmb/VoxCPM2")
WHISPER_SIZE = os.environ.get("TINYWORLD_WHISPER", "base")

_llm = None
_tok = None
_voice = None
_whisper = None


# --------------------------------------------------------------------------- LLM
def _load_llm():
    global _llm, _tok
    if _llm is None:
        from transformers import AutoModelForCausalLM, AutoTokenizer
        print(f"[inference] loading {LLM_MODEL_ID} …")
        _tok = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
        _llm = AutoModelForCausalLM.from_pretrained(
            LLM_MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True,
        )
    return _llm, _tok


def warmup():
    """Download + load the LLM into CPU RAM ahead of time (no GPU needed) so the
    first @spaces.GPU call only moves to CUDA and generates — avoids the cold 8GB
    load racing the ZeroGPU duration limit. Safe to call from a background thread."""
    try:
        _load_llm()
        print("[inference] LLM warmed (CPU RAM)")
    except Exception as e:
        print(f"[inference] warmup failed: {e}")


def _strip_think(text):
    text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
    if "<think>" in text:
        parts = text.split("</think>")
        text = parts[-1].strip() if len(parts) > 1 else text
    return text.strip()


@spaces.GPU(duration=120)
def generate_batch(prompts):
    """One raw completion per prompt, all in a single GPU allocation. Returns a
    list of raw strings aligned with ``prompts`` (the reaction engine parses them)."""
    mdl, tok = _load_llm()
    mdl.to("cuda")
    outputs = []
    for prompt in prompts:
        messages = [{"role": "user", "content": prompt}]
        input_text = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        inputs = tok(input_text, return_tensors="pt").to("cuda")
        with torch.no_grad():
            out = mdl.generate(
                **inputs, max_new_tokens=400, do_sample=True, temperature=0.8, top_p=0.9,
            )
        new = out[0][inputs["input_ids"].shape[1]:]
        text = tok.decode(new, skip_special_tokens=True).strip()
        outputs.append(_strip_think(text))
    return outputs


# --------------------------------------------------------------------------- TTS
def _load_voice():
    global _voice
    if _voice is None:
        from voxcpm import VoxCPM
        print(f"[inference] loading {VOICE_MODEL_ID} …")
        _voice = VoxCPM.from_pretrained(VOICE_MODEL_ID)
    return _voice


@spaces.GPU(duration=60)
def synthesize_voice(text, voice_desc):
    """Returns a path to a WAV file (matches voice.generate_voice's contract)."""
    import soundfile as sf
    model = _load_voice()
    wav = model.generate(text=f"{voice_desc}{text}", cfg_value=2.0, inference_timesteps=10)
    path = os.path.join(tempfile.gettempdir(), f"tinyworld_voice_{os.getpid()}.wav")
    sf.write(path, wav, model.tts_model.sample_rate)
    return path


# --------------------------------------------------------------------------- ASR
def _load_whisper():
    global _whisper
    if _whisper is None:
        import whisper
        print(f"[inference] loading Whisper {WHISPER_SIZE} …")
        _whisper = whisper.load_model(WHISPER_SIZE)
    return _whisper


@spaces.GPU(duration=60)
def transcribe_audio(audio_path):
    model = _load_whisper()
    result = model.transcribe(audio_path, fp16=True)
    return (result.get("text") or "").strip()