File size: 8,975 Bytes
8de8ab2
 
8820f69
 
 
8de8ab2
 
ef005d9
8de8ab2
febacca
 
 
8de8ab2
febacca
ef005d9
8de8ab2
 
 
 
 
ef005d9
 
c75b16d
ef005d9
 
 
c75b16d
ef005d9
 
 
c75b16d
 
 
ef005d9
ed1bb3b
 
 
 
 
 
 
 
 
 
dba126e
 
 
 
 
c75b16d
ed1bb3b
 
 
 
dba126e
 
ed1bb3b
c75b16d
 
ef005d9
ed1bb3b
 
 
 
 
 
 
 
 
 
 
 
c75b16d
 
ef005d9
c75b16d
ed1bb3b
ef005d9
dba126e
 
ef005d9
f3ed552
 
 
c75b16d
f3ed552
 
dba126e
 
c75b16d
dba126e
 
 
 
f3ed552
dba126e
c75b16d
 
ef005d9
dba126e
 
ef005d9
dba126e
 
c75b16d
dba126e
 
 
 
c75b16d
ef005d9
 
 
 
 
8de8ab2
aeb37c0
 
5149b6a
aeb37c0
 
5149b6a
febacca
 
 
 
 
fc5f622
8f6bbb3
 
 
dba126e
aeb37c0
ef005d9
aeb37c0
5149b6a
b397c92
aeb37c0
 
b397c92
aeb37c0
 
 
 
 
 
b397c92
aeb37c0
 
 
 
 
 
 
 
febacca
8de8ab2
fc5f622
aeb37c0
8de8ab2
 
 
aeb37c0
8de8ab2
8f6bbb3
fc5f622
aeb37c0
febacca
 
8de8ab2
 
 
febacca
c75b16d
ef005d9
c75b16d
febacca
8de8ab2
febacca
 
 
 
 
 
 
 
 
 
8de8ab2
febacca
 
 
c75b16d
 
656aab0
febacca
 
8de8ab2
 
 
 
 
febacca
 
 
 
8de8ab2
febacca
 
 
 
 
 
c75b16d
 
ef005d9
 
 
8de8ab2
febacca
 
aeb37c0
8de8ab2
aeb37c0
 
 
 
 
febacca
 
 
8820f69
 
 
0c3b0ba
f4ca1dd
0c3b0ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""daVinci-MagiHuman WebUI β€” Gradio frontend for HF Spaces.

Pure frontend, no queuing, no load management.
All requests are sent immediately to the router.
If the router rejects (503), the error is shown directly to the user.

Architecture:
    User prompt ──LLM rewrite──▢ refined prompt
    HF Space (this app) ──HTTP──▢ Router (public IP) ──▢ 4x inference servers
"""

import os

import gradio as gr
from openai import OpenAI

from api_client import generate

OUTPUT_DIR = "/tmp/magihuman_webui_outputs"

# ── Prompt rewrite via LLM ───────────────────────────────────────────

_enhance_client = OpenAI(
    base_url="https://apicz.boyuerichdata.com/v1/",
    api_key=os.environ.get("REWRITE_API_KEY", ""),
)
_ENHANCE_MODEL = "gemini-3-flash-preview"

_PROMPT_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt.txt")
with open(_PROMPT_FILE, "r") as f:
    _ENHANCE_SYSTEM_PROMPT = f.read()
    print(f"[Enhance] Loaded system prompt from {os.path.basename(_PROMPT_FILE)}, length={len(_ENHANCE_SYSTEM_PROMPT)} chars")
    print(f"[Enhance] System prompt preview: {_ENHANCE_SYSTEM_PROMPT[:200]}...")

def _pil_to_base64_url(image) -> str:
    """Convert a PIL Image to a base64 data URL for the vision API."""
    import base64
    import io
    buf = io.BytesIO()
    image.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
    return f"data:image/png;base64,{b64}"


class ContentBlockedError(Exception):
    """Raised when the LLM safety filter blocks the request."""
    pass


def enhance_prompt(user_prompt: str, image=None) -> str:
    """Rewrite user prompt into the model's required format via LLM API.

    Sends both the text prompt and the reference image to the LLM
    so it can describe the character and scene accurately.

    Raises ContentBlockedError if the safety filter blocks the content.
    """
    print(f"[Enhance] Starting rewrite, input length={len(user_prompt)} chars, has_image={image is not None}")
    print(f"[Enhance] User prompt: {user_prompt[:100]}...")
    try:
        # Build user message with text + image
        user_content = []
        if image is not None:
            user_content.append({
                "type": "image_url",
                "image_url": {"url": _pil_to_base64_url(image)},
            })
        user_content.append({
            "type": "text",
            "text": user_prompt,
        })

        resp = _enhance_client.chat.completions.create(
            model=_ENHANCE_MODEL,
            messages=[
                {"role": "system", "content": _ENHANCE_SYSTEM_PROMPT},
                {"role": "user", "content": user_content},
            ],
            temperature=0.3,
            max_tokens=2048,
        )
        choice = resp.choices[0]
        raw_content = choice.message.content
        finish_reason = choice.finish_reason
        print(f"[Enhance] API returned: finish_reason={finish_reason}, "
              f"content_length={len(raw_content) if raw_content else 0}")

        # Check for content filter block
        if finish_reason == "content_filter" or not (raw_content or "").strip():
            print(f"[Enhance] Content blocked by safety filter (finish_reason={finish_reason})")
            raise ContentBlockedError(
                "Your request was blocked by the content safety filter. "
                "Please modify your prompt or image and try again."
            )

        rewritten = raw_content.strip()
        print(f"[Enhance] Done, output length={len(rewritten)} chars")
        print(f"[Enhance] Rewritten: {rewritten[:150]}...")
        return rewritten
    except ContentBlockedError:
        raise
    except Exception as e:
        err_str = str(e).lower()
        if "block" in err_str or "safety" in err_str or "content" in err_str:
            print(f"[Enhance] Content blocked by API error: {e}")
            raise ContentBlockedError(
                "Your request was blocked by the content safety filter. "
                "Please modify your prompt or image and try again."
            )
        print(f"[Enhance] FAILED: {e}, using original prompt")
        return user_prompt


# ── Generation ───────────────────────────────────────────────────────


def step1_enhance(image, prompt, seed, seconds):
    """Step 1: Validate inputs and enhance prompt via LLM.

    Returns the enhanced prompt. Runs as a regular function so only
    enhanced_prompt_box shows the processing animation.
    """
    if image is None:
        raise gr.Error("Please upload a reference image.")
    if not prompt or not prompt.strip():
        raise gr.Error("Please enter a text prompt.")

    print(f"[Generate] Request received: seed={seed} seconds={seconds} prompt={prompt[:50]!r}")
    try:
        enhanced = enhance_prompt(prompt.strip(), image=image)
    except ContentBlockedError as e:
        print(f"[Generate] Blocked by content filter, aborting generation")
        raise gr.Error(str(e))

    return enhanced


def step2_generate(image, enhanced, seed, seconds):
    """Step 2: Send generation request to router.

    Regular (non-generator) function β€” Gradio's queue mode keeps the SSE
    connection alive via its own heartbeat, so the processing animation on
    video_output stays active for the entire duration.
    """
    if not enhanced or not enhanced.strip():
        return None, "Error: No enhanced prompt."

    print(f"[Generate] Sending to router ...")
    result = generate(
        image=image,
        video_prompt=enhanced,
        seed=int(seed),
        output_dir=OUTPUT_DIR,
        seconds=int(seconds),
    )

    if result["error"]:
        print(f"[Generate] Error from router: {result['error']}")
        return None, f"Error: {result['error']}"

    video_path = result["video_path"]
    if not video_path or not os.path.isfile(video_path):
        return None, "Error: Video file not found."

    status = f"Done. seed={result['seed']}"
    print(f"[Generate] Success: {video_path}")
    return video_path, status


# ── Gradio UI ────────────────────────────────────────────────────────

TITLE = "daVinci-MagiHuman β€” Audio-Video Generation"
DESCRIPTION = (
    "Upload a reference image, describe what you want in the video, choose the "
    "duration (4–10 s), and click **Generate**. Your prompt will be automatically "
    "enhanced into the optimal format before generation.\n\n"
    "**Model**: 15B single-stream Transformer (distilled, 8-step inference) "
    "| **Resolution**: 448Γ—256 β†’ 540p | **FPS**: 25"
)

with gr.Blocks(title=TITLE, theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"# {TITLE}")
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(
                label="Reference Image",
                type="pil",
                height=300,
            )
            prompt_input = gr.Textbox(
                label="Video Description (will be auto-enhanced)",
                placeholder="Describe the scene, character actions, dialogue, etc. Your prompt will be automatically enhanced for optimal generation.",
                lines=6,
            )
            with gr.Row():
                seed_input = gr.Number(
                    label="Seed (-1 = random)",
                    value=-1,
                    precision=0,
                )
                seconds_slider = gr.Slider(
                    minimum=4,
                    maximum=10,
                    step=1,
                    value=5,
                    label="Duration (seconds)",
                )
            generate_btn = gr.Button("Generate", variant="primary")

        with gr.Column(scale=1):
            video_output = gr.Video(label="Generated Video")
            enhanced_prompt_box = gr.Textbox(
                label="Enhanced Prompt (sent to model)",
                interactive=False,
                lines=8,
            )
            status_box = gr.Textbox(label="Status", interactive=False, lines=2)

    generate_btn.click(
        fn=step1_enhance,
        inputs=[image_input, prompt_input, seed_input, seconds_slider],
        outputs=[enhanced_prompt_box],
    ).then(
        fn=step2_generate,
        inputs=[image_input, enhanced_prompt_box, seed_input, seconds_slider],
        outputs=[video_output, status_box],
    )

if __name__ == "__main__":
    # Gradio does NOT queue or throttle β€” every click goes straight to the router.
    # default_concurrency_limit=None removes Gradio's concurrency gate.
    demo.queue(default_concurrency_limit=None).launch(
        server_name="0.0.0.0",
        server_port=7860,
    )