third-eye / modal_backend.py
mitvho09's picture
Upload folder using huggingface_hub
031e3f9 verified
Raw
History Blame Contribute Delete
9.29 kB
from __future__ import annotations
import io
import modal
APP_NAME = "third-eye-backend"
VISION_MODEL_ID = "openbmb/MiniCPM-V-2"
TTS_MODEL_ID = "openbmb/VoxCPM2"
app = modal.App(APP_NAME)
model_cache = modal.Volume.from_name("third-eye-model-cache", create_if_missing=True)
cache_mount = {"/cache": model_cache}
cache_env = {
"HF_HOME": "/cache/huggingface",
"TRANSFORMERS_CACHE": "/cache/huggingface",
}
vision_image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install(
"torch==2.1.2",
"torchvision==0.16.2",
"transformers==4.36.2",
"accelerate>=0.25",
"sentencepiece>=0.1.99",
"timm==0.9.10",
"pillow>=10",
"peft==0.9.0",
"numpy<2",
)
.env(cache_env)
)
tts_image = (
modal.Image.from_registry(
"nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04",
add_python="3.11",
)
.apt_install("ffmpeg", "libsox-dev", "build-essential")
.pip_install(
"torch>=2.5",
"voxcpm>=0.1.0",
"misaki[zh]>=0.9",
"soundfile>=0.12",
"numpy>=1.26",
)
.env({**cache_env, "TORCHDYNAMO_DISABLE": "1"})
)
stt_image = (
modal.Image.from_registry(
"nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04",
add_python="3.11",
)
.pip_install(
"torch>=2.5",
"transformers>=5.4",
"accelerate>=1.0",
"librosa>=0.10",
"sentencepiece",
"protobuf",
"soundfile>=0.12",
)
.add_local_file("cohere_stt.py", "/root/cohere_stt.py", copy=True)
.env(cache_env)
)
_vision_model = None
_vision_tokenizer = None
_tts_model = None
def _load_vision():
global _vision_model, _vision_tokenizer
if _vision_model is None:
import torch
from transformers import AutoModel, AutoTokenizer
# The MiniCPM-V-2 model card recommends bfloat16; it is more numerically
# stable than float16 and reduces OCR drift on small text.
_vision_model = (
AutoModel.from_pretrained(
VISION_MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
.to(device="cuda", dtype=torch.bfloat16)
.eval()
)
_vision_tokenizer = AutoTokenizer.from_pretrained(
VISION_MODEL_ID,
trust_remote_code=True,
)
param = next(_vision_model.parameters())
print(
f"[third-eye VISION] loaded {VISION_MODEL_ID} "
f"| device={param.device} | dtype={param.dtype}",
flush=True,
)
return _vision_model, _vision_tokenizer
def _chat_once(model, tokenizer, image, prompt: str) -> str:
# Low temperature (0.2) for faithful, repeatable output while avoiding
# the refusals that greedy (sampling=False) sometimes triggers on
# MiniCPM-V-2. Keeps hallucination low (no "$20" for an "$18" item).
answer, _, _ = model.chat(
image=image,
msgs=[{"role": "user", "content": prompt}],
context=None,
tokenizer=tokenizer,
sampling=True,
temperature=0.2,
)
return answer.strip()
def stitch_overlapping_text(parts: list[str], max_overlap_words: int = 12) -> str:
"""Join OCR results from overlapping image bands, removing the duplicated
region. Finds the longest suffix of the running text that matches the prefix
of the next part (case-insensitive) and drops it. Pure function — unit tested."""
parts = [p.strip() for p in parts if p and p.strip()]
if not parts:
return ""
words = parts[0].split()
for nxt in parts[1:]:
nwords = nxt.split()
limit = min(len(words), len(nwords), max_overlap_words)
overlap = 0
for k in range(limit, 0, -1):
if [w.lower() for w in words[-k:]] == [w.lower() for w in nwords[:k]]:
overlap = k
break
words += nwords[overlap:]
return " ".join(words)
@app.function(
gpu="A10G",
image=vision_image,
timeout=300,
volumes=cache_mount,
)
def describe_scene(
image_bytes: bytes, question: str, lang: str = "en", tile: bool = False
) -> str:
import time
from PIL import Image
model, tokenizer = _load_vision()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
prompt = question.strip() or "Describe everything visible for a blind user."
if lang == "zh":
prompt += " Answer in Chinese."
start = time.time()
if not tile:
answer = _chat_once(model, tokenizer, image, prompt)
print(f"[third-eye VISION] chat: {time.time() - start:.2f}s", flush=True)
return answer
# Tiled OCR for verbatim Read Text mode: small text on a full image exceeds a
# 2.8B VLM's OCR resolution (it merged "MANGO LASSI" -> "MANGOLAISSI"). Splitting
# into overlapping top/bottom bands enlarges the relative text per call; the
# overlap is stitched away. Automatic — no box-drawing, which a blind user can't do.
w, h = image.size
bands = [(0, 0, w, int(h * 0.55)), (0, int(h * 0.45), w, h)]
parts = [_chat_once(model, tokenizer, image.crop(box), prompt) for box in bands]
answer = stitch_overlapping_text(parts)
print(
f"[third-eye VISION] tiled chat ({len(bands)} bands): "
f"{time.time() - start:.2f}s",
flush=True,
)
return answer
def _load_tts():
global _tts_model
if _tts_model is None:
from voxcpm import VoxCPM
_tts_model = VoxCPM.from_pretrained(TTS_MODEL_ID, load_denoiser=False)
return _tts_model
@app.function(
gpu="A10G",
image=tts_image,
timeout=300,
volumes=cache_mount,
)
def speak(text: str, lang: str = "en") -> bytes:
import numpy as np
import soundfile as sf
if not text.strip():
raise ValueError("Cannot synthesize empty text.")
text = text.strip()[:500]
model = _load_tts()
waveform = model.generate(
text=text,
cfg_value=2.0,
inference_timesteps=10,
)
output = io.BytesIO()
sf.write(
output,
np.asarray(waveform, dtype=np.float32),
model.tts_model.sample_rate,
format="WAV",
)
return output.getvalue()
@app.function(
gpu="A10G",
image=stt_image,
timeout=300,
volumes=cache_mount,
secrets=[modal.Secret.from_name("third-eye-hf")],
)
def transcribe_audio(audio_bytes: bytes, language: str = "en") -> str:
from cohere_stt import transcribe_wav_bytes
return transcribe_wav_bytes(audio_bytes, language)
@app.local_entrypoint()
def smoke_test(image_path: str = "assets/sample_menu.jpg"):
image_bytes = open(image_path, "rb").read()
answer = describe_scene.remote(
image_bytes,
"Read the menu and summarize the available items.",
"en",
)
print(answer)
audio = speak.remote(answer, "en")
with open("out.wav", "wb") as output:
output.write(audio)
print("Saved out.wav")
@app.local_entrypoint()
def read_test(image_path: str = "assets/sample_menu.jpg", prompt: str = ""):
"""Test Read Text (verbatim OCR) mode for transcription distortion."""
read_prompt = prompt or (
"Read every word and number in this image exactly as written. "
"Include all text, labels, prices, dates, directions, and signs. "
"Do not interpret or explain - just read the text verbatim."
)
image_bytes = open(image_path, "rb").read()
answer = describe_scene.remote(image_bytes, read_prompt, "en", True)
print(f"READ [{image_path}]:\n{answer}")
@app.local_entrypoint()
def ask_test(
image_path: str = "assets/sample_menu.jpg",
question_text: str = "What is the cheapest item on the menu and how much does it cost?",
):
"""End-to-end 'Ask' pipeline test: speak a question -> STT -> vision answer.
Synthesizes the question to audio so we can compare what was SPOKEN vs HEARD,
then checks whether the vision model actually ANSWERS that question.
"""
print(f"SPOKEN: {question_text!r}")
q_audio = speak.remote(question_text, "en")
heard = transcribe_audio.remote(q_audio, "en")
print(f"HEARD: {heard!r}")
image_bytes = open(image_path, "rb").read()
answer = describe_scene.remote(image_bytes, heard, "en")
print(f"ANSWER: {answer!r}")
@app.local_entrypoint()
def stt_benchmark(audio_path: str = "test_speech.wav"):
import time
audio_bytes = open(audio_path, "rb").read()
print(f"Benchmarking STT on {audio_path} ({len(audio_bytes)} bytes)")
t0 = time.time()
text1 = transcribe_audio.remote(audio_bytes, "en")
cold = time.time() - t0
print(f"\n[COLD] total round-trip: {cold:.1f}s")
print(f"[COLD] transcript: {text1!r}")
t1 = time.time()
text2 = transcribe_audio.remote(audio_bytes, "en")
warm = time.time() - t1
print(f"\n[WARM] total round-trip: {warm:.1f}s")
print(f"[WARM] transcript: {text2!r}")