linoyts's picture
linoyts HF Staff
Update app.py
de8cf24 verified
#!/usr/bin/env python3
"""Avatar Generator — text-prompt-driven talking avatar.
Pipeline (single Gradio Blocks app):
1. User provides a unified Dramabox-style prompt + optional voice reference
+ an avatar reference image.
2. Dramabox (LTX-2.3 audio branch + IC-LoRA, in-process, warm-loaded) turns
the prompt into a watermarked WAV inside a @spaces.GPU window.
3. The WAV + image are sent to the deployed
`victor/LongCat-Video-Avatar-1.5` Space via `gradio_client`, which
returns the final lip-synced MP4. That step uses *its* GPU quota, not
ours — keeping this Space's per-call GPU window tight.
"""
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import time
_MAX_SEED = 2**31 - 1
import gradio as gr
from gradio_client import Client, handle_file
import spaces
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
from inference_server import TTSServer # noqa: E402
from model_downloader import get_all_paths # noqa: E402
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logging.info("Fetching DramaBox checkpoints from HuggingFace (cached after first run)...")
PATHS = get_all_paths()
logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
tts = TTSServer(
checkpoint=PATHS["transformer"],
full_checkpoint=PATHS["audio_components"],
gemma_root=PATHS["gemma_root"],
device="cuda",
dtype=os.environ.get("LTX_DTYPE", "bf16"),
compile_model=False,
bnb_4bit=True,
)
logging.info("TTSServer ready.")
# Pre-warm Perth watermarker on the tts instance. inference_server.py loads
# it lazily on the first watermark call ("loaded PerthNet (Implicit) at step
# 250,000" — ~9 s of disk + setup), which ate the entire GPU window on the
# first request. Pre-loading at boot pushes that cost off the hot path.
try:
import perth
tts._perth = perth.PerthImplicitWatermarker()
logging.info("Perth watermarker pre-warmed.")
except Exception as e:
logging.warning(f"Perth pre-warm skipped ({e}); first request will pay the load cost.")
# ── Remote video pipeline ────────────────────────────────────────────────────
# We don't load LongCat-Video-Avatar locally: its weights are ~20 GB and
# loading both pipelines in one ZeroGPU process is fragile. Instead we proxy
# to the public Space via gradio_client. HF_TOKEN is forwarded so quota and
# queue priority are attributed to the caller, not anonymous traffic.
LONGCAT_SPACE = os.environ.get("LONGCAT_SPACE", "victor/LongCat-Video-Avatar-1.5")
_VIDEO_CLIENT: Client | None = None
def _build_video_client() -> Client | None:
token = os.environ.get("HF_TOKEN")
logging.info(f"Connecting to {LONGCAT_SPACE} via gradio_client...")
try:
# gradio_client>=1.0 renamed `hf_token` -> `token`. Passing None when
# no token is set is fine — gradio_client treats it as anonymous.
return Client(LONGCAT_SPACE, token=token)
except Exception as e:
# Don't take the whole Space down if LongCat is briefly unreachable
# at boot — we'll retry lazily on first request.
logging.warning(f"Could not pre-warm video client at boot ({e}); will retry on first request.")
return None
def _video_client() -> Client:
"""Return the pre-warmed client, or build one on demand if boot-time
construction failed. The pre-warm at module init saves ~0.5–2 s of TLS +
/info handshake on the very first user request."""
global _VIDEO_CLIENT
if _VIDEO_CLIENT is None:
_VIDEO_CLIENT = _build_video_client()
if _VIDEO_CLIENT is None:
# Build threw again — surface a clean error to the user.
raise gr.Error(f"Couldn't connect to {LONGCAT_SPACE}. Try again in a moment.")
return _VIDEO_CLIENT
# Eager pre-warm so the first user request doesn't pay the gradio_client
# handshake against the LongCat Space.
_VIDEO_CLIENT = _build_video_client()
# ── Optional portrait generator (FLUX.2-klein-4B) ────────────────────────────
# Lazy: most users will arrive with their own photo and never touch this tab.
# Pre-warming a third remote client at boot would just slow startup for a
# feature only a fraction of users hit. /infer accepts an empty input_images
# list for pure text→image, or a single image dict for prompt+image editing.
FLUX_SPACE = os.environ.get("FLUX_SPACE", "black-forest-labs/FLUX.2-klein-4B")
_FLUX_CLIENT: Client | None = None
def _flux_client() -> Client:
global _FLUX_CLIENT
if _FLUX_CLIENT is None:
token = os.environ.get("HF_TOKEN")
logging.info(f"Connecting to {FLUX_SPACE} via gradio_client...")
try:
_FLUX_CLIENT = Client(FLUX_SPACE, token=token)
except Exception as e:
raise gr.Error(f"Couldn't connect to {FLUX_SPACE}: {e}")
return _FLUX_CLIENT
def generate_portrait(
flux_prompt: str,
flux_edit_image: str | None,
progress=gr.Progress(track_tqdm=True),
):
"""Call FLUX.2-klein-4B /infer with the Distilled 4-step preset. With
``flux_edit_image`` set this edits the existing portrait; without, it
generates from scratch. Returns the generated image path *and* a
gr.Tabs update so the Generate tab folds back to the Upload tab on
completion (so the user sees the new portrait in the same component
they'd upload one to). Pattern: multimodalart/wan-2-2-first-last-frame.
"""
if not flux_prompt or not flux_prompt.strip():
raise gr.Error("Please describe the portrait you want.")
progress(0.05, desc="Connecting to FLUX.2-klein-4B…")
client = _flux_client()
images_arg = (
[{"image": handle_file(flux_edit_image)}]
if flux_edit_image and os.path.exists(flux_edit_image)
else []
)
mode_desc = "Editing portrait" if images_arg else "Generating portrait"
progress(0.2, desc=f"{mode_desc} (FLUX.2-klein-4B, 4 steps)…")
t0 = time.time()
result = client.predict(
prompt=flux_prompt,
input_images=images_arg,
mode_choice="Distilled (4 steps)",
seed=0,
randomize_seed=True,
width=1024,
height=1024,
num_inference_steps=4,
guidance_scale=1.0,
prompt_upsampling=False,
api_name="/infer",
)
logging.info(f"[flux] {time.time() - t0:.2f}s -> {result}")
# result is (image_dict, seed); image_dict has `path` (local cached copy
# downloaded by gradio_client) and `url`.
image_dict = result[0] if isinstance(result, (list, tuple)) else result
image_path = (
image_dict.get("path") if isinstance(image_dict, dict) else image_dict
)
progress(1.0, desc="Done")
return image_path, gr.Tabs(selected="portrait_upload")
def _video_prompt_from_script(script: str) -> str:
"""Derive a clean visual prompt for LongCat from the unified Dramabox
script. Dramabox prompts wrap dialogue in straight quotes, e.g.
`A shadowy villain speaks coldly, "You have entered my domain."` — the
quoted text is what gets *spoken*, the lead-in is the *speaker
description*. LongCat's prompt should describe the *visual*, so we keep
the speaker description and drop the dialogue.
Falls back to a neutral caption if the script is empty or unquoted.
"""
if not script or not script.strip():
return "A person speaks expressively, looking at the camera."
# Take everything up to the first quote, stripping trailing commas/spaces.
head = script.split('"', 1)[0].strip().rstrip(",").strip()
if not head:
return "A person speaks expressively, looking at the camera."
# Anchor it to a portrait shot so LongCat doesn't reframe the avatar.
if "camera" not in head.lower():
head += ", speaking to the camera"
return head
# ── GPU window sizing (TTS step only — video runs on the remote Space) ──────
_GPU_BASE_S = 10
_GPU_PER_SENTENCE_S = 1
_GPU_CAP_S = 110
def _count_sentences(prompt: str) -> int:
if not prompt or not prompt.strip():
return 1
try:
from text_chunker import split_sentences_outside_quotes
n = len(split_sentences_outside_quotes(prompt))
except Exception:
n = sum(1 for ch in prompt if ch in ".!?")
return max(1, n)
def _tts_gpu_duration(
prompt: str,
voice_ref: str | None,
cfg: float,
stg: float,
steps: int,
duration: float,
seed: int,
resolution: str,
progress=None,
) -> int:
# Denoise time scales with audio length × steps. Observed: ~0.012 s of
# GPU per (sec of audio × step) at default settings; 0.05 here gives ~4×
# safety margin. Base covers Gemma encode + VAE decode + watermark + save.
needed = _GPU_BASE_S + float(duration) * int(steps) * 0.05
return max(_GPU_BASE_S, min(int(round(needed)) + 2, _GPU_CAP_S))
@spaces.GPU(duration=_tts_gpu_duration)
def _run_tts(
prompt: str,
voice_ref: str | None,
cfg: float,
stg: float,
steps: int,
duration: float,
seed: int,
resolution: str,
progress=gr.Progress(),
) -> str:
"""TTS step. Returns path to a watermarked .wav file of length `duration`."""
if not prompt or not prompt.strip():
raise gr.Error("Prompt is empty.")
progress(0.05, desc="Generating speech with Dramabox…")
out_wav = tempfile.mktemp(suffix=".wav", prefix="avgen_tts_")
t0 = time.time()
tts.generate_to_file(
prompt=prompt,
output=out_wav,
voice_ref=voice_ref if voice_ref and os.path.exists(voice_ref) else None,
cfg_scale=float(cfg),
stg_scale=float(stg),
steps=int(steps),
duration_multiplier=1.1,
seed=int(seed),
gen_duration=float(duration),
ref_duration=10.0,
denoise_ref=False,
)
logging.info(f"[tts] {time.time() - t0:.2f}s -> {out_wav} (steps={int(steps)}, dur={float(duration):.1f}s)")
return out_wav
_LONGCAT_VIDEO_SECONDS = 5.0 # LongCat /generate hardcodes NUM_FRAMES=125 @ 25fps
def _trim_video(src_mp4: str, duration: float) -> str:
"""Trim ``src_mp4`` to ``duration`` seconds. Re-encodes (instead of
`-c copy`) so the cut is sample-accurate regardless of keyframe layout —
LongCat's mp4 is ~5 s so the re-encode is sub-second."""
if duration >= _LONGCAT_VIDEO_SECONDS - 0.05:
return src_mp4 # already full length
out = tempfile.mktemp(suffix=".mp4", prefix="avgen_trim_")
cmd = [
"ffmpeg", "-y", "-loglevel", "error",
"-i", src_mp4,
"-t", f"{duration:.3f}",
"-c:v", "libx264", "-preset", "veryfast", "-crf", "20",
"-c:a", "aac", "-b:a", "128k",
out,
]
try:
subprocess.run(cmd, check=True)
return out
except Exception as e:
logging.warning(f"[trim] ffmpeg trim failed ({e}); returning untrimmed clip")
return src_mp4
def generate_avatar(
image_path: str,
voice_ref: str | None,
prompt: str,
cfg: float,
stg: float,
steps: int,
duration: float,
seed: int,
randomize_seed: bool,
resolution: str,
progress=gr.Progress(),
):
if not image_path:
raise gr.Error("Please upload a reference portrait.")
if not voice_ref:
raise gr.Error("Please record or upload a voice clip (10+ seconds) to clone.")
if not prompt or not prompt.strip():
raise gr.Error("Please enter a script.")
if randomize_seed:
seed = random.randint(0, _MAX_SEED)
logging.info(f"[seed] randomized -> {seed}")
wav_path = _run_tts(prompt, voice_ref, cfg, stg, steps, duration, seed, resolution, progress)
progress(0.55, desc="Generating talking-head video on LongCat-Video-Avatar…")
client = _video_client()
video_prompt = _video_prompt_from_script(prompt)
logging.info(f"[video] prompt={video_prompt!r} resolution={resolution} seed={seed}")
t0 = time.time()
# Param order matches victor/LongCat-Video-Avatar-1.5 `generate(image_path,
# audio_path, prompt, resolution, seed, vocal_mode, acceleration)`.
# vocal_mode is forced to the fast path because our TTS output is already
# clean studio audio — no need for vocal isolation. acceleration is the
# 8-step DBCache faster preset which runs ~2× faster than exact 8-step at
# negligible quality cost.
result = client.predict(
handle_file(image_path),
handle_file(wav_path),
video_prompt,
resolution,
int(seed),
"Clean speech (fast)",
"DBCache faster",
api_name="/generate",
)
logging.info(f"[video] {time.time() - t0:.2f}s -> {result}")
if isinstance(result, dict):
video_path = result.get("video") or result.get("path") or result
else:
video_path = result
if duration < _LONGCAT_VIDEO_SECONDS:
progress(0.95, desc=f"Trimming to {duration:.1f}s…")
video_path = _trim_video(video_path, float(duration))
progress(1.0, desc="Done")
return video_path, seed
# ── UI ──────────────────────────────────────────────────────────────────────
_ASSETS = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
_AVATARS_DIR = os.path.join(_ASSETS, "avatars")
def _a(name: str) -> str:
return os.path.join(_AVATARS_DIR, name)
# Examples fill portrait + script only. Voice is the user's own, advanced
# settings keep their defaults.
EXAMPLES = [
[
_a("orc_warrior.png"),
'A shadowy warlord speaks with cold menace, "You have entered my domain, mortal." '
'He chuckles darkly, "Such arrogance will be your undoing." '
],
[
_a("photoreal_person.png"),
'A radio host clears his throat, "Excuse me, pardon that." '
'He settles into a warm, professional tone, "Good evening everyone, '
],
[
_a("character.png"),
'A playful character already mid-giggle, "Hehehe, oh my gosh you should see your face!" '
],
]
# Theme inspired by victor/ace-step-jam: dark slate palette, Hanken Grotesk,
# tight radius, subtle frosted surfaces. ace-step-jam itself is a custom
# HTML/CSS frontend, so this is an approximation translated to a Gradio Blocks
# theme + minimal CSS — same vibe, way less surface area.
THEME = gr.themes.Soft(
primary_hue=gr.themes.colors.slate,
secondary_hue=gr.themes.colors.slate,
neutral_hue=gr.themes.colors.slate,
radius_size=gr.themes.sizes.radius_sm,
text_size=gr.themes.sizes.text_md,
font=[gr.themes.GoogleFont("Hanken Grotesk"), "system-ui", "sans-serif"],
).set(
body_background_fill="oklch(0.13 0.006 260)",
body_background_fill_dark="oklch(0.13 0.006 260)",
body_text_color="rgba(255, 255, 255, 0.87)",
body_text_color_dark="rgba(255, 255, 255, 0.87)",
background_fill_primary="rgba(255, 255, 255, 0.04)",
background_fill_primary_dark="rgba(255, 255, 255, 0.04)",
background_fill_secondary="rgba(255, 255, 255, 0.06)",
background_fill_secondary_dark="rgba(255, 255, 255, 0.06)",
border_color_primary="rgba(255, 255, 255, 0.08)",
border_color_primary_dark="rgba(255, 255, 255, 0.08)",
block_background_fill="rgba(255, 255, 255, 0.04)",
block_background_fill_dark="rgba(255, 255, 255, 0.04)",
block_border_color="rgba(255, 255, 255, 0.08)",
block_border_color_dark="rgba(255, 255, 255, 0.08)",
block_label_background_fill="transparent",
block_label_background_fill_dark="transparent",
block_title_text_color="rgba(255, 255, 255, 0.87)",
block_title_text_color_dark="rgba(255, 255, 255, 0.87)",
input_background_fill="rgba(255, 255, 255, 0.04)",
input_background_fill_dark="rgba(255, 255, 255, 0.04)",
input_border_color="rgba(255, 255, 255, 0.08)",
input_border_color_dark="rgba(255, 255, 255, 0.08)",
button_primary_background_fill="oklch(0.90 0.005 260)",
button_primary_background_fill_dark="oklch(0.90 0.005 260)",
button_primary_background_fill_hover="oklch(0.95 0.005 260)",
button_primary_background_fill_hover_dark="oklch(0.95 0.005 260)",
button_primary_text_color="oklch(0.13 0.006 260)",
button_primary_text_color_dark="oklch(0.13 0.006 260)",
)
CUSTOM_CSS = """
main, .gradio-container, .fillable:not(.fill_width) {
width: min(100%, 1180px) !important;
max-width: 1180px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.gradio-container { font-feature-settings: "ss01", "cv11"; }
#hero h1 {
font-weight: 600;
letter-spacing: -0.02em;
margin-bottom: 0.25em;
}
#hero p { color: rgba(255, 255, 255, 0.55); margin-top: 0; }
.gr-button-primary {
letter-spacing: -0.01em;
font-weight: 600;
}
"""
with gr.Blocks(title="Avatar Generator", theme=THEME, css=CUSTOM_CSS) as demo:
gr.Markdown(
"""
# Avatar Generator
SOTA Avatar generation with synthetic speech using [Dramabox](https://huggingface.co/ResembleAI/Dramabox) and [LongCat-Video-Avatar 1.5](https://huggingface.co/meituan-longcat/LongCat-Video-Avatar-1.5).
Upload/generate a portrait, clone your voice (or upload one), write a script — get a lip-synced talking-head✨
""",
elem_id="hero",
)
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
with gr.Tabs() as portrait_tabs:
with gr.TabItem("Upload", id="portrait_upload"):
image_in = gr.Image(
label="Reference portrait",
type="filepath",
height=260,
sources=["upload", "clipboard"],
)
with gr.TabItem("Generate / edit", id="portrait_generate"):
flux_prompt = gr.Textbox(
info="Describe the portrait (or the edit, if you attach one below)",
placeholder=(
"e.g. A photorealistic portrait of an elderly fisherman with "
"weathered skin and a wool sweater, neutral studio backdrop"
),
lines=1,
)
flux_edit_image = gr.Image(
label="optional: input image",
type="filepath",
height=160,
sources=["upload", "clipboard"],
)
flux_go = gr.Button("Generate portrait", variant="secondary")
voice_in = gr.Audio(
label="Avatar voice",
type="filepath",
sources=["upload", "microphone"],
)
prompt = gr.Textbox(
label="Script",
value=(
'A confident announcer speaks proudly, "And now, the moment '
'you have all been waiting for." He chuckles knowingly, '
'"Heheh, trust me, this one is going to blow you away."'
),
lines=4,
)
with gr.Accordion("Advanced", open=False):
# LongCat's /generate API hardcodes 5 s of video output. We
# can shorten by pacing TTS to the requested length + trimming
# the returned mp4, but we can't go longer from a single call.
duration_in = gr.Slider(
1.0, 5.0, value=5.0, step=0.5,
label="Output duration (seconds, max 5)",
)
with gr.Row():
resolution = gr.Radio(["480p", "720p"], value="480p", label="Resolution")
with gr.Row():
seed = gr.Number(value=42, precision=0, label="Seed")
randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
# Default 22 trades ~25% of the TTS step for negligible quality
# cost on typical short prompts; bump back toward 30 for the
# cleanest output, drop toward 14 for fastest iteration.
steps_in = gr.Slider(10, 40, value=22, step=1, label="TTS steps (Euler)")
cfg = gr.Slider(1.0, 5.0, value=2.5, step=0.1, label="TTS CFG scale")
stg = gr.Slider(0.0, 3.0, value=1.5, step=0.1, label="TTS STG scale")
go = gr.Button("Generate avatar", variant="primary")
with gr.Column(scale=1):
video_out = gr.Video(label="Output", autoplay=True, height=420)
gr.Examples(
examples=EXAMPLES,
inputs=[image_in, prompt],
outputs=None,
fn=None,
cache_examples=False,
examples_per_page=4,
label="Script + portrait examples (then add your own voice above)",
)
flux_go.click(
generate_portrait,
inputs=[flux_prompt, flux_edit_image],
outputs=[image_in, portrait_tabs],
show_progress="full",
)
go.click(
generate_avatar,
inputs=[image_in, voice_in, prompt, cfg, stg, steps_in, duration_in, seed, randomize_seed, resolution],
outputs=[video_out, seed],
)
if __name__ == "__main__":
demo.queue(max_size=8).launch(show_error=True)