Spaces:
Sleeping
Sleeping
| """In-process GPU inference for Hugging Face Spaces ZeroGPU. | |
| Mirrors modal_app.py, but instead of calling Modal over HTTP the models run | |
| locally inside the Space behind @spaces.GPU (which allocates a ZeroGPU for the | |
| duration of each call). Selected by TINYWORLD_INFER=local. | |
| Imported LAZILY (only when the local backend is active) because it pulls in | |
| torch / transformers / voxcpm / whisper — heavy deps the offline test suite and | |
| the Modal-backed path never need. Models are loaded on CPU once and moved to CUDA | |
| inside the GPU-decorated functions (CUDA is only available there on ZeroGPU). | |
| """ | |
| import os | |
| import re | |
| import tempfile | |
| import spaces | |
| import torch | |
| LLM_MODEL_ID = os.environ.get("TINYWORLD_LLM", "nvidia/Nemotron-Mini-4B-Instruct") | |
| VOICE_MODEL_ID = os.environ.get("TINYWORLD_VOICE_MODEL", "openbmb/VoxCPM2") | |
| WHISPER_SIZE = os.environ.get("TINYWORLD_WHISPER", "base") | |
| _llm = None | |
| _tok = None | |
| _voice = None | |
| _whisper = None | |
| # --------------------------------------------------------------------------- LLM | |
| def _load_llm(): | |
| global _llm, _tok | |
| if _llm is None: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| print(f"[inference] loading {LLM_MODEL_ID} …") | |
| _tok = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True) | |
| _llm = AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True, | |
| ) | |
| return _llm, _tok | |
| def warmup(): | |
| """Download + load the LLM into CPU RAM ahead of time (no GPU needed) so the | |
| first @spaces.GPU call only moves to CUDA and generates — avoids the cold 8GB | |
| load racing the ZeroGPU duration limit. Safe to call from a background thread.""" | |
| try: | |
| _load_llm() | |
| print("[inference] LLM warmed (CPU RAM)") | |
| except Exception as e: | |
| print(f"[inference] warmup failed: {e}") | |
| def _strip_think(text): | |
| text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip() | |
| if "<think>" in text: | |
| parts = text.split("</think>") | |
| text = parts[-1].strip() if len(parts) > 1 else text | |
| return text.strip() | |
| def generate_batch(prompts): | |
| """One raw completion per prompt, all in a single GPU allocation. Returns a | |
| list of raw strings aligned with ``prompts`` (the reaction engine parses them).""" | |
| mdl, tok = _load_llm() | |
| mdl.to("cuda") | |
| outputs = [] | |
| for prompt in prompts: | |
| messages = [{"role": "user", "content": prompt}] | |
| input_text = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
| inputs = tok(input_text, return_tensors="pt").to("cuda") | |
| with torch.no_grad(): | |
| out = mdl.generate( | |
| **inputs, max_new_tokens=400, do_sample=True, temperature=0.8, top_p=0.9, | |
| ) | |
| new = out[0][inputs["input_ids"].shape[1]:] | |
| text = tok.decode(new, skip_special_tokens=True).strip() | |
| outputs.append(_strip_think(text)) | |
| return outputs | |
| # --------------------------------------------------------------------------- TTS | |
| def _load_voice(): | |
| global _voice | |
| if _voice is None: | |
| from voxcpm import VoxCPM | |
| print(f"[inference] loading {VOICE_MODEL_ID} …") | |
| _voice = VoxCPM.from_pretrained(VOICE_MODEL_ID) | |
| return _voice | |
| def synthesize_voice(text, voice_desc): | |
| """Returns a path to a WAV file (matches voice.generate_voice's contract).""" | |
| import soundfile as sf | |
| model = _load_voice() | |
| wav = model.generate(text=f"{voice_desc}{text}", cfg_value=2.0, inference_timesteps=10) | |
| path = os.path.join(tempfile.gettempdir(), f"tinyworld_voice_{os.getpid()}.wav") | |
| sf.write(path, wav, model.tts_model.sample_rate) | |
| return path | |
| # --------------------------------------------------------------------------- ASR | |
| def _load_whisper(): | |
| global _whisper | |
| if _whisper is None: | |
| import whisper | |
| print(f"[inference] loading Whisper {WHISPER_SIZE} …") | |
| _whisper = whisper.load_model(WHISPER_SIZE) | |
| return _whisper | |
| def transcribe_audio(audio_path): | |
| model = _load_whisper() | |
| result = model.transcribe(audio_path, fp16=True) | |
| return (result.get("text") or "").strip() | |