third-eye / zerogpu_backend.py
mitvho09's picture
Upload folder using huggingface_hub
031e3f9 verified
Raw
History Blame Contribute Delete
7.19 kB
"""In-process inference backend for Hugging Face ZeroGPU Spaces.
All three stages run in one Python environment on a ZeroGPU slice, exposed as
plain functions (``describe_scene``, ``transcribe_audio``, ``speak``) so the
Gradio app can call them exactly like the Modal backend.
Model stack (single, Transformers >= 5.4 compatible environment):
* Vision / OCR -> Qwen/Qwen2.5-VL-3B-Instruct (bilingual EN/ZH, < 4B)
* Speech-to-text -> CohereLabs/cohere-transcribe-03-2026 (via cohere_stt)
* Text-to-speech -> openbmb/VoxCPM2
Models are lazy-loaded once and cached; loading happens inside the GPU context
so it works under ZeroGPU's on-demand allocation.
"""
from __future__ import annotations
import io
# ``spaces`` only exists on a Hugging Face Space. Fall back to a no-op decorator
# so this module still imports in a plain environment (e.g. for unit tests).
try:
import spaces
GPU = spaces.GPU
except Exception: # pragma: no cover - exercised only off-Space
def GPU(*args, **kwargs):
if len(args) == 1 and callable(args[0]) and not kwargs:
return args[0]
def decorator(fn):
return fn
return decorator
VISION_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
TTS_MODEL_ID = "openbmb/VoxCPM2"
_vision = None
_tts = None
# --------------------------------------------------------------------------- #
# Pure helper (unit-testable, no GPU): stitch overlapping OCR bands.
# --------------------------------------------------------------------------- #
def stitch_overlapping_text(parts: list[str], max_overlap_words: int = 12) -> str:
"""Join OCR results from overlapping image bands, removing the duplicated
region. Finds the longest suffix of the running text that matches the prefix
of the next part (case-insensitive) and drops it."""
parts = [p.strip() for p in parts if p and p.strip()]
if not parts:
return ""
words = parts[0].split()
for nxt in parts[1:]:
nwords = nxt.split()
limit = min(len(words), len(nwords), max_overlap_words)
overlap = 0
for k in range(limit, 0, -1):
if [w.lower() for w in words[-k:]] == [w.lower() for w in nwords[:k]]:
overlap = k
break
words += nwords[overlap:]
return " ".join(words)
# --------------------------------------------------------------------------- #
# Vision / OCR — Qwen2.5-VL
# --------------------------------------------------------------------------- #
def _load_vision():
global _vision
if _vision is None:
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
model = (
Qwen2_5_VLForConditionalGeneration.from_pretrained(
VISION_MODEL_ID,
torch_dtype=dtype,
)
.to("cuda")
.eval()
)
# Cap visual tokens so menus/labels stay fast without losing legible text.
processor = AutoProcessor.from_pretrained(
VISION_MODEL_ID,
min_pixels=256 * 28 * 28,
max_pixels=1280 * 28 * 28,
)
param = next(model.parameters())
print(
f"[third-eye VISION] loaded {VISION_MODEL_ID} "
f"| device={param.device} | dtype={param.dtype}",
flush=True,
)
_vision = (model, processor)
return _vision
def _chat_once(model, processor, image, prompt: str) -> str:
import torch
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt",
).to(model.device)
with torch.inference_mode():
generated = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.2,
)
trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated)]
answer = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return answer.strip()
@GPU(duration=120)
def describe_scene(
image_bytes: bytes, question: str, lang: str = "en", tile: bool = False
) -> str:
import time
from PIL import Image
model, processor = _load_vision()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
prompt = question.strip() or "Describe everything visible for a blind user."
if lang == "zh":
prompt += " Answer in Chinese."
start = time.time()
if not tile:
answer = _chat_once(model, processor, image, prompt)
print(f"[third-eye VISION] chat: {time.time() - start:.2f}s", flush=True)
return answer
# Tiled OCR for verbatim Read Text mode: splitting into overlapping top/bottom
# bands enlarges the relative text per call; the overlap is stitched away.
w, h = image.size
bands = [(0, 0, w, int(h * 0.55)), (0, int(h * 0.45), w, h)]
parts = [_chat_once(model, processor, image.crop(box), prompt) for box in bands]
answer = stitch_overlapping_text(parts)
print(
f"[third-eye VISION] tiled chat ({len(bands)} bands): "
f"{time.time() - start:.2f}s",
flush=True,
)
return answer
# --------------------------------------------------------------------------- #
# Speech-to-text — Cohere Transcribe (shared with the Modal backend)
# --------------------------------------------------------------------------- #
@GPU(duration=120)
def transcribe_audio(audio_bytes: bytes, language: str = "en") -> str:
from cohere_stt import transcribe_wav_bytes
return transcribe_wav_bytes(audio_bytes, language)
# --------------------------------------------------------------------------- #
# Text-to-speech — VoxCPM2
# --------------------------------------------------------------------------- #
def _load_tts():
global _tts
if _tts is None:
from voxcpm import VoxCPM
_tts = VoxCPM.from_pretrained(TTS_MODEL_ID, load_denoiser=False)
return _tts
@GPU(duration=120)
def speak(text: str, lang: str = "en") -> bytes:
import numpy as np
import soundfile as sf
if not text.strip():
raise ValueError("Cannot synthesize empty text.")
model = _load_tts()
waveform = model.generate(
text=text.strip()[:500],
cfg_value=2.0,
inference_timesteps=10,
)
output = io.BytesIO()
sf.write(
output,
np.asarray(waveform, dtype=np.float32),
model.tts_model.sample_rate,
format="WAV",
)
return output.getvalue()