virtual-characters / scripts /run_character_generation_spike.py
ShadowInk's picture
Upload complete Space runtime files
6bcddd0 verified
Raw
History Blame Contribute Delete
22.9 kB
from __future__ import annotations
import argparse
import json
import shutil
import sys
import time
from pathlib import Path
from typing import Any
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.character_spike.assets import (
create_asset_package_from_probe_outputs,
create_mock_asset_package,
make_thumbnail_grid,
write_report,
)
from src.character_spike.schema import (
EIGHT_ASSET_LIMIT_SECONDS,
EXPRESSIONS,
MIN_USABLE_ASSET_COUNT,
MODEL_CANDIDATES,
WARM_FOUR_IMAGE_LIMIT_SECONDS,
candidate_by_id,
slugify_identifier,
)
from src.character_spike.tavern_import import convert_tavern_card, load_tavern_json, write_character_draft
DEFAULT_PROMPT = (
"single original anime-style virtual character portrait, young adult woman, silver short hair, "
"teal eyes, light sci-fi communicator outfit, half body character sprite, centered composition, "
"plain light background, clean silhouette, no text, no logo, no poster, not a landscape, "
"not an existing commercial character"
)
EXPRESSION_PROMPTS = {
"idle": "neutral idle expression, relaxed posture",
"listening": "listening expression, attentive eyes, slight head tilt",
"thinking": "thinking expression, one hand near chin, focused eyes",
"worried": "worried expression, gentle concern, tense shoulders",
"smile": "small warm smile, calm and friendly",
"happy": "bright happy expression, open eyes, energetic",
"talk": "speaking expression, mouth open naturally, conversational pose",
"focus": "battle focus expression, determined gaze, stable posture",
}
def require_modal() -> None:
try:
import modal # noqa: F401
except Exception as exc:
raise SystemExit(
"Modal Python package is not installed. Run `python -m pip install -r requirements.txt` "
"and `modal setup` before remote checks."
) from exc
def list_models(_: argparse.Namespace) -> int:
rows = [
{
"id": candidate.id,
"label": candidate.label,
"model_id": candidate.model_id,
"mode": candidate.mode,
"steps": candidate.default_steps,
"gpu": candidate.default_gpu,
"implemented": candidate.implemented,
"notes": candidate.notes,
}
for candidate in MODEL_CANDIDATES
]
print(json.dumps(rows, ensure_ascii=False, indent=2))
return 0
def mock_assets(args: argparse.Namespace) -> int:
character_id = slugify_identifier(args.character_id or args.display_name, fallback="spike_character")
manifest = create_mock_asset_package(
character_id=character_id,
display_name=args.display_name,
output_root=args.output_root,
seed=args.seed,
)
print(json.dumps(manifest["paths"], ensure_ascii=False, indent=2))
return 0
def package_probe_assets(args: argparse.Namespace) -> int:
character_id = slugify_identifier(args.character_id or args.display_name, fallback="spike_character")
manifest = create_asset_package_from_probe_outputs(
source_run_dir=args.source_run_dir,
candidate_id=args.candidate,
character_id=character_id,
display_name=args.display_name,
output_root=args.output_root,
seed=args.seed,
remove_background=not args.keep_background,
)
print(json.dumps(manifest["paths"], ensure_ascii=False, indent=2))
return 0
def import_tavern(args: argparse.Namespace) -> int:
card = load_tavern_json(args.input)
package = convert_tavern_card(card, forced_id=args.id)
if args.dry_run:
print(json.dumps(package, ensure_ascii=False, indent=2))
return 0
path = write_character_draft(package, args.output_dir)
print(f"wrote {path}")
return 0
def modal_health(_: argparse.Namespace) -> int:
require_modal()
from modal_apps.modal_character_spike import app, spike_health
with app.run():
result = spike_health.remote()
print(json.dumps(result, ensure_ascii=False, indent=2))
return 0 if result.get("ok") else 1
def modal_probe(args: argparse.Namespace) -> int:
if not args.confirm_gpu:
raise SystemExit("Refusing to run GPU inference without --confirm-gpu.")
candidate = candidate_by_id(args.candidate)
run_dir = _run_dir(args.output_root, args.character_id, args.run_name)
run_dir.mkdir(parents=True, exist_ok=True)
init_image = _read_optional_bytes(args.init_image)
control_image = _read_optional_bytes(args.control_image)
prompt = args.prompt or DEFAULT_PROMPT
result = _remote_probe(
candidate_id=candidate.id,
prompt=prompt,
batch_size=args.batch_size,
steps=args.steps,
seed=args.seed,
width=args.width,
height=args.height,
init_image_bytes=init_image,
control_image_bytes=control_image,
)
image_paths = _write_probe_images(run_dir, candidate.id, "probe", result.pop("images", []))
result["image_paths"] = [str(path.relative_to(run_dir)) for path in image_paths]
result["manual_score"] = None
result["prompt"] = prompt
result["source"] = "modal_probe"
manifest = _load_or_new_manifest(run_dir, args.character_id, args.display_name, args.seed)
manifest["model_results"].append(result)
_refresh_manifest(manifest, run_dir)
print(json.dumps({key: value for key, value in result.items() if key != "images"}, ensure_ascii=False, indent=2))
return 0 if result.get("status") in {"ok", "skipped"} else 1
def modal_benchmark(args: argparse.Namespace) -> int:
if not args.confirm_gpu:
raise SystemExit("Refusing to run GPU benchmark without --confirm-gpu.")
require_modal()
from modal_apps.modal_character_spike import CharacterGenerationSpike, app
selected = [candidate_by_id(item) for item in args.candidates]
run_dir = _run_dir(args.output_root, args.character_id, args.run_name)
run_dir.mkdir(parents=True, exist_ok=True)
manifest = _load_or_new_manifest(run_dir, args.character_id, args.display_name, args.seed)
with app.run():
runner = CharacterGenerationSpike()
for candidate in selected:
if not candidate.implemented:
manifest["model_results"].append(
{
"candidate_id": candidate.id,
"label": candidate.label,
"model_id": candidate.model_id,
"mode": candidate.mode,
"status": "skipped",
"failure_reason": candidate.notes,
"source": "modal_benchmark",
}
)
continue
for label, batch_size in (("cold_single", 1), ("warm_single", 1), ("warm_four", 4)):
result = _call_remote_probe(
runner,
candidate_id=candidate.id,
prompt=args.prompt or DEFAULT_PROMPT,
batch_size=batch_size,
steps=args.steps,
seed=args.seed,
width=args.width,
height=args.height,
init_image_bytes=_read_optional_bytes(args.init_image),
control_image_bytes=_read_optional_bytes(args.control_image),
)
image_paths = _write_probe_images(run_dir, candidate.id, label, result.pop("images", []))
result.update(
{
"source": "modal_benchmark",
"benchmark_case": label,
"image_paths": [str(path.relative_to(run_dir)) for path in image_paths],
"manual_score": None,
"prompt": args.prompt or DEFAULT_PROMPT,
}
)
manifest["model_results"].append(result)
_refresh_manifest(manifest, run_dir)
if args.include_expressions:
expression_started = time.perf_counter()
expression_results = []
for index, expression in enumerate(EXPRESSIONS):
expression_prompt = f"{args.prompt or DEFAULT_PROMPT}, {EXPRESSION_PROMPTS[expression]}"
result = _call_remote_probe(
runner,
candidate_id=candidate.id,
prompt=expression_prompt,
batch_size=1,
steps=args.steps,
seed=args.seed + index,
width=args.width,
height=args.height,
init_image_bytes=_read_optional_bytes(args.init_image),
control_image_bytes=_read_optional_bytes(args.control_image),
)
image_paths = _write_probe_images(
run_dir,
candidate.id,
f"expression_{expression}",
result.pop("images", []),
)
result.update(
{
"source": "modal_benchmark",
"benchmark_case": f"expression_{expression}",
"expression": expression,
"image_paths": [str(path.relative_to(run_dir)) for path in image_paths],
"manual_score": None,
"prompt": expression_prompt,
}
)
expression_results.append(result)
manifest["model_results"].append(result)
_refresh_manifest(manifest, run_dir)
manifest["model_results"].append(
{
"candidate_id": candidate.id,
"label": candidate.label,
"mode": candidate.mode,
"source": "modal_benchmark",
"benchmark_case": "eight_expression_total",
"status": "ok" if all(item.get("status") == "ok" for item in expression_results) else "failed",
"image_count": sum(item.get("image_count", 0) for item in expression_results),
"duration_seconds": round(time.perf_counter() - expression_started, 3),
"manual_score": None,
}
)
else:
manifest["model_results"].append(
{
"candidate_id": candidate.id,
"label": candidate.label,
"mode": candidate.mode,
"source": "modal_benchmark",
"benchmark_case": "eight_expression_total",
"status": "skipped",
"failure_reason": "not run; pass --include-expressions to consume GPU for 8 expression probes",
"manual_score": None,
}
)
_refresh_manifest(manifest, run_dir)
print(f"benchmark manifest: {run_dir / 'generated' / 'manifest.json'}")
return 0
def stage_smoke(args: argparse.Namespace) -> int:
run_dir = Path(args.run_dir)
character_path = run_dir / "characters" / f"{args.character_id}.json"
if not character_path.exists():
raise SystemExit(f"missing character package: {character_path}")
character = json.loads(character_path.read_text(encoding="utf-8"))
from src import stage_driver
stage_driver.ASSET_ROOT = run_dir / "assets" / "characters"
stage_driver.BACKGROUND_ROOT = run_dir / "assets" / "backgrounds"
stage_driver._asset_data_uri.cache_clear()
stage_driver._background_data_uri.cache_clear()
states = [
{"expression": "idle", "motion": "breathe", "intensity": 0.25},
{"expression": "listening", "motion": "gentle_blink", "intensity": 0.35},
{"expression": "smile", "motion": "talk", "intensity": 0.72},
{"expression": "thinking", "motion": "focus", "intensity": 0.68},
]
snippets = [stage_driver.render_character_stage(character, state) for state in states]
html = "<!doctype html><meta charset='utf-8'><body style='margin:0;background:#020617'>" + "\n".join(snippets) + "</body>"
output = run_dir / "generated" / "stage_smoke.html"
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(html, encoding="utf-8")
print(f"stage smoke html: {output}")
return 0
def install_package(args: argparse.Namespace) -> int:
run_dir = Path(args.run_dir)
manifest_path = run_dir / "generated" / "manifest.json"
if not manifest_path.exists():
raise SystemExit(f"missing manifest: {manifest_path}")
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
character_id = manifest["character_id"]
src_character_assets = run_dir / "assets" / "characters" / character_id
src_background = run_dir / "assets" / "backgrounds" / f"{character_id}_spike_background.png"
src_character = run_dir / "characters" / f"{character_id}.json"
dst_character_assets = PROJECT_ROOT / "assets" / "characters" / character_id
dst_background = PROJECT_ROOT / "assets" / "backgrounds" / f"{character_id}_spike_background.png"
dst_character = PROJECT_ROOT / "characters" / f"{character_id}.json"
if args.dry_run:
print(
json.dumps(
{
"character_assets": [str(src_character_assets), str(dst_character_assets)],
"background": [str(src_background), str(dst_background)],
"character_draft": [str(src_character), str(dst_character)],
},
ensure_ascii=False,
indent=2,
)
)
return 0
shutil.copytree(src_character_assets, dst_character_assets, dirs_exist_ok=True)
dst_background.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_background, dst_background)
dst_character.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_character, dst_character)
print(f"installed spike package for {character_id}")
return 0
def _remote_probe(**kwargs: Any) -> dict[str, Any]:
require_modal()
from modal_apps.modal_character_spike import CharacterGenerationSpike, app
with app.run():
return _call_remote_probe(CharacterGenerationSpike(), **kwargs)
def _call_remote_probe(runner: Any, **kwargs: Any) -> dict[str, Any]:
started = time.perf_counter()
result = runner.probe.remote(**kwargs)
result["client_wall_seconds"] = round(time.perf_counter() - started, 3)
return result
def _write_probe_images(run_dir: Path, candidate_id: str, label: str, images: list[bytes]) -> list[Path]:
output_dir = run_dir / "generated" / candidate_id / label
output_dir.mkdir(parents=True, exist_ok=True)
paths = []
for index, payload in enumerate(images):
path = output_dir / f"{index:02d}.png"
path.write_bytes(payload)
paths.append(path)
if paths:
make_thumbnail_grid(paths, output_dir / "grid.png")
return paths
def _load_or_new_manifest(run_dir: Path, character_id: str, display_name: str, seed: int) -> dict[str, Any]:
manifest_path = run_dir / "generated" / "manifest.json"
if manifest_path.exists():
return json.loads(manifest_path.read_text(encoding="utf-8"))
return {
"schema_version": 1,
"run_type": "modal_character_generation_spike",
"character_id": character_id,
"display_name": display_name,
"seed": seed,
"created_at_unix": int(time.time()),
"duration_seconds": None,
"paths": {
"run_dir": str(run_dir),
"manifest": "generated/manifest.json",
"report": "generated/report.md",
},
"assets": [],
"model_results": [],
"qa": {
"usable_assets": 0,
"total_assets": len(EXPRESSIONS),
"needs_manual_review": [],
"notes": ["Remote probes only; generate a package with `mock-assets` or copy accepted outputs manually."],
},
}
def _refresh_manifest(manifest: dict[str, Any], run_dir: Path) -> None:
manifest["gate"] = _evaluate_gate(manifest)
manifest_path = run_dir / "generated" / "manifest.json"
manifest_path.parent.mkdir(parents=True, exist_ok=True)
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
write_report(manifest, run_dir / "generated" / "report.md")
def _evaluate_gate(manifest: dict[str, Any]) -> dict[str, Any]:
results = manifest.get("model_results") or []
warm_four = [
item
for item in results
if item.get("benchmark_case") == "warm_four" and item.get("status") == "ok" and item.get("loaded_before") is True
]
expression_totals = [
item for item in results if item.get("benchmark_case") == "eight_expression_total" and item.get("status") == "ok"
]
assets = manifest.get("assets") or []
usable_assets = sum(1 for item in assets if item.get("usable") is not False)
four_ok = any(float(item.get("duration_seconds", 999999)) <= WARM_FOUR_IMAGE_LIMIT_SECONDS for item in warm_four)
eight_ok = any(float(item.get("duration_seconds", 999999)) <= EIGHT_ASSET_LIMIT_SECONDS for item in expression_totals)
asset_ok = usable_assets >= MIN_USABLE_ASSET_COUNT
return {
"warm_four_under_60s": four_ok,
"eight_assets_under_180s": eight_ok,
"usable_assets_at_least_6": asset_ok,
"ready_for_character_workshop": bool(four_ok and eight_ok and asset_ok),
"fallback_recommendation": None
if four_ok and eight_ok and asset_ok
else "导入角色卡 + 单张立绘 + 背景 + 对话;多表情先用现有静态图和 stage 事件驱动。",
}
def _run_dir(output_root: str | Path, character_id: str, run_name: str | None) -> Path:
safe_id = slugify_identifier(character_id, fallback="spike_character")
name = slugify_identifier(run_name or time.strftime("%Y%m%d_%H%M%S"), fallback="run")
return Path(output_root) / safe_id / name
def _read_optional_bytes(path: str | None) -> bytes | None:
if not path:
return None
return Path(path).read_bytes()
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run the automated character-generation risk spike.")
subparsers = parser.add_subparsers(dest="command", required=True)
sub = subparsers.add_parser("list-models", help="Print model candidates and implementation status.")
sub.set_defaults(func=list_models)
sub = subparsers.add_parser("mock-assets", help="Create a no-GPU minimal character asset package.")
sub.add_argument("--character-id", default="spike_character")
sub.add_argument("--display-name", default="星核")
sub.add_argument("--output-root", default=str(PROJECT_ROOT / "assets" / "generated" / "character_spike"))
sub.add_argument("--seed", type=int, default=42)
sub.set_defaults(func=mock_assets)
sub = subparsers.add_parser("package-probe-assets", help="Create a stage asset package from Modal probe expression outputs.")
sub.add_argument("--source-run-dir", required=True)
sub.add_argument("--candidate", required=True, choices=[item.id for item in MODEL_CANDIDATES])
sub.add_argument("--character-id", default="spike_character")
sub.add_argument("--display-name", default="星核")
sub.add_argument("--output-root", default=str(PROJECT_ROOT / "assets" / "generated" / "character_spike"))
sub.add_argument("--seed", type=int, default=42)
sub.add_argument("--keep-background", action="store_true")
sub.set_defaults(func=package_probe_assets)
sub = subparsers.add_parser("import-tavern", help="Import a Tavern JSON card as a draft character package.")
sub.add_argument("--input", required=True)
sub.add_argument("--output-dir", default=str(PROJECT_ROOT / "characters"))
sub.add_argument("--id")
sub.add_argument("--dry-run", action="store_true")
sub.set_defaults(func=import_tavern)
sub = subparsers.add_parser("modal-health", help="Check the Modal spike container without loading model weights.")
sub.set_defaults(func=modal_health)
sub = subparsers.add_parser("modal-probe", help="Run one remote model probe and write images/metrics.")
_add_modal_probe_args(sub)
sub.set_defaults(func=modal_probe)
sub = subparsers.add_parser("modal-benchmark", help="Run cold/warm/4-image probes, optionally 8 expression probes.")
_add_modal_probe_args(sub)
sub.add_argument("--candidates", nargs="+", default=["flux_schnell"], choices=[item.id for item in MODEL_CANDIDATES])
sub.add_argument("--include-expressions", action="store_true")
sub.set_defaults(func=modal_benchmark)
sub = subparsers.add_parser("stage-smoke", help="Render generated assets through the existing stage driver.")
sub.add_argument("--run-dir", required=True)
sub.add_argument("--character-id", required=True)
sub.set_defaults(func=stage_smoke)
sub = subparsers.add_parser("install-package", help="Explicitly copy a generated spike package into app asset paths.")
sub.add_argument("--run-dir", required=True)
sub.add_argument("--dry-run", action="store_true")
sub.set_defaults(func=install_package)
return parser
def _add_modal_probe_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--candidate", default="flux_schnell", choices=[item.id for item in MODEL_CANDIDATES])
parser.add_argument("--prompt")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--steps", type=int)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--width", type=int, default=768)
parser.add_argument("--height", type=int, default=1024)
parser.add_argument("--init-image")
parser.add_argument("--control-image")
parser.add_argument("--character-id", default="spike_character")
parser.add_argument("--display-name", default="星核")
parser.add_argument("--run-name")
parser.add_argument("--output-root", default=str(PROJECT_ROOT / "assets" / "generated" / "character_spike"))
parser.add_argument("--confirm-gpu", action="store_true")
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
return args.func(args)
if __name__ == "__main__":
raise SystemExit(main())