tabras / clients.py
Codex
Fix card truncation: lead JSON with effects+name, request concise flavor/art, bump card tokens
3625c13
Raw
History Blame Contribute Delete
7.83 kB
import os
from art import ArtClient, DiffusersImageClient, LazyArtClient, ModalArtClient
from boss import BossClient, NemotronBossClient
from generator import CardPackClient, LlamaCppCardClient, MiniCPMCardClient
from local_llm import (
LocalChatClient,
LocalCompletionClient,
LocalJsonChatClient,
MLXChatClient,
MiniCPMTransformersChatClient,
NemotronTransformersChatClient,
nemotron_prompt,
)
DEFAULT_CARD_ENDPOINT = "http://localhost:8080/v1/chat/completions"
DEFAULT_LLAMA_CARD_ENDPOINT = "http://127.0.0.1:8090/v1/chat/completions"
DEFAULT_BOSS_ENDPOINT = "http://localhost:8081/v1/chat/completions"
DEFAULT_BOSS_COMPLETION_ENDPOINT = "http://localhost:8081/v1/completions"
DEFAULT_CARD_MODEL = "minicpm-v-4.6"
DEFAULT_BOSS_MODEL = "nvidia/Nemotron-Mini-4B-Instruct"
DEFAULT_MINICPM_MODEL = "openbmb/MiniCPM-V-4"
DEFAULT_MLX_BOSS_MODEL = "mlx-community/Nemotron-Mini-4B-Instruct-4bit-mlx"
DEFAULT_ART_MODEL = "stabilityai/sdxl-turbo"
_art_client_cache: ArtClient | None = None
# Cache the expensive Transformers MiniCPM load so the card client is not rebuilt
# (which would reload the model from disk) on every handler and timer tick.
_minicpm_chat_cache: dict[tuple, object] = {}
# Configure model backends from MODE. LOCAL (default) runs the models on your own
# hardware (Transformers + Diffusers, in-process); MODAL calls the Modal GPU
# endpoints over HTTP (used on the Hugging Face Space). Individual TABRAS_* vars
# still override these defaults.
def configure_mode() -> str:
mode = os.environ.get("MODE", "LOCAL").upper()
if mode == "MODAL":
os.environ.setdefault("TABRAS_CARD_BACKEND", "llamacpp")
os.environ.setdefault("TABRAS_CARD_MODEL", "minicpm")
os.environ.setdefault("TABRAS_CARD_TIMEOUT", "120")
os.environ.setdefault("TABRAS_BOSS_BACKEND", "openai")
os.environ.setdefault("TABRAS_BOSS_MODEL", "nemotron")
os.environ.setdefault("TABRAS_BOSS_TIMEOUT", "120")
os.environ.setdefault("TABRAS_ART_BACKEND", "modal")
else:
os.environ.setdefault("TABRAS_CARD_BACKEND", "transformers")
os.environ.setdefault("TABRAS_CARD_MODEL", DEFAULT_MINICPM_MODEL)
os.environ.setdefault("TABRAS_BOSS_BACKEND", "transformers")
os.environ.setdefault("TABRAS_BOSS_MODEL", DEFAULT_BOSS_MODEL)
os.environ.setdefault("TABRAS_ART_BACKEND", "diffusers")
os.environ.setdefault("TABRAS_ART_MODEL", DEFAULT_ART_MODEL)
os.environ.setdefault("TABRAS_AI_BOSS", "1")
os.environ.setdefault("TABRAS_CARD_TEMPERATURE", "0.7")
os.environ.setdefault("TABRAS_CARD_MAX_TOKENS", "320")
os.environ.setdefault("TABRAS_BOSS_MAX_TOKENS", "96")
os.environ.setdefault("TABRAS_ART_STEPS", "4")
return mode
# Build a card-generation client from environment variables.
def card_client_from_env() -> CardPackClient | None:
if os.environ.get("TABRAS_CARD_BACKEND") == "llamacpp":
return LlamaCppCardClient(
LocalJsonChatClient(
endpoint=os.environ.get("TABRAS_CARD_ENDPOINT", DEFAULT_LLAMA_CARD_ENDPOINT),
model=os.environ.get("TABRAS_CARD_MODEL", DEFAULT_CARD_MODEL),
timeout_seconds=int(os.environ.get("TABRAS_CARD_TIMEOUT", "60")),
temperature=float(os.environ.get("TABRAS_CARD_TEMPERATURE", "0.0")),
max_tokens=int(os.environ.get("TABRAS_CARD_MAX_TOKENS", "320")),
)
)
if os.environ.get("TABRAS_CARD_BACKEND") == "transformers":
model_path = os.environ.get("TABRAS_CARD_MODEL", DEFAULT_MINICPM_MODEL)
key = (
model_path,
int(os.environ.get("TABRAS_CARD_MAX_TOKENS", "320")),
float(os.environ.get("TABRAS_CARD_TEMPERATURE", "0.7")),
)
chat = _minicpm_chat_cache.get(key)
if chat is None:
chat = MiniCPMTransformersChatClient.load(model_path, max_new_tokens=key[1], temperature=key[2])
_minicpm_chat_cache[key] = chat
return MiniCPMCardClient(chat)
if os.environ.get("TABRAS_CARD_ENDPOINT") is None:
return None
return MiniCPMCardClient(
LocalChatClient(
endpoint=os.environ.get("TABRAS_CARD_ENDPOINT", DEFAULT_CARD_ENDPOINT),
model=os.environ.get("TABRAS_CARD_MODEL", DEFAULT_CARD_MODEL),
timeout_seconds=int(os.environ.get("TABRAS_CARD_TIMEOUT", "60")),
temperature=float(os.environ.get("TABRAS_CARD_TEMPERATURE", "0.8")),
)
)
# Build a boss client from environment variables.
def boss_client_from_env() -> BossClient | None:
if os.environ.get("TABRAS_AI_BOSS") != "1":
return None
if os.environ.get("TABRAS_BOSS_BACKEND") == "completion":
return NemotronBossClient(
LocalCompletionClient(
endpoint=os.environ.get("TABRAS_BOSS_ENDPOINT", DEFAULT_BOSS_COMPLETION_ENDPOINT),
model=os.environ.get("TABRAS_BOSS_MODEL", DEFAULT_BOSS_MODEL),
prompt_template=nemotron_prompt,
timeout_seconds=int(os.environ.get("TABRAS_BOSS_TIMEOUT", "20")),
temperature=float(os.environ.get("TABRAS_BOSS_TEMPERATURE", "0.2")),
max_tokens=int(os.environ.get("TABRAS_BOSS_MAX_TOKENS", "256")),
)
)
if os.environ.get("TABRAS_BOSS_BACKEND") == "transformers":
chat = NemotronTransformersChatClient.load(
os.environ.get("TABRAS_BOSS_MODEL", DEFAULT_BOSS_MODEL),
max_new_tokens=int(os.environ.get("TABRAS_BOSS_MAX_TOKENS", "256")),
temperature=float(os.environ.get("TABRAS_BOSS_TEMPERATURE", "0.2")),
)
return NemotronBossClient(chat)
if os.environ.get("TABRAS_BOSS_BACKEND") == "mlx":
chat = MLXChatClient.load(
os.environ.get("TABRAS_BOSS_MODEL", DEFAULT_MLX_BOSS_MODEL),
nemotron_prompt,
max_tokens=int(os.environ.get("TABRAS_BOSS_MAX_TOKENS", "96")),
)
return NemotronBossClient(chat)
return NemotronBossClient(
LocalChatClient(
endpoint=os.environ.get("TABRAS_BOSS_ENDPOINT", DEFAULT_BOSS_ENDPOINT),
model=os.environ.get("TABRAS_BOSS_MODEL", DEFAULT_BOSS_MODEL),
timeout_seconds=int(os.environ.get("TABRAS_BOSS_TIMEOUT", "20")),
temperature=float(os.environ.get("TABRAS_BOSS_TEMPERATURE", "0.2")),
)
)
# Build an art-generation client from environment variables.
def art_client_from_env() -> ArtClient | None:
global _art_client_cache
backend = os.environ.get("TABRAS_ART_BACKEND")
if backend == "modal":
if _art_client_cache is None:
_art_client_cache = ModalArtClient(
endpoint=os.environ["TABRAS_ART_ENDPOINT"],
steps=int(os.environ.get("TABRAS_ART_STEPS", "4")),
guidance_scale=float(os.environ.get("TABRAS_ART_GUIDANCE", "0.0")),
width=int(os.environ.get("TABRAS_ART_WIDTH", "512")),
height=int(os.environ.get("TABRAS_ART_HEIGHT", "320")),
)
return _art_client_cache
if backend != "diffusers":
return None
if _art_client_cache is None:
model_id = os.environ.get("TABRAS_ART_MODEL", DEFAULT_ART_MODEL)
steps = int(os.environ.get("TABRAS_ART_STEPS", "1"))
guidance_scale = float(os.environ.get("TABRAS_ART_GUIDANCE", "0.0"))
width = int(os.environ.get("TABRAS_ART_WIDTH", "512"))
height = int(os.environ.get("TABRAS_ART_HEIGHT", "320"))
_art_client_cache = LazyArtClient(
lambda: DiffusersImageClient.load(
model_id,
steps=steps,
guidance_scale=guidance_scale,
width=width,
height=height,
)
)
return _art_client_cache