llaa33219's picture
Update app.py
919bd60 verified
"""
SteadyDancer-14B - ZeroGPU ์ตœ์ ํ™” ๋ฒ„์ „
=====================================
์ฃผ์š” ๋ณ€๊ฒฝ์‚ฌํ•ญ:
1. subprocess ์ œ๊ฑฐ โ†’ ์ง์ ‘ Python import ์‚ฌ์šฉ (ZeroGPU ํ˜ธํ™˜์„ฑ)
2. ๋ชจ๋ธ ๋กœ๋”ฉ ์ตœ์ ํ™” (์ „์—ญ ์บ์‹ฑ + GPU ํ•จ์ˆ˜ ๋‚ด ์ด๋™)
3. duration ์กฐ์ • (300์ดˆ = ZeroGPU ์ตœ๋Œ€๊ฐ’)
4. ํ”„๋ ˆ์ž„ ์ˆ˜ ์ œํ•œ์œผ๋กœ ํƒ€์ž„์•„์›ƒ ๋ฐฉ์ง€
5. ํฌ์ฆˆ ์ถ”์ถœ์„ GPU ํ•จ์ˆ˜ ๋ฐ–์œผ๋กœ ๋ถ„๋ฆฌ
6. ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” (torch.cuda.empty_cache, gc)
"""
import gradio as gr
import spaces
import torch
import os
import gc
import tempfile
import sys
import shutil
from pathlib import Path
from PIL import Image
import cv2
import numpy as np
from huggingface_hub import snapshot_download
# ========== ์ƒ์ˆ˜ ์ •์˜ ==========
REPO_DIR = Path("SteadyDancer")
MODEL_DIR = Path("SteadyDancer-14B")
MAX_FRAMES = 49 # ZeroGPU ํ† ํฐ ๋งŒ๋ฃŒ ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด ์ œํ•œ
MAX_DURATION_SECONDS = 300 # ZeroGPU ์ตœ๋Œ€ ํ—ˆ์šฉ ์‹œ๊ฐ„
# ========== ์ „์—ญ ์บ์‹œ (CPU์— ์œ ์ง€) ==========
_pipe = None
_pose_detector = None
_repo_ready = False
def ensure_repo():
"""SteadyDancer ๋ ˆํฌ์ง€ํ† ๋ฆฌ ํด๋ก  ๋ฐ ์˜์กด์„ฑ ์„ค์น˜ (1ํšŒ๋งŒ)"""
global _repo_ready
if _repo_ready:
return
if not REPO_DIR.exists():
print("๐Ÿ“ฅ Cloning SteadyDancer repository...")
import git
git.Repo.clone_from(
"https://github.com/MCG-NJU/SteadyDancer.git",
str(REPO_DIR),
depth=1
)
# ๋ ˆํฌ์˜ requirements.txt ์„ค์น˜
repo_requirements = REPO_DIR / "requirements.txt"
if repo_requirements.exists():
print("๐Ÿ“ฆ Installing SteadyDancer requirements...")
import subprocess
subprocess.run([
sys.executable, "-m", "pip", "install", "-q",
"-r", str(repo_requirements)
], check=False)
if str(REPO_DIR) not in sys.path:
sys.path.insert(0, str(REPO_DIR))
_repo_ready = True
print("โœ… Repository ready")
def ensure_model():
"""๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ (1ํšŒ๋งŒ)"""
if not MODEL_DIR.exists():
print("๐Ÿ“ฅ Downloading SteadyDancer-14B model weights...")
snapshot_download(
repo_id="MCG-NJU/SteadyDancer-14B",
local_dir=str(MODEL_DIR),
resume_download=True
)
print("โœ… Model weights downloaded")
def get_pose_detector():
"""ํฌ์ฆˆ ๋””ํ…ํ„ฐ ๋กœ๋“œ (CPU์— ์œ ์ง€, ํ•„์š”์‹œ GPU๋กœ ์ด๋™)"""
global _pose_detector
if _pose_detector is None:
print("๐Ÿ“ฅ Loading DWPose detector...")
try:
from controlnet_aux import DWposeDetector
_pose_detector = DWposeDetector.from_pretrained("lllyasviel/Annotators")
except Exception as e:
print(f"โš ๏ธ DWPose ๋กœ๋“œ ์‹คํŒจ, OpenPose๋กœ ๋Œ€์ฒด: {e}")
from controlnet_aux import OpenposeDetector
_pose_detector = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
print("โœ… Pose detector loaded")
return _pose_detector
def extract_poses_from_video(video_path, output_dir, max_frames=MAX_FRAMES, progress_callback=None):
"""
๋“œ๋ผ์ด๋น™ ๋น„๋””์˜ค์—์„œ ํฌ์ฆˆ ์ถ”์ถœ (CPU์—์„œ ์‹คํ–‰)
- GPU ์‹œ๊ฐ„ ์ ˆ์•ฝ์„ ์œ„ํ•ด @spaces.GPU ๋ฐ–์—์„œ ์‹คํ–‰
- ํ”„๋ ˆ์ž„ ์ˆ˜ ์ œํ•œ์œผ๋กœ ํƒ€์ž„์•„์›ƒ ๋ฐฉ์ง€
"""
pose_detector = get_pose_detector()
# GPU๊ฐ€ ์žˆ์œผ๋ฉด ์ด๋™ (ZeroGPU๊ฐ€ ์•„๋‹Œ ํ™˜๊ฒฝ์—์„œ)
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
pose_detector = pose_detector.to(device)
except:
pass # ZeroGPU์—์„œ๋Š” ์‹คํŒจํ•  ์ˆ˜ ์žˆ์Œ
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS) or 24
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# ํ”„๋ ˆ์ž„ ์ˆ˜ ์ œํ•œ
target_frames = min(total_frames, max_frames)
# ํ”„๋ ˆ์ž„ ์ƒ˜ํ”Œ๋ง (์›๋ณธ ํ”„๋ ˆ์ž„์ด ๋งŽ์œผ๋ฉด ๊ท ๋“ฑ ์ƒ˜ํ”Œ๋ง)
if total_frames > max_frames:
frame_indices = np.linspace(0, total_frames - 1, max_frames, dtype=int)
else:
frame_indices = list(range(total_frames))
pos_dir = Path(output_dir) / "positive"
neg_dir = Path(output_dir) / "negative"
pos_dir.mkdir(parents=True, exist_ok=True)
neg_dir.mkdir(parents=True, exist_ok=True)
extracted_count = 0
for idx, frame_idx in enumerate(frame_indices):
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if not ret:
continue
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
try:
with torch.inference_mode():
pose_image = pose_detector(pil_image)
except Exception as e:
print(f"โš ๏ธ Frame {idx} pose extraction failed: {e}")
pose_image = Image.new('RGB', pil_image.size, (0, 0, 0))
pose_image.save(pos_dir / f"{idx:04d}.jpg")
pose_image.save(neg_dir / f"{idx:04d}.jpg")
extracted_count += 1
if progress_callback:
progress_callback(idx / len(frame_indices))
cap.release()
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return str(pos_dir), str(neg_dir), fps, extracted_count
def load_steadydancer_components():
"""
SteadyDancer ์ปดํฌ๋„ŒํŠธ ๋กœ๋“œ (generate.py ๋ฐฉ์‹ ์ฐธ์กฐ)
ZeroGPU์—์„œ๋Š” ๋งค ํ˜ธ์ถœ๋งˆ๋‹ค ์ƒˆ๋กœ ๋กœ๋“œํ•ด์•ผ ํ•จ
"""
ensure_repo()
ensure_model()
print("๐Ÿ“ฅ Loading SteadyDancer components...")
# SteadyDancer ๋‚ด๋ถ€ ๋ชจ๋“ˆ import
from wan.configs import WAN_CONFIGS
from wan.modules.vae import WanVAE
from wan.modules.t5 import T5EncoderModel
from wan.modules.clip import CLIPModel
from wan.modules.model import WanModel
cfg = WAN_CONFIGS["i2v-14B"]
# ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ •
ckpt_dir = str(MODEL_DIR)
# T5 ํ…์ŠคํŠธ ์ธ์ฝ”๋”
t5_encoder = T5EncoderModel(
text_len=cfg.text_len,
dtype=cfg.t5_dtype,
device="cuda",
checkpoint_path=f"{ckpt_dir}/models_t5_umt5-xxl-enc-bf16.pth",
tokenizer_path=f"{ckpt_dir}/google_umt5-xxl",
spiece_path=f"{ckpt_dir}/google_umt5-xxl/spiece.model",
)
# CLIP ๋น„์ „ ์ธ์ฝ”๋”
clip_encoder = CLIPModel(
dtype=cfg.clip_dtype,
device="cuda",
checkpoint_path=f"{ckpt_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
)
# VAE
vae = WanVAE(
vae_pth=f"{ckpt_dir}/Wan2.1_VAE.pth",
device="cuda",
dtype=cfg.vae_dtype,
)
# Main Model (DiT)
model = WanModel.from_pretrained(ckpt_dir, torch_dtype=torch.bfloat16)
model = model.to("cuda")
model.eval()
print("โœ… All components loaded")
return cfg, t5_encoder, clip_encoder, vae, model
@spaces.GPU(duration=MAX_DURATION_SECONDS)
def generate_video_gpu(
ref_image_path: str,
pos_folder: str,
neg_folder: str,
prompt: str,
cfg_scale: float,
condition_guide_scale: float,
seed: int,
width: int,
height: int,
output_path: str,
num_frames: int = 49,
):
"""
GPU์—์„œ ๋น„๋””์˜ค ์ƒ์„ฑ (SteadyDancer ๋‚ด๋ถ€ API ์ง์ ‘ ์‚ฌ์šฉ)
"""
import random
import subprocess
from PIL import Image
print(f"๐ŸŽฌ Starting generation: {width}x{height}, seed={seed}, frames={num_frames}")
# ์‹œ๋“œ ์„ค์ •
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
try:
# SteadyDancer ๋‚ด๋ถ€ ๋ชจ๋“ˆ import
if str(REPO_DIR) not in sys.path:
sys.path.insert(0, str(REPO_DIR))
from wan.pipelines.pipeline_dancer import DancerPipeline
from wan.configs import WAN_CONFIGS
cfg = WAN_CONFIGS["i2v-14B"]
# ํŒŒ์ดํ”„๋ผ์ธ ์ƒ์„ฑ
print("๐Ÿ“ฆ Creating DancerPipeline...")
pipe = DancerPipeline(
config=cfg,
checkpoint_dir=str(MODEL_DIR),
device_id=0,
dtype=torch.bfloat16,
)
# ์ฐธ์กฐ ์ด๋ฏธ์ง€ ๋กœ๋“œ
ref_image = Image.open(ref_image_path).convert("RGB")
# ์ƒ์„ฑ ์‹คํ–‰
print("๐ŸŽจ Running inference...")
output = pipe.generate(
image=ref_image,
prompt=prompt,
cond_pos_folder=pos_folder,
cond_neg_folder=neg_folder,
size=f"{width}*{height}",
num_frames=num_frames,
sample_guide_scale=cfg_scale,
condition_guide_scale=condition_guide_scale,
seed=seed,
save_path=output_path,
)
print(f"โœ… Generation complete!")
except Exception as e:
print(f"โš ๏ธ Direct API failed: {e}")
print("โš ๏ธ Trying CLI fallback...")
import traceback
traceback.print_exc()
# Fallback: CLI ์‹คํ–‰
cmd = [
sys.executable, str(REPO_DIR / "generate_dancer.py"),
"--task", "i2v-14B",
"--size", f"{width}*{height}",
"--image", ref_image_path,
"--cond_pos_folder", pos_folder,
"--cond_neg_folder", neg_folder,
"--prompt", prompt,
"--save_file", output_path,
"--sample_guide_scale", str(cfg_scale),
"--condition_guide_scale", str(condition_guide_scale),
"--base_seed", str(seed),
"--ckpt_dir", str(MODEL_DIR),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=250)
if result.returncode != 0:
error_msg = result.stderr or result.stdout or str(e)
raise gr.Error(f"์ƒ์„ฑ ์‹คํŒจ: {error_msg[:300]}")
finally:
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return output_path
def generate_video(
reference_image,
driving_video,
prompt,
cfg_scale,
condition_guide_scale,
seed,
resolution,
max_frames,
progress=gr.Progress()
):
"""
๋ฉ”์ธ ์ƒ์„ฑ ํ•จ์ˆ˜ (Gradio ์ธํ„ฐํŽ˜์ด์Šค)
- ํฌ์ฆˆ ์ถ”์ถœ: CPU (GPU ์‹œ๊ฐ„ ์ ˆ์•ฝ)
- ๋น„๋””์˜ค ์ƒ์„ฑ: GPU (@spaces.GPU)
์ค‘์š”: ZeroGPU ํ† ํฐ ๋งŒ๋ฃŒ ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด ์ตœ๋Œ€ํ•œ ๋นจ๋ฆฌ GPU ํ•จ์ˆ˜ ํ˜ธ์ถœ
"""
if reference_image is None:
raise gr.Error("โŒ ์ฐธ์กฐ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.")
if driving_video is None:
raise gr.Error("โŒ ๋“œ๋ผ์ด๋น™ ๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.")
# ๋ชจ๋ธ์ด ์ค€๋น„๋˜์ง€ ์•Š์•˜์œผ๋ฉด ์—๋Ÿฌ (warmup์—์„œ ๋ฏธ๋ฆฌ ๋˜์–ด์•ผ ํ•จ)
if not MODEL_DIR.exists():
progress(0.1, desc="โณ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ค‘ (์ฒซ ์‹คํ–‰ ์‹œ ์˜ค๋ž˜ ๊ฑธ๋ฆผ)...")
ensure_model()
progress(0.05, desc="๐Ÿ”ง ํ™˜๊ฒฝ ์„ค์ • ์ค‘...")
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
# 1. ์ฐธ์กฐ ์ด๋ฏธ์ง€ ์ €์žฅ (๋น ๋ฆ„)
progress(0.08, desc="๐Ÿ“ธ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์ค‘...")
ref_image_path = tmpdir / "reference.png"
if isinstance(reference_image, str):
shutil.copy(reference_image, ref_image_path)
elif isinstance(reference_image, np.ndarray):
Image.fromarray(reference_image).save(ref_image_path)
else:
reference_image.save(ref_image_path)
# 2. ํฌ์ฆˆ ์ถ”์ถœ (CPU - ์ตœ๋Œ€ํ•œ ๋นจ๋ฆฌ!)
progress(0.1, desc="๐Ÿ•บ ํฌ์ฆˆ ์ถ”์ถœ ์ค‘...")
pose_dir = tmpdir / "poses"
pose_dir.mkdir(exist_ok=True)
# ํ”„๋ ˆ์ž„ ์ˆ˜ ๋” ์ œํ•œ (ํ† ํฐ ๋งŒ๋ฃŒ ๋ฐฉ์ง€)
actual_max_frames = min(int(max_frames), 49) # ์ตœ๋Œ€ 49ํ”„๋ ˆ์ž„์œผ๋กœ ์ œํ•œ
def pose_progress(p):
progress(0.1 + 0.25 * p, desc=f"๐Ÿ•บ ํฌ์ฆˆ ์ถ”์ถœ ์ค‘... {int(p*100)}%")
pos_folder, neg_folder, fps, frame_count = extract_poses_from_video(
driving_video,
pose_dir,
max_frames=actual_max_frames,
progress_callback=pose_progress
)
if frame_count == 0:
raise gr.Error("โŒ ๋น„๋””์˜ค์—์„œ ํ”„๋ ˆ์ž„์„ ์ถ”์ถœํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
progress(0.35, desc=f"โœ… {frame_count}๊ฐœ ํ”„๋ ˆ์ž„ ์ถ”์ถœ ์™„๋ฃŒ")
# 3. ํ•ด์ƒ๋„ ํŒŒ์‹ฑ
width, height = map(int, resolution.split("x"))
# 4. ์ถœ๋ ฅ ๊ฒฝ๋กœ
output_path = str(tmpdir / "output.mp4")
# 5. GPU ์ƒ์„ฑ ์‹คํ–‰ (์ตœ๋Œ€ํ•œ ๋นจ๋ฆฌ ํ˜ธ์ถœ!)
progress(0.4, desc="๐ŸŽฌ ๋น„๋””์˜ค ์ƒ์„ฑ ์ค‘ (GPU)...")
final_prompt = prompt.strip() if prompt and prompt.strip() else "A person dancing gracefully"
try:
generate_video_gpu(
ref_image_path=str(ref_image_path),
pos_folder=pos_folder,
neg_folder=neg_folder,
prompt=final_prompt,
cfg_scale=cfg_scale,
condition_guide_scale=condition_guide_scale,
seed=int(seed),
width=width,
height=height,
output_path=output_path,
num_frames=frame_count,
)
except Exception as e:
error_msg = str(e)
if "Expired ZeroGPU proxy token" in error_msg:
raise gr.Error(
"โŒ ZeroGPU ํ† ํฐ ๋งŒ๋ฃŒ๋จ. ํŽ˜์ด์ง€๋ฅผ ์ƒˆ๋กœ๊ณ ์นจํ•˜๊ณ  ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”. "
"ํŒ: ํ”„๋ ˆ์ž„ ์ˆ˜๋ฅผ 30 ์ดํ•˜๋กœ ์ค„์—ฌ๋ณด์„ธ์š”."
)
raise gr.Error(f"โŒ ์ƒ์„ฑ ์‹คํŒจ: {error_msg[:300]}")
progress(0.95, desc="๐Ÿ“ผ ๋น„๋””์˜ค ์ €์žฅ ์ค‘...")
# 6. ์ตœ์ข… ์ถœ๋ ฅ ๋ณต์‚ฌ
final_output = Path(tempfile.gettempdir()) / f"steadydancer_output_{seed}.mp4"
if Path(output_path).exists():
shutil.copy(output_path, final_output)
else:
raise gr.Error("โŒ ์ถœ๋ ฅ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
progress(1.0, desc="โœ… ์™„๋ฃŒ!")
return str(final_output)
# ========== Gradio UI ==========
with gr.Blocks(
title="SteadyDancer-14B - ZeroGPU Optimized",
theme=gr.themes.Soft(),
css="""
.main-title { text-align: center; margin-bottom: 1rem; }
.warning-box {
background: linear-gradient(135deg, #fff3cd 0%, #ffeeba 100%);
border: 1px solid #ffc107;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
}
.tip-box {
background: linear-gradient(135deg, #d4edda 0%, #c3e6cb 100%);
border: 1px solid #28a745;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
"""
) as demo:
gr.Markdown("""
# ๐Ÿ•บ SteadyDancer-14B (ZeroGPU ์ตœ์ ํ™”)
## Pose-Guided Human Image Animation
**๋“œ๋ผ์ด๋น™ ๋น„๋””์˜ค์˜ ๋™์ž‘์„ ์ฐธ์กฐ ์ด๋ฏธ์ง€์— ์ „์†กํ•ฉ๋‹ˆ๋‹ค!**
๐Ÿ“ [Paper](https://arxiv.org/abs/2412.12534) |
๐Ÿ”— [GitHub](https://github.com/MCG-NJU/SteadyDancer) |
๐Ÿค— [Model](https://huggingface.co/MCG-NJU/SteadyDancer-14B)
""", elem_classes=["main-title"])
gr.Markdown("""
### โš ๏ธ ZeroGPU ์ œํ•œ์‚ฌํ•ญ
- **์ตœ๋Œ€ ์‹คํ–‰ ์‹œ๊ฐ„**: 5๋ถ„ (300์ดˆ)
- **๊ถŒ์žฅ ํ”„๋ ˆ์ž„ ์ˆ˜**: 20-30 ํ”„๋ ˆ์ž„ (ํƒ€์ž„์•„์›ƒ/ํ† ํฐ ๋งŒ๋ฃŒ ๋ฐฉ์ง€)
- **๊ถŒ์žฅ ํ•ด์ƒ๋„**: 480x832 ๋˜๋Š” ๋” ๋‚ฎ์€ ํ•ด์ƒ๋„
- **์ฒซ ์‹คํ–‰**: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ๋กœ ์‹œ๊ฐ„์ด ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Œ โ†’ ํŽ˜์ด์ง€ ์ƒˆ๋กœ๊ณ ์นจ ํ›„ ์žฌ์‹œ๋„
""", elem_classes=["warning-box"])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ๐Ÿ“ธ ์ž…๋ ฅ")
reference_image = gr.Image(
label="์ฐธ์กฐ ์ด๋ฏธ์ง€ (์• ๋‹ˆ๋ฉ”์ด์…˜ํ•  ์ธ๋ฌผ)",
type="numpy",
sources=["upload", "clipboard"],
height=280
)
driving_video = gr.Video(
label="๋“œ๋ผ์ด๋น™ ๋น„๋””์˜ค (๋™์ž‘ ์†Œ์Šค)",
sources=["upload"],
height=280
)
prompt = gr.Textbox(
label="ํ”„๋กฌํ”„ํŠธ (์„ ํƒ์‚ฌํ•ญ)",
placeholder="์˜ˆ: A person dancing gracefully in a studio",
value=""
)
with gr.Accordion("โš™๏ธ ๊ณ ๊ธ‰ ์„ค์ •", open=True):
resolution = gr.Dropdown(
label="์ถœ๋ ฅ ํ•ด์ƒ๋„",
choices=[
"480x832", # ์„ธ๋กœ (๊ถŒ์žฅ)
"832x480", # ๊ฐ€๋กœ (๊ถŒ์žฅ)
"576x1024", # ์„ธ๋กœ HD
"1024x576", # ๊ฐ€๋กœ HD
"720x1280", # ์„ธ๋กœ HD+
"1280x720", # ๊ฐ€๋กœ HD+
],
value="480x832",
info="โšก ๋‚ฎ์€ ํ•ด์ƒ๋„ = ๋น ๋ฅธ ์ƒ์„ฑ + ํƒ€์ž„์•„์›ƒ ๋ฐฉ์ง€"
)
max_frames = gr.Slider(
label="์ตœ๋Œ€ ํ”„๋ ˆ์ž„ ์ˆ˜",
minimum=10,
maximum=49,
value=30, # ๊ธฐ๋ณธ๊ฐ’ ๋‚ฎ์ถค
step=1,
info="โšก ์ ์€ ํ”„๋ ˆ์ž„ = ๋น ๋ฅธ ์ƒ์„ฑ + ํ† ํฐ ๋งŒ๋ฃŒ ๋ฐฉ์ง€ (30 ๊ถŒ์žฅ)"
)
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=10.0,
value=5.0,
step=0.5
)
condition_guide_scale = gr.Slider(
label="Condition Guide Scale",
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.1
)
seed = gr.Slider(
label="์‹œ๋“œ",
minimum=0,
maximum=999999,
value=42,
step=1
)
generate_btn = gr.Button(
"๐ŸŽฌ ๋น„๋””์˜ค ์ƒ์„ฑ",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
gr.Markdown("### ๐ŸŽฅ ์ถœ๋ ฅ")
output_video = gr.Video(
label="์ƒ์„ฑ๋œ ๋น„๋””์˜ค",
height=450,
autoplay=True
)
gr.Markdown("""
### ๐Ÿ’ก ํŒ
- **์ฐธ์กฐ ์ด๋ฏธ์ง€**: ์ „์‹ ์ด ๋ณด์ด๊ณ  ๋ฐฐ๊ฒฝ์ด ๋‹จ์ˆœํ•œ ์ด๋ฏธ์ง€๊ฐ€ ์ข‹์Šต๋‹ˆ๋‹ค
- **๋“œ๋ผ์ด๋น™ ๋น„๋””์˜ค**: 3-5์ดˆ ์ •๋„์˜ ์งง์€ ๋น„๋””์˜ค๊ฐ€ ์ข‹์Šต๋‹ˆ๋‹ค
- **ํƒ€์ž„์•„์›ƒ ๋ฐœ์ƒ ์‹œ**: ํ”„๋ ˆ์ž„ ์ˆ˜์™€ ํ•ด์ƒ๋„๋ฅผ ๋‚ฎ์ถฐ๋ณด์„ธ์š”
- **์ฒซ ์‹คํ–‰**: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ๋กœ ์‹œ๊ฐ„์ด ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
""", elem_classes=["tip-box"])
generate_btn.click(
fn=generate_video,
inputs=[
reference_image,
driving_video,
prompt,
cfg_scale,
condition_guide_scale,
seed,
resolution,
max_frames,
],
outputs=output_video
)
def warmup():
"""
Space ์‹œ์ž‘ ์‹œ ๋ชจ๋“  ์ค€๋น„ ์ž‘์—… ์ˆ˜ํ–‰
- ๋ ˆํฌ ํด๋ก 
- ์˜์กด์„ฑ ์„ค์น˜
- ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ (๊ฐ€์žฅ ์˜ค๋ž˜ ๊ฑธ๋ฆผ!)
- ํฌ์ฆˆ ๋””ํ…ํ„ฐ ๋กœ๋“œ
์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์‚ฌ์šฉ์ž ์š”์ฒญ ์‹œ ZeroGPU ํ† ํฐ ๋งŒ๋ฃŒ ๋ฐฉ์ง€
"""
import subprocess
print("๐Ÿš€ Warming up SteadyDancer-14B...")
# 1. ํ•„์ˆ˜ ์˜์กด์„ฑ ๋จผ์ € ์„ค์น˜
print("๐Ÿ“ฆ Checking dependencies...")
deps_to_install = []
try:
import easydict
except ImportError:
deps_to_install.append("easydict")
try:
import einops
except ImportError:
deps_to_install.append("einops")
try:
import ftfy
except ImportError:
deps_to_install.append("ftfy")
try:
import decord
except ImportError:
deps_to_install.append("decord")
if deps_to_install:
print(f"๐Ÿ“ฆ Installing missing dependencies: {deps_to_install}")
subprocess.run(
[sys.executable, "-m", "pip", "install", "-q"] + deps_to_install,
check=False
)
# 2. ๋ ˆํฌ ํด๋ก 
ensure_repo()
print("โœ… Repository ready")
# 3. ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ (๊ฐ€์žฅ ์ค‘์š”!)
ensure_model()
print("โœ… Model weights ready")
# 4. ํฌ์ฆˆ ๋””ํ…ํ„ฐ ๋ฏธ๋ฆฌ ๋กœ๋“œ (์„ ํƒ์ )
try:
get_pose_detector()
print("โœ… Pose detector ready")
except Exception as e:
print(f"โš ๏ธ Pose detector will be loaded on first use: {e}")
# 5. SteadyDancer ๋ชจ๋“ˆ import ํ…Œ์ŠคํŠธ
try:
sys.path.insert(0, str(REPO_DIR))
from wan.configs import WAN_CONFIGS
print("โœ… SteadyDancer modules importable")
except Exception as e:
print(f"โš ๏ธ SteadyDancer import test failed: {e}")
print(" (Will try again during generation)")
print("๐ŸŽ‰ Warmup complete! Ready for requests.")
if __name__ == "__main__":
# Space ์‹œ์ž‘ ์‹œ ๋ชจ๋“  ์ค€๋น„ ์ž‘์—… ๋ฏธ๋ฆฌ ์ˆ˜ํ–‰
warmup()
demo.launch()