virtual-characters / src /model_status.py
ShadowInk's picture
Upload complete Space runtime files
6bcddd0 verified
Raw
History Blame Contribute Delete
20.9 kB
from __future__ import annotations
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
DEFAULT_VLLM_BASE_URL = "https://veronicaulises0--virtual-characters-vllm-gemma-serve.modal.run"
DEFAULT_TTS_URL = "https://veronicaulises0--virtual-characters-tts-charactertts-tts.modal.run"
PROJECT_ROOT = Path(__file__).resolve().parents[1]
MODAL_SERVICE_FILES = {
"tts": PROJECT_ROOT / "modal_apps" / "modal_tts.py",
"image_generation": PROJECT_ROOT / "modal_apps" / "modal_character_spike.py",
}
IMAGE_GENERATION_WAIT_MESSAGE = "Modal 图像生成服务可能已休眠或正在冷启动,请等待容器启动和模型载入后重试。"
@dataclass
class ModelStatus:
kind: str
state: str
label: str
url: str | None = None
model: str | None = None
latency_s: float | None = None
message: str = ""
def as_dict(self) -> dict[str, Any]:
return {
"kind": self.kind,
"state": self.state,
"label": self.label,
"url": self.url,
"model": self.model,
"latency_s": self.latency_s,
"message": self.message,
}
STATE_LABELS = {
"ready": "可用",
"loading": "载入中",
"sleeping": "已休眠",
"error": "错误",
"unconfigured": "未配置",
"local": "本地服务",
"mock": "Mock",
"unknown": "待检测",
}
def configured_llm_url() -> str | None:
if os.environ.get("VC_USE_MOCK") == "1":
return None
return os.environ.get("VC_MODAL_LLM_URL") or os.environ.get("VC_MODAL_VLLM_URL") or DEFAULT_VLLM_BASE_URL
def configured_tts_url() -> str | None:
return os.environ.get("VC_MODAL_TTS_URL") or DEFAULT_TTS_URL
def initial_model_statuses() -> list[ModelStatus]:
if os.environ.get("VC_USE_MOCK") == "1":
llm = ModelStatus("llm", "mock", STATE_LABELS["mock"], message="当前使用本地 mock 对话。")
else:
llm = ModelStatus("llm", "unknown", STATE_LABELS["unknown"], url=configured_llm_url(), message="点击刷新检测 Modal LLM。")
tts_url = configured_tts_url()
return [
llm,
_initial_endpoint_status("tts", tts_url, "VC_MODAL_TTS_URL"),
_initial_image_generation_status(),
]
def llm_loading_status(message: str = "正在启动主模型;首次加载可能需要 1-3 分钟。") -> ModelStatus:
if os.environ.get("VC_USE_MOCK") == "1":
return ModelStatus("llm", "mock", STATE_LABELS["mock"], message="当前使用本地 mock 对话。")
return ModelStatus("llm", "loading", STATE_LABELS["loading"], url=configured_llm_url(), message=message)
def statuses_with_llm_status(llm_status: ModelStatus) -> list[ModelStatus]:
statuses = initial_model_statuses()
return [llm_status, *statuses[1:]]
def check_all_statuses(timeout_s: float = 4.0) -> list[ModelStatus]:
if os.environ.get("VC_USE_MOCK") == "1":
llm = ModelStatus("llm", "mock", STATE_LABELS["mock"], message="当前使用本地 mock 对话。")
elif os.environ.get("VC_MODAL_LLM_URL"):
llm = _check_simple_health("llm", os.environ["VC_MODAL_LLM_URL"], timeout_s, health_path=None)
else:
llm = _check_vllm(configured_llm_url(), timeout_s)
return [
llm,
_check_tts_endpoint(configured_tts_url(), timeout_s),
check_image_generation_status(),
]
def warm_llm_model(timeout_s: float = 600.0) -> ModelStatus:
if os.environ.get("VC_USE_MOCK") == "1":
return ModelStatus("llm", "mock", STATE_LABELS["mock"], message="当前使用本地 mock 对话,不需要启动远端模型。")
if os.environ.get("VC_MODAL_LLM_URL"):
return _warm_modal_llm(os.environ["VC_MODAL_LLM_URL"], timeout_s)
return _warm_vllm(configured_llm_url(), timeout_s)
def check_image_generation_status() -> ModelStatus:
path = MODAL_SERVICE_FILES["image_generation"]
if not path.exists():
return ModelStatus(
"image_generation",
"unconfigured",
STATE_LABELS["unconfigured"],
message="未找到 modal_apps/modal_character_spike.py。",
)
started = time.perf_counter()
try:
from modal_apps.modal_character_spike import app, spike_health
with app.run():
result = spike_health.remote()
elapsed = time.perf_counter() - started
if result.get("ok"):
return ModelStatus(
"image_generation",
"ready",
STATE_LABELS["ready"],
url="modal_apps/modal_character_spike.py",
model="Qwen/Qwen-Image",
latency_s=elapsed,
message="Modal 图像生成 health check 通过;生成时仍可能需要等待 GPU 模型载入。",
)
return ModelStatus(
"image_generation",
"loading",
STATE_LABELS["loading"],
url="modal_apps/modal_character_spike.py",
latency_s=elapsed,
message=IMAGE_GENERATION_WAIT_MESSAGE,
)
except ImportError as exc:
return ModelStatus(
"image_generation",
"unconfigured",
STATE_LABELS["unconfigured"],
url="modal_apps/modal_character_spike.py",
latency_s=time.perf_counter() - started,
message=f"Modal Python 包或依赖不可用:{exc}",
)
except Exception as exc:
return ModelStatus(
"image_generation",
"sleeping",
STATE_LABELS["sleeping"],
url="modal_apps/modal_character_spike.py",
latency_s=time.perf_counter() - started,
message=f"{IMAGE_GENERATION_WAIT_MESSAGE} ({type(exc).__name__}: {exc})",
)
def statuses_markdown(statuses: list[ModelStatus]) -> str:
rows = []
for status in statuses:
css_state = status.state
latency = f" · {status.latency_s:.2f}s" if status.latency_s is not None else ""
model = f" · `{status.model}`" if status.model else ""
url = f"<small>{status.url}</small>" if status.url else f"<small>{_empty_url_label(status.state)}</small>"
rows.append(
f'<div class="vc-model-pill vc-model-{css_state}">'
f'<b>{_kind_label(status.kind)}</b><span>{status.label}{latency}{model}</span>{url}'
f'<em>{status.message}</em></div>'
)
return '<div class="vc-model-grid">' + "".join(rows) + "</div>"
def statuses_json(statuses: list[ModelStatus]) -> dict[str, Any]:
return {"models": [status.as_dict() for status in statuses]}
def _check_vllm(base_url: str | None, timeout_s: float) -> ModelStatus:
if not base_url:
return ModelStatus("llm", "unconfigured", STATE_LABELS["unconfigured"], message="未设置 vLLM URL。")
for path in ("/v1/models", "/models"):
status = _get_json("llm", base_url.rstrip("/") + path, timeout_s)
if status.state == "ready":
data = status.message_json or {}
models = data.get("data") if isinstance(data, dict) else None
model_id = None
if isinstance(models, list) and models:
model_id = str(models[0].get("id") or models[0].get("root") or "")
return ModelStatus(
"llm",
"ready",
STATE_LABELS["ready"],
url=base_url,
model=model_id or None,
latency_s=status.latency_s,
message="vLLM 模型列表可访问。",
)
if status.state in {"sleeping", "loading"}:
return ModelStatus(
"llm",
status.state,
STATE_LABELS[status.state],
url=base_url,
latency_s=status.latency_s,
message="Modal 模型服务已休眠或正在冷启动,如需体验请等待模型载入。",
)
return ModelStatus("llm", "error", STATE_LABELS["error"], url=base_url, message="vLLM 模型状态检测失败。")
def _warm_vllm(base_url: str | None, timeout_s: float) -> ModelStatus:
if not base_url:
return ModelStatus("llm", "unconfigured", STATE_LABELS["unconfigured"], message="未设置 vLLM URL。")
import httpx
url = base_url.rstrip("/") + "/v1/chat/completions"
payload = {
"model": os.environ.get("VC_VLLM_SERVED_MODEL", "llm"),
"messages": [
{"role": "system", "content": "你是模型启动检查。"},
{"role": "user", "content": "只回复:已就绪"},
],
"max_tokens": 4,
"temperature": 0,
"stream": False,
"chat_template_kwargs": {"enable_thinking": False},
}
started = time.perf_counter()
timeout = httpx.Timeout(connect=30, read=timeout_s, write=30, pool=30)
try:
response = httpx.post(url, json=payload, timeout=timeout, trust_env=False)
elapsed = time.perf_counter() - started
if response.status_code == 200:
return ModelStatus(
"llm",
"ready",
STATE_LABELS["ready"],
url=base_url,
model=os.environ.get("VC_VLLM_SERVED_MODEL", "llm"),
latency_s=elapsed,
message="主模型已完成短请求,接下来几分钟内对话会更快。",
)
if response.status_code in {408, 425, 429, 500, 502, 503, 504}:
return ModelStatus(
"llm",
"loading",
STATE_LABELS["loading"],
url=base_url,
latency_s=elapsed,
message=f"启动请求已到达,但服务仍在冷启动或排队:HTTP {response.status_code}",
)
return ModelStatus("llm", "error", STATE_LABELS["error"], url=base_url, latency_s=elapsed, message=f"启动失败:HTTP {response.status_code}")
except httpx.TimeoutException:
return ModelStatus(
"llm",
"loading",
STATE_LABELS["loading"],
url=base_url,
latency_s=time.perf_counter() - started,
message="启动请求超时;Modal 可能仍在拉起容器或加载权重,请稍后刷新状态。",
)
except (httpx.ConnectError, httpx.RemoteProtocolError, httpx.ReadError) as exc:
return ModelStatus(
"llm",
"sleeping",
STATE_LABELS["sleeping"],
url=base_url,
latency_s=time.perf_counter() - started,
message=f"服务暂不可达:{exc}",
)
except Exception as exc:
return ModelStatus(
"llm",
"error",
STATE_LABELS["error"],
url=base_url,
latency_s=time.perf_counter() - started,
message=f"启动失败:{exc}",
)
def _warm_modal_llm(url: str, timeout_s: float) -> ModelStatus:
import httpx
payload = {
"text": "请只回复:已就绪",
"character": {"display_name": "启动检查"},
"max_new_tokens": 4,
}
started = time.perf_counter()
timeout = httpx.Timeout(connect=30, read=timeout_s, write=30, pool=30)
try:
with httpx.stream("POST", url, json=payload, timeout=timeout, trust_env=False) as response:
elapsed = time.perf_counter() - started
if response.status_code == 200:
for line in response.iter_lines():
if line:
break
return ModelStatus(
"llm",
"ready",
STATE_LABELS["ready"],
url=url,
latency_s=time.perf_counter() - started,
message="主模型已完成短请求,接下来几分钟内对话会更快。",
)
if response.status_code in {408, 425, 429, 500, 502, 503, 504}:
return ModelStatus("llm", "loading", STATE_LABELS["loading"], url=url, latency_s=elapsed, message=f"服务仍在冷启动或排队:HTTP {response.status_code}")
return ModelStatus("llm", "error", STATE_LABELS["error"], url=url, latency_s=elapsed, message=f"启动失败:HTTP {response.status_code}")
except httpx.TimeoutException:
return ModelStatus(
"llm",
"loading",
STATE_LABELS["loading"],
url=url,
latency_s=time.perf_counter() - started,
message="启动请求超时;Modal 可能仍在拉起容器或加载权重,请稍后刷新状态。",
)
except (httpx.ConnectError, httpx.RemoteProtocolError, httpx.ReadError) as exc:
return ModelStatus("llm", "sleeping", STATE_LABELS["sleeping"], url=url, latency_s=time.perf_counter() - started, message=f"服务暂不可达:{exc}")
except Exception as exc:
return ModelStatus("llm", "error", STATE_LABELS["error"], url=url, latency_s=time.perf_counter() - started, message=f"启动失败:{exc}")
def _check_simple_health(kind: str, url: str | None, timeout_s: float, health_path: str | None) -> ModelStatus:
if not url:
local = _local_service_status(kind)
if local:
return local
env_name = {"tts": "VC_MODAL_TTS_URL"}.get(kind, "URL")
return ModelStatus(kind, "unconfigured", STATE_LABELS["unconfigured"], message=f"未设置 {env_name}。")
targets = _health_targets(url, health_path)
status = _HttpProbeResult(kind, "error", STATE_LABELS["error"], url=url, message="未执行检测")
for target in targets:
status = _get_json(kind, target, timeout_s)
if status.state != "error":
break
if status.state == "ready":
data = status.message_json if isinstance(status.message_json, dict) else {}
return ModelStatus(
kind,
"ready",
STATE_LABELS["ready"],
url=url,
model=str(data.get("backend") or data.get("model") or "") or None,
latency_s=status.latency_s,
message="服务健康检查通过。",
)
if status.state in {"sleeping", "loading"}:
return ModelStatus(
kind,
status.state,
STATE_LABELS[status.state],
url=url,
latency_s=status.latency_s,
message="Modal 模型服务已休眠或正在冷启动,如需体验请等待模型载入。",
)
return ModelStatus(kind, "error", STATE_LABELS["error"], url=url, latency_s=status.latency_s, message=status.message)
def _check_tts_endpoint(url: str | None, timeout_s: float) -> ModelStatus:
if not url:
local = _local_service_status("tts")
if local:
return local
return ModelStatus("tts", "unconfigured", STATE_LABELS["unconfigured"], message="未设置 VC_MODAL_TTS_URL。")
import httpx
target = _tts_endpoint_url(url)
timeout_s = max(timeout_s, 15.0)
started = time.perf_counter()
try:
response = httpx.post(target, json={"probe_only": True}, timeout=timeout_s, trust_env=False)
elapsed = time.perf_counter() - started
if response.status_code == 200:
data = response.json()
return ModelStatus(
"tts",
"ready",
STATE_LABELS["ready"],
url=url,
model=str(data.get("backend") or data.get("model") or "") or None,
latency_s=elapsed,
message="TTS endpoint 可访问;首次合成仍可能需要等待模型载入。",
)
if response.status_code in {408, 425, 429, 500, 502, 503, 504}:
return ModelStatus(
"tts",
"loading",
STATE_LABELS["loading"],
url=url,
latency_s=elapsed,
message=f"Modal TTS 服务已触达,但可能正在冷启动:HTTP {response.status_code}",
)
return ModelStatus("tts", "error", STATE_LABELS["error"], url=url, latency_s=elapsed, message=f"HTTP {response.status_code}")
except (httpx.TimeoutException, httpx.ConnectError, httpx.RemoteProtocolError, httpx.ReadError) as exc:
return ModelStatus(
"tts",
"sleeping",
STATE_LABELS["sleeping"],
url=url,
latency_s=time.perf_counter() - started,
message=f"Modal TTS 服务可能已休眠或正在冷启动:{exc}",
)
except Exception as exc:
return ModelStatus("tts", "error", STATE_LABELS["error"], url=url, latency_s=time.perf_counter() - started, message=f"TTS 状态检测失败:{exc}")
@dataclass
class _HttpProbeResult(ModelStatus):
message_json: Any = None
def _get_json(kind: str, url: str, timeout_s: float) -> _HttpProbeResult:
import httpx
started = time.perf_counter()
try:
response = httpx.get(url, timeout=timeout_s, trust_env=False)
elapsed = time.perf_counter() - started
if response.status_code == 200:
result = _HttpProbeResult(kind, "ready", STATE_LABELS["ready"], url=url, latency_s=elapsed, message="OK")
try:
result.message_json = response.json()
except ValueError:
result.message_json = {}
return result
if response.status_code in {408, 425, 429, 500, 502, 503, 504}:
return _HttpProbeResult(kind, "loading", STATE_LABELS["loading"], url=url, latency_s=elapsed, message=f"HTTP {response.status_code}")
return _HttpProbeResult(kind, "error", STATE_LABELS["error"], url=url, latency_s=elapsed, message=f"HTTP {response.status_code}")
except (httpx.TimeoutException, httpx.ConnectError, httpx.RemoteProtocolError, httpx.ReadError) as exc:
elapsed = time.perf_counter() - started
return _HttpProbeResult(kind, "sleeping", STATE_LABELS["sleeping"], url=url, latency_s=elapsed, message=str(exc))
except Exception as exc:
elapsed = time.perf_counter() - started
return _HttpProbeResult(kind, "error", STATE_LABELS["error"], url=url, latency_s=elapsed, message=str(exc))
def _kind_label(kind: str) -> str:
return {"llm": "LLM", "tts": "TTS", "image_generation": "Image Generation"}.get(kind, kind.upper())
def _empty_url_label(state: str) -> str:
if state == "mock":
return "本地模拟"
if state == "local":
return "本地服务定义"
if state == "unconfigured":
return "未绑定 endpoint"
return "待检测"
def _initial_endpoint_status(kind: str, url: str | None, env_name: str) -> ModelStatus:
if url:
return ModelStatus(kind, "unknown", STATE_LABELS["unknown"], url=url, message=f"点击刷新检测 {_kind_label(kind)}。")
local = _local_service_status(kind)
if local:
return local
return ModelStatus(kind, "unconfigured", STATE_LABELS["unconfigured"], message=f"未设置 {env_name}。")
def _initial_image_generation_status() -> ModelStatus:
local = _local_service_status("image_generation")
if local:
local.state = "unknown"
local.label = STATE_LABELS["unknown"]
local.model = "Qwen/Qwen-Image"
local.message = "点击刷新检测 Modal 图像生成服务。"
return local
return ModelStatus(
"image_generation",
"unconfigured",
STATE_LABELS["unconfigured"],
message="未找到 Modal 图像生成服务定义。",
)
def _local_service_status(kind: str) -> ModelStatus | None:
path = MODAL_SERVICE_FILES.get(kind)
if not path or not path.exists():
return None
rel_path = path.relative_to(PROJECT_ROOT).as_posix()
env_name = {"tts": "VC_MODAL_TTS_URL", "image_generation": "Modal app"}.get(kind, "URL")
return ModelStatus(
kind,
"local",
STATE_LABELS["local"],
url=rel_path,
message=f"{rel_path} 已存在;部署后设置 {env_name},或从 Modal 输出复制 endpoint URL。",
)
def _health_targets(url: str, health_path: str | None) -> list[str]:
if not health_path:
return [url]
base = url.rstrip("/")
tail = base.rsplit("/", 1)[-1]
service_base = base.rsplit("/", 1)[0] if tail in {"tts", "persona_events"} else base
targets = [service_base + health_path]
if health_path == "/health":
targets.append(service_base + "/health_http")
if service_base != base:
targets.append(base + health_path)
return list(dict.fromkeys(targets))
def _tts_endpoint_url(url: str) -> str:
base = url.rstrip("/")
parsed = urlparse(base)
if not parsed.path or parsed.path == "/":
return base
if parsed.path.rstrip("/").rsplit("/", 1)[-1] == "tts":
return base
return base + "/tts"