tabras / modal_app.py
Codex
Size Modal pools under the 10-GPU plan cap (card 3/4, art 3/4, boss 1/2)
8f039b3
Raw
History Blame Contribute Delete
5.24 kB
"""Modal GPU endpoints for Tabras: MiniCPM card authoring, Nemotron boss play,
and SDXL-Turbo art. The Gradio Space calls these over HTTP, so all heavy compute
runs on dedicated, autoscaled Modal GPUs instead of the Space.
Deploy: modal deploy modal_app.py
Then set the printed URLs as Space variables:
TABRAS_CARD_ENDPOINT -> CardModel.chat URL
TABRAS_BOSS_ENDPOINT -> BossModel.chat URL
TABRAS_ART_ENDPOINT -> ArtModel.generate URL
"""
import modal
CACHE = "/cache"
hf_cache = modal.Volume.from_name("tabras-hf-cache", create_if_missing=True)
MINICPM = "openbmb/MiniCPM-V-4"
NEMOTRON = "nvidia/Nemotron-Mini-4B-Instruct"
SDXL = "stabilityai/sdxl-turbo"
llm_image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install(
"torch", "transformers==4.49.0", "accelerate", "sentencepiece",
"torchvision", "einops", "pillow", "fastapi[standard]",
)
.env({"HF_HOME": CACHE})
)
art_image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install(
# diffusers 0.31 supports SDXL-Turbo and is compatible with transformers
# 4.49; newer diffusers (0.35+) imports flux2 which needs Qwen3ForCausalLM.
"torch", "diffusers==0.31.0", "transformers==4.49.0", "accelerate",
"safetensors", "pillow", "fastapi[standard]",
)
.env({"HF_HOME": CACHE})
)
app = modal.App("tabras-models")
# ---- MiniCPM: card authoring (OpenAI-compatible chat) ----
@app.cls(image=llm_image, gpu="A10G", volumes={CACHE: hf_cache}, min_containers=3, max_containers=4, scaledown_window=600, timeout=600)
class CardModel:
@modal.enter()
def load(self) -> None:
import torch
from transformers import AutoModel, AutoTokenizer
self.tok = AutoTokenizer.from_pretrained(MINICPM, trust_remote_code=True)
self.model = (
AutoModel.from_pretrained(MINICPM, trust_remote_code=True, attn_implementation="sdpa", torch_dtype=torch.float16)
.eval()
.cuda()
)
@modal.fastapi_endpoint(method="POST")
def chat(self, item: dict) -> dict:
msgs = item.get("messages", [])
system = " ".join(m["content"] for m in msgs if m.get("role") == "system")
user = " ".join(m["content"] for m in msgs if m.get("role") == "user")
temp = float(item.get("temperature", 0.7))
text = str(
self.model.chat(
msgs=[{"role": "user", "content": user}],
image=None,
tokenizer=self.tok,
system_prompt=system,
sampling=temp > 0,
temperature=max(temp, 0.01),
max_new_tokens=int(item.get("max_tokens", 128)),
)
)
return {"choices": [{"message": {"role": "assistant", "content": text}}]}
# ---- Nemotron: boss play (OpenAI-compatible chat) ----
@app.cls(image=llm_image, gpu="A10G", volumes={CACHE: hf_cache}, min_containers=1, max_containers=2, scaledown_window=600, timeout=600)
class BossModel:
@modal.enter()
def load(self) -> None:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
self.tok = AutoTokenizer.from_pretrained(NEMOTRON)
self.model = AutoModelForCausalLM.from_pretrained(NEMOTRON, torch_dtype=torch.float16).eval().cuda()
@modal.fastapi_endpoint(method="POST")
def chat(self, item: dict) -> dict:
import torch
msgs = item.get("messages", [])
inputs = self.tok.apply_chat_template(msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
temp = float(item.get("temperature", 0.2))
with torch.no_grad():
out = self.model.generate(
inputs,
max_new_tokens=int(item.get("max_tokens", 96)),
do_sample=temp > 0,
temperature=max(temp, 0.01),
)
text = self.tok.decode(out[0][inputs.shape[-1]:], skip_special_tokens=True)
return {"choices": [{"message": {"role": "assistant", "content": text}}]}
# ---- SDXL-Turbo: card art (returns a JPEG data URI) ----
@app.cls(image=art_image, gpu="A10G", volumes={CACHE: hf_cache}, min_containers=3, max_containers=4, scaledown_window=600, timeout=600)
class ArtModel:
@modal.enter()
def load(self) -> None:
import torch
from diffusers import AutoPipelineForText2Image
self.pipe = AutoPipelineForText2Image.from_pretrained(SDXL, torch_dtype=torch.float16).to("cuda")
self.pipe.set_progress_bar_config(disable=True)
@modal.fastapi_endpoint(method="POST")
def generate(self, item: dict) -> dict:
import base64
from io import BytesIO
result = self.pipe(
prompt=item["prompt"],
num_inference_steps=int(item.get("steps", 4)),
guidance_scale=float(item.get("guidance", 0.0)),
width=int(item.get("width", 512)),
height=int(item.get("height", 320)),
negative_prompt=item.get("negative_prompt"),
)
buffer = BytesIO()
result.images[0].save(buffer, format="JPEG", quality=85)
return {"image": "data:image/jpeg;base64," + base64.b64encode(buffer.getvalue()).decode("ascii")}