File size: 4,252 Bytes
b96d4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f408942
b96d4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd5800b
 
b96d4b9
 
 
 
 
 
fd5800b
 
 
 
b96d4b9
 
 
 
 
 
 
fd5800b
 
b96d4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Celebrity Deathmatch β€” runtime client.

Talks to the Modal backend (modal.App("deathmatch")) over HTTP, OR serves canned
data when DEATHMATCH_MOCK is set β€” same function signatures either way, so the UI
never branches on mode. The HF Space deploy injects DEATHMATCH_API_URL.
"""
from __future__ import annotations

import base64
import io
import os
import tempfile

import httpx

API_URL = os.environ.get(
    "DEATHMATCH_API_URL",
    "https://pablo-pisarski--deathmatch-api.modal.run",
).rstrip("/")

MOCK = os.environ.get("DEATHMATCH_MOCK", "").strip().lower() in ("1", "true", "yes", "on")

TIMEOUT_S = 900


class BackendError(RuntimeError):
    """Inference backend unreachable or returned an error."""


def _pil_to_b64(img) -> str:
    buf = io.BytesIO()
    img.convert("RGB").save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()


def _b64_to_pil(data: str):
    from PIL import Image
    return Image.open(io.BytesIO(base64.b64decode(data))).convert("RGB")


def _post(path: str, payload: dict, timeout: float = TIMEOUT_S) -> dict:
    url = f"{API_URL}{path}"
    try:
        resp = httpx.post(url, json=payload, timeout=timeout, follow_redirects=True)
        resp.raise_for_status()
        return resp.json()
    except httpx.ConnectError as e:
        raise BackendError(
            f"Cannot reach the Deathmatch backend at {API_URL} β€” is the Modal app "
            f"deployed? ({e})"
        ) from e
    except httpx.ReadTimeout as e:
        raise BackendError(
            "Backend timed out β€” likely a GPU cold start pulling weights. "
            "Try again in ~1 minute."
        ) from e
    except httpx.HTTPStatusError as e:
        raise BackendError(
            f"Backend error {e.response.status_code}: {e.response.text[:300]}"
        ) from e


def health() -> dict:
    if MOCK:
        return {"status": "mock", "service": "deathmatch (mock mode β€” no GPU)"}
    try:
        resp = httpx.get(f"{API_URL}/health", timeout=10, follow_redirects=True)
        resp.raise_for_status()
        return resp.json()
    except Exception as e:  # noqa: BLE001 β€” banner only, never crash the UI
        return {"status": "unreachable", "error": str(e), "url": API_URL}


def generate_fightcard(image_a, image_b, storyline: str, arena: str, style: str,
                       name_a: str = "", name_b: str = "") -> dict:
    """Stage 1 β€” two photos -> validated fight card dict."""
    if MOCK:
        import copy
        from mock import MOCK_FIGHTCARD
        card = copy.deepcopy(MOCK_FIGHTCARD)
        card["arena"] = arena or card["arena"]
        if name_a and name_a.strip():
            card["fighter_a"]["name"] = name_a.strip()
        if name_b and name_b.strip():
            card["fighter_b"]["name"] = name_b.strip()
        return card
    data = _post("/fightcard", {
        "image_a_b64": _pil_to_b64(image_a),
        "image_b_b64": _pil_to_b64(image_b),
        "storyline": (storyline or "")[:500],
        "arena": arena,
        "style": style,
        "name_a": (name_a or "").strip()[:60],
        "name_b": (name_b or "").strip()[:60],
    })
    return data["fightcard"]


def generate_keyframes(card: dict, style: str, aspect: str = "16:9") -> list:
    """Stage 2 β€” fight card -> list of 5 keyframe PIL images."""
    if MOCK:
        from mock import placeholder_reel
        return placeholder_reel(card)
    data = _post("/keyframes", {"fightcard": card, "style": style, "aspect": aspect})
    return [_b64_to_pil(b) for b in data["images_b64"]]


def animate(card: dict, keyframes: list, style: str) -> str:
    """Stage 3 β€” keyframes -> one chained, captioned fight clip. Returns a file path.

    Mock: an animated GIF cycling the keyframes with captions burned in.
    Real: an MP4 from LTX-Video (Slice 4) written to a temp file.
    """
    if MOCK:
        from mock import mock_fight_video
        return mock_fight_video(keyframes)
    data = _post("/animate", {
        "fightcard": card,
        "keyframes_b64": [_pil_to_b64(f) for f in keyframes],
        "style": style,
    })
    fd, path = tempfile.mkstemp(prefix="deathmatch_", suffix=".mp4")
    os.write(fd, base64.b64decode(data["video_b64"]))
    os.close(fd)
    return path