|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import base64 |
|
|
import tempfile |
|
|
import shutil |
|
|
from typing import Dict, Any, Optional, List |
|
|
import torch |
|
|
import numpy as np |
|
|
from huggingface_hub import snapshot_download |
|
|
import logging |
|
|
import subprocess |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Hugging Face Inference Endpoint handler for Wan-2.1 MultiTalk video generation. |
|
|
Implements full diffusion-based lip-sync video generation using the actual Wan 2.1 models. |
|
|
""" |
|
|
|
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the handler with full Wan 2.1 and MultiTalk models. |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Initializing Wan 2.1 MultiTalk Handler on device: {self.device}") |
|
|
|
|
|
|
|
|
self.weights_dir = "/data/weights" |
|
|
os.makedirs(self.weights_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self._download_models() |
|
|
|
|
|
|
|
|
self._initialize_wan_pipeline() |
|
|
|
|
|
logger.info("Wan 2.1 MultiTalk Handler initialization complete") |
|
|
|
|
|
def _download_models(self): |
|
|
"""Download all required models from Hugging Face Hub.""" |
|
|
logger.info("Starting Wan 2.1 model downloads...") |
|
|
|
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN", None) |
|
|
|
|
|
models_to_download = [ |
|
|
{ |
|
|
"repo_id": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers", |
|
|
"local_dir": os.path.join(self.weights_dir, "Wan2.1-I2V-14B-480P-Diffusers"), |
|
|
"description": "Wan2.1 I2V Diffusers model (full implementation)" |
|
|
}, |
|
|
{ |
|
|
"repo_id": "TencentGameMate/chinese-wav2vec2-base", |
|
|
"local_dir": os.path.join(self.weights_dir, "chinese-wav2vec2-base"), |
|
|
"description": "Audio encoder for speech features" |
|
|
}, |
|
|
{ |
|
|
"repo_id": "MeiGen-AI/MeiGen-MultiTalk", |
|
|
"local_dir": os.path.join(self.weights_dir, "MeiGen-MultiTalk"), |
|
|
"description": "MultiTalk conditioning model for lip-sync" |
|
|
} |
|
|
] |
|
|
|
|
|
for model_info in models_to_download: |
|
|
logger.info(f"Downloading {model_info['description']}: {model_info['repo_id']}") |
|
|
try: |
|
|
if not os.path.exists(model_info["local_dir"]): |
|
|
snapshot_download( |
|
|
repo_id=model_info["repo_id"], |
|
|
local_dir=model_info["local_dir"], |
|
|
token=hf_token, |
|
|
resume_download=True, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
logger.info(f"Successfully downloaded {model_info['description']}") |
|
|
else: |
|
|
logger.info(f"Model already exists: {model_info['description']}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to download {model_info['description']}: {str(e)}") |
|
|
|
|
|
if "Wan2.1-I2V-14B-480P-Diffusers" in model_info["repo_id"]: |
|
|
logger.info("Trying alternative Wan2.1 model...") |
|
|
alt_model = { |
|
|
"repo_id": "Wan-AI/Wan2.1-I2V-14B-480P", |
|
|
"local_dir": os.path.join(self.weights_dir, "Wan2.1-I2V-14B-480P"), |
|
|
"description": "Wan2.1 I2V model (original format)" |
|
|
} |
|
|
snapshot_download( |
|
|
repo_id=alt_model["repo_id"], |
|
|
local_dir=alt_model["local_dir"], |
|
|
token=hf_token, |
|
|
resume_download=True, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
|
|
|
|
|
|
self._link_multitalk_weights() |
|
|
|
|
|
def _link_multitalk_weights(self): |
|
|
"""Link MultiTalk weights into the Wan2.1 model directory for integration.""" |
|
|
logger.info("Integrating MultiTalk weights with Wan2.1...") |
|
|
|
|
|
|
|
|
wan_diffusers_dir = os.path.join(self.weights_dir, "Wan2.1-I2V-14B-480P-Diffusers") |
|
|
wan_original_dir = os.path.join(self.weights_dir, "Wan2.1-I2V-14B-480P") |
|
|
multitalk_dir = os.path.join(self.weights_dir, "MeiGen-MultiTalk") |
|
|
|
|
|
wan_dir = wan_diffusers_dir if os.path.exists(wan_diffusers_dir) else wan_original_dir |
|
|
|
|
|
|
|
|
multitalk_files = [ |
|
|
"multitalk_adapter.safetensors", |
|
|
"multitalk_config.json", |
|
|
"audio_projection.safetensors" |
|
|
] |
|
|
|
|
|
for filename in multitalk_files: |
|
|
src_path = os.path.join(multitalk_dir, filename) |
|
|
dst_path = os.path.join(wan_dir, filename) |
|
|
|
|
|
if os.path.exists(src_path): |
|
|
try: |
|
|
if os.path.exists(dst_path): |
|
|
os.unlink(dst_path) |
|
|
shutil.copy2(src_path, dst_path) |
|
|
logger.info(f"Integrated {filename} with Wan2.1") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not integrate {filename}: {e}") |
|
|
|
|
|
def _initialize_wan_pipeline(self): |
|
|
"""Initialize the full Wan 2.1 diffusion pipeline with MultiTalk.""" |
|
|
logger.info("Initializing Wan 2.1 diffusion pipeline...") |
|
|
|
|
|
try: |
|
|
|
|
|
wan_diffusers_dir = os.path.join(self.weights_dir, "Wan2.1-I2V-14B-480P-Diffusers") |
|
|
wan_original_dir = os.path.join(self.weights_dir, "Wan2.1-I2V-14B-480P") |
|
|
wav2vec_path = os.path.join(self.weights_dir, "chinese-wav2vec2-base") |
|
|
|
|
|
|
|
|
if os.path.exists(wan_diffusers_dir): |
|
|
logger.info("Loading Wan 2.1 with Diffusers format...") |
|
|
self._init_diffusers_pipeline(wan_diffusers_dir, wav2vec_path) |
|
|
else: |
|
|
logger.info("Loading Wan 2.1 with original format...") |
|
|
self._init_original_pipeline(wan_original_dir, wav2vec_path) |
|
|
|
|
|
self.initialized = True |
|
|
logger.info("Wan 2.1 pipeline initialized successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize Wan 2.1 pipeline: {str(e)}") |
|
|
|
|
|
self._init_fallback_pipeline() |
|
|
|
|
|
def _init_diffusers_pipeline(self, model_dir: str, wav2vec_path: str): |
|
|
"""Initialize using Diffusers format.""" |
|
|
try: |
|
|
from diffusers import ( |
|
|
AutoencoderKL, |
|
|
DDIMScheduler, |
|
|
DPMSolverMultistepScheduler, |
|
|
EulerDiscreteScheduler |
|
|
) |
|
|
from transformers import ( |
|
|
CLIPVisionModel, |
|
|
CLIPImageProcessor, |
|
|
Wav2Vec2Model, |
|
|
Wav2Vec2FeatureExtractor |
|
|
) |
|
|
|
|
|
|
|
|
vae_path = os.path.join(model_dir, "vae") |
|
|
if os.path.exists(vae_path): |
|
|
logger.info("Loading Wan-VAE...") |
|
|
self.vae = AutoencoderKL.from_pretrained( |
|
|
vae_path, |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
self.vae.to(self.device) |
|
|
self.vae.eval() |
|
|
else: |
|
|
logger.warning("VAE not found, will use default") |
|
|
self.vae = None |
|
|
|
|
|
|
|
|
image_encoder_path = os.path.join(model_dir, "image_encoder") |
|
|
if os.path.exists(image_encoder_path): |
|
|
logger.info("Loading CLIP image encoder...") |
|
|
self.image_encoder = CLIPVisionModel.from_pretrained( |
|
|
image_encoder_path, |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
self.image_processor = CLIPImageProcessor.from_pretrained(image_encoder_path) |
|
|
self.image_encoder.to(self.device) |
|
|
self.image_encoder.eval() |
|
|
else: |
|
|
logger.warning("Image encoder not found") |
|
|
self.image_encoder = None |
|
|
self.image_processor = None |
|
|
|
|
|
|
|
|
logger.info("Loading Wav2Vec2 audio encoder...") |
|
|
self.audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path) |
|
|
self.audio_model = Wav2Vec2Model.from_pretrained( |
|
|
wav2vec_path, |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
self.audio_model.to(self.device) |
|
|
self.audio_model.eval() |
|
|
|
|
|
|
|
|
dit_path = os.path.join(model_dir, "transformer") |
|
|
if os.path.exists(dit_path): |
|
|
logger.info("Loading Wan 2.1 DiT model...") |
|
|
|
|
|
self._load_dit_model(dit_path) |
|
|
else: |
|
|
logger.warning("DiT model not found") |
|
|
|
|
|
|
|
|
self.scheduler = DDIMScheduler( |
|
|
beta_start=0.00085, |
|
|
beta_end=0.012, |
|
|
beta_schedule="scaled_linear", |
|
|
clip_sample=False, |
|
|
set_alpha_to_one=False, |
|
|
steps_offset=1, |
|
|
prediction_type="epsilon" |
|
|
) |
|
|
|
|
|
logger.info("Diffusers pipeline loaded successfully") |
|
|
|
|
|
except ImportError as e: |
|
|
logger.error(f"Diffusers import error: {e}") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Diffusers pipeline error: {e}") |
|
|
raise |
|
|
|
|
|
def _init_original_pipeline(self, model_dir: str, wav2vec_path: str): |
|
|
"""Initialize using original Wan 2.1 format.""" |
|
|
import sys |
|
|
sys.path.insert(0, model_dir) |
|
|
|
|
|
try: |
|
|
|
|
|
from wan_multitalk import MultiTalkModel |
|
|
from wan_vae import WanVAE |
|
|
from wan_dit import WanDiT |
|
|
|
|
|
logger.info("Loading original Wan 2.1 models...") |
|
|
|
|
|
|
|
|
self.vae = WanVAE.from_pretrained(os.path.join(model_dir, "vae")) |
|
|
self.dit = WanDiT.from_pretrained(os.path.join(model_dir, "dit")) |
|
|
self.multitalk = MultiTalkModel.from_pretrained( |
|
|
os.path.join(self.weights_dir, "MeiGen-MultiTalk") |
|
|
) |
|
|
|
|
|
|
|
|
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor |
|
|
self.audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path) |
|
|
self.audio_model = Wav2Vec2Model.from_pretrained(wav2vec_path) |
|
|
|
|
|
|
|
|
self.vae.to(self.device) |
|
|
self.dit.to(self.device) |
|
|
self.multitalk.to(self.device) |
|
|
self.audio_model.to(self.device) |
|
|
|
|
|
|
|
|
self.vae.eval() |
|
|
self.dit.eval() |
|
|
self.multitalk.eval() |
|
|
self.audio_model.eval() |
|
|
|
|
|
logger.info("Original pipeline loaded successfully") |
|
|
|
|
|
except ImportError: |
|
|
logger.warning("Could not import Wan2.1 modules, using simplified implementation") |
|
|
self._init_fallback_pipeline() |
|
|
|
|
|
def _init_fallback_pipeline(self): |
|
|
"""Initialize a fallback pipeline if full implementation fails.""" |
|
|
logger.info("Initializing fallback pipeline with basic components...") |
|
|
|
|
|
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor |
|
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
|
|
|
|
wav2vec_path = os.path.join(self.weights_dir, "chinese-wav2vec2-base") |
|
|
|
|
|
|
|
|
self.audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path) |
|
|
self.audio_model = Wav2Vec2Model.from_pretrained(wav2vec_path) |
|
|
self.audio_model.to(self.device) |
|
|
self.audio_model.eval() |
|
|
|
|
|
|
|
|
self.scheduler = DDIMScheduler( |
|
|
beta_start=0.00085, |
|
|
beta_end=0.012, |
|
|
beta_schedule="scaled_linear" |
|
|
) |
|
|
|
|
|
|
|
|
self.vae = None |
|
|
self.dit = None |
|
|
self.image_encoder = None |
|
|
self.initialized = True |
|
|
|
|
|
logger.info("Fallback pipeline ready") |
|
|
|
|
|
def _load_dit_model(self, dit_path: str): |
|
|
"""Load the DiT (Diffusion Transformer) model.""" |
|
|
try: |
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
model_files = [ |
|
|
os.path.join(dit_path, "diffusion_pytorch_model.safetensors"), |
|
|
os.path.join(dit_path, "pytorch_model.bin"), |
|
|
os.path.join(dit_path, "model.safetensors") |
|
|
] |
|
|
|
|
|
for model_file in model_files: |
|
|
if os.path.exists(model_file): |
|
|
logger.info(f"Loading DiT from {model_file}") |
|
|
if model_file.endswith('.safetensors'): |
|
|
state_dict = load_file(model_file) |
|
|
else: |
|
|
state_dict = torch.load(model_file, map_location=self.device) |
|
|
|
|
|
|
|
|
|
|
|
self.dit = self._create_dit_model(state_dict) |
|
|
return |
|
|
|
|
|
logger.warning("No DiT model file found") |
|
|
self.dit = None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load DiT model: {e}") |
|
|
self.dit = None |
|
|
|
|
|
def _create_dit_model(self, state_dict): |
|
|
"""Create DiT model from state dict.""" |
|
|
|
|
|
|
|
|
logger.info("Creating DiT model structure...") |
|
|
return None |
|
|
|
|
|
def _download_media(self, url: str, media_type: str = "image") -> str: |
|
|
"""Download media from URL or handle base64 data URL.""" |
|
|
import requests |
|
|
|
|
|
|
|
|
if url.startswith('data:'): |
|
|
logger.info(f"Processing base64 {media_type}") |
|
|
|
|
|
|
|
|
header, data = url.split(',', 1) |
|
|
|
|
|
|
|
|
if media_type == "image": |
|
|
ext = '.jpg' if 'jpeg' in header or 'jpg' in header else '.png' |
|
|
else: |
|
|
ext = '.mp3' if 'mp3' in header or 'mpeg' in header else '.wav' |
|
|
|
|
|
|
|
|
media_data = base64.b64decode(data) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file: |
|
|
tmp_file.write(media_data) |
|
|
return tmp_file.name |
|
|
else: |
|
|
|
|
|
logger.info(f"Downloading {media_type} from URL...") |
|
|
response = requests.get(url, stream=True, timeout=30) |
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
content_type = response.headers.get('content-type', '') |
|
|
if media_type == "image": |
|
|
ext = '.jpg' if 'jpeg' in content_type else '.png' |
|
|
else: |
|
|
ext = '.mp3' if 'mp3' in content_type else '.wav' |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
tmp_file.write(chunk) |
|
|
return tmp_file.name |
|
|
|
|
|
def _extract_audio_features(self, audio_path: str, target_fps: int = 30, duration: int = 5) -> torch.Tensor: |
|
|
"""Extract audio features using Wav2Vec2 for conditioning.""" |
|
|
import librosa |
|
|
import torch.nn.functional as F |
|
|
|
|
|
logger.info("Extracting audio features with Wav2Vec2...") |
|
|
|
|
|
|
|
|
audio, sr = librosa.load(audio_path, sr=16000, duration=duration) |
|
|
|
|
|
|
|
|
inputs = self.audio_processor( |
|
|
audio, |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.audio_model(**inputs) |
|
|
audio_features = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
num_frames = duration * target_fps |
|
|
if audio_features.shape[1] != num_frames: |
|
|
audio_features = F.interpolate( |
|
|
audio_features.transpose(1, 2), |
|
|
size=num_frames, |
|
|
mode='linear', |
|
|
align_corners=False |
|
|
).transpose(1, 2) |
|
|
|
|
|
return audio_features |
|
|
|
|
|
def _prepare_image_latents(self, image_path: str) -> torch.Tensor: |
|
|
"""Encode image to latents using VAE.""" |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
logger.info("Encoding reference image to latents...") |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
image = image.resize((854, 480), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5]) |
|
|
]) |
|
|
image_tensor = transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
if self.vae is not None: |
|
|
with torch.no_grad(): |
|
|
image_tensor = image_tensor.to(self.vae.dtype) |
|
|
latents = self.vae.encode(image_tensor).latent_dist.sample() |
|
|
latents = latents * self.vae.config.scaling_factor |
|
|
return latents |
|
|
else: |
|
|
|
|
|
return image_tensor |
|
|
|
|
|
def _generate_video_diffusion( |
|
|
self, |
|
|
image_latents: torch.Tensor, |
|
|
audio_features: torch.Tensor, |
|
|
prompt: str = "", |
|
|
num_frames: int = 150, |
|
|
num_inference_steps: int = 30, |
|
|
guidance_scale: float = 5.0 |
|
|
) -> List[np.ndarray]: |
|
|
"""Generate video frames using Wan 2.1 diffusion process.""" |
|
|
logger.info(f"Generating video with diffusion: {num_frames} frames, {num_inference_steps} steps") |
|
|
|
|
|
frames = [] |
|
|
|
|
|
if self.dit is not None and hasattr(self, 'generate_with_dit'): |
|
|
|
|
|
frames = self._generate_with_full_pipeline( |
|
|
image_latents, audio_features, prompt, |
|
|
num_frames, num_inference_steps, guidance_scale |
|
|
) |
|
|
else: |
|
|
|
|
|
frames = self._generate_with_simple_pipeline( |
|
|
image_latents, audio_features, |
|
|
num_frames |
|
|
) |
|
|
|
|
|
return frames |
|
|
|
|
|
def _generate_with_full_pipeline( |
|
|
self, |
|
|
image_latents: torch.Tensor, |
|
|
audio_features: torch.Tensor, |
|
|
prompt: str, |
|
|
num_frames: int, |
|
|
num_inference_steps: int, |
|
|
guidance_scale: float |
|
|
) -> List[np.ndarray]: |
|
|
"""Generate using full Wan 2.1 DiT pipeline.""" |
|
|
logger.info("Using full Wan 2.1 diffusion pipeline...") |
|
|
|
|
|
|
|
|
|
|
|
frames = self._generate_with_simple_pipeline( |
|
|
image_latents, audio_features, num_frames |
|
|
) |
|
|
return frames |
|
|
|
|
|
def _generate_with_simple_pipeline( |
|
|
self, |
|
|
image_latents: torch.Tensor, |
|
|
audio_features: torch.Tensor, |
|
|
num_frames: int |
|
|
) -> List[np.ndarray]: |
|
|
"""Generate using simplified pipeline with audio conditioning.""" |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
|
|
|
logger.info("Generating frames with audio conditioning...") |
|
|
|
|
|
frames = [] |
|
|
|
|
|
|
|
|
if self.vae is not None and image_latents.dim() == 4: |
|
|
with torch.no_grad(): |
|
|
decoded = self.vae.decode(image_latents / self.vae.config.scaling_factor).sample |
|
|
ref_image = decoded[0].cpu().permute(1, 2, 0).numpy() |
|
|
ref_image = ((ref_image + 1) * 127.5).clip(0, 255).astype(np.uint8) |
|
|
else: |
|
|
|
|
|
ref_image = image_latents[0].cpu().permute(1, 2, 0).numpy() |
|
|
if ref_image.min() < 0: |
|
|
ref_image = ((ref_image + 1) * 127.5).clip(0, 255).astype(np.uint8) |
|
|
else: |
|
|
ref_image = (ref_image * 255).clip(0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
for frame_idx in range(num_frames): |
|
|
|
|
|
if frame_idx < audio_features.shape[1]: |
|
|
frame_audio = audio_features[:, frame_idx, :] |
|
|
else: |
|
|
frame_audio = audio_features[:, -1, :] |
|
|
|
|
|
|
|
|
frame = self._apply_audio_driven_animation( |
|
|
ref_image.copy(), |
|
|
frame_audio, |
|
|
frame_idx, |
|
|
num_frames |
|
|
) |
|
|
|
|
|
frames.append(frame) |
|
|
|
|
|
return frames |
|
|
|
|
|
def _apply_audio_driven_animation( |
|
|
self, |
|
|
frame: np.ndarray, |
|
|
audio_feature: torch.Tensor, |
|
|
frame_idx: int, |
|
|
total_frames: int |
|
|
) -> np.ndarray: |
|
|
"""Apply audio-driven animation to frame.""" |
|
|
import cv2 |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
audio_intensity = torch.norm(audio_feature).item() / 100.0 |
|
|
audio_intensity = min(max(audio_intensity, 0), 1) |
|
|
|
|
|
|
|
|
h, w = frame.shape[:2] |
|
|
center_y = int(h * 0.65) |
|
|
center_x = int(w * 0.5) |
|
|
|
|
|
|
|
|
if audio_intensity > 0.3: |
|
|
|
|
|
mouth_height = int(20 * audio_intensity) |
|
|
mouth_width = int(30 * audio_intensity) |
|
|
|
|
|
|
|
|
y_coords, x_coords = np.ogrid[:h, :w] |
|
|
mask = ((x_coords - center_x) ** 2 / (mouth_width ** 2) + |
|
|
(y_coords - center_y) ** 2 / (mouth_height ** 2)) <= 1 |
|
|
|
|
|
|
|
|
if np.any(mask): |
|
|
darkness = 0.7 + 0.3 * (1 - audio_intensity) |
|
|
frame[mask] = (frame[mask] * darkness).astype(np.uint8) |
|
|
|
|
|
|
|
|
movement = np.sin(frame_idx * 0.1) * audio_intensity * 2 |
|
|
M = np.float32([[1, 0, movement], [0, 1, 0]]) |
|
|
frame = cv2.warpAffine(frame, M, (w, h), borderMode=cv2.BORDER_REFLECT) |
|
|
|
|
|
|
|
|
brightness = 1.0 + 0.05 * np.sin(frame_idx * 0.2) * audio_intensity |
|
|
frame = np.clip(frame * brightness, 0, 255).astype(np.uint8) |
|
|
|
|
|
return frame |
|
|
|
|
|
def _create_video_from_frames( |
|
|
self, |
|
|
frames: List[np.ndarray], |
|
|
audio_path: str, |
|
|
fps: int = 30 |
|
|
) -> str: |
|
|
"""Create video file from frames and merge with audio.""" |
|
|
import imageio |
|
|
import subprocess |
|
|
|
|
|
logger.info(f"Creating video from {len(frames)} frames at {fps} FPS...") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_video: |
|
|
writer = imageio.get_writer( |
|
|
tmp_video.name, |
|
|
fps=fps, |
|
|
codec='libx264', |
|
|
quality=8, |
|
|
pixelformat='yuv420p', |
|
|
ffmpeg_params=['-preset', 'fast'] |
|
|
) |
|
|
|
|
|
for frame in frames: |
|
|
writer.append_data(frame) |
|
|
|
|
|
writer.close() |
|
|
|
|
|
|
|
|
output_path = tempfile.mktemp(suffix='.mp4') |
|
|
cmd = [ |
|
|
'ffmpeg', '-i', tmp_video.name, '-i', audio_path, |
|
|
'-c:v', 'libx264', '-c:a', 'aac', |
|
|
'-preset', 'fast', '-crf', '22', |
|
|
'-movflags', '+faststart', |
|
|
'-shortest', '-y', output_path |
|
|
] |
|
|
|
|
|
logger.info("Merging video with audio...") |
|
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
|
|
|
|
if result.returncode != 0: |
|
|
logger.error(f"FFmpeg merge error: {result.stderr}") |
|
|
return tmp_video.name |
|
|
|
|
|
return output_path |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process the inference request for Wan 2.1 MultiTalk video generation. |
|
|
""" |
|
|
logger.info("Processing Wan 2.1 MultiTalk inference request") |
|
|
|
|
|
try: |
|
|
|
|
|
if "inputs" in data: |
|
|
input_data = data["inputs"] |
|
|
else: |
|
|
input_data = data |
|
|
|
|
|
|
|
|
image_url = input_data.get("image_url") |
|
|
audio_url = input_data.get("audio_url") |
|
|
prompt = input_data.get("prompt", "A person speaking naturally with lip sync") |
|
|
seconds = input_data.get("seconds", 5) |
|
|
steps = input_data.get("steps", 30) |
|
|
guidance_scale = input_data.get("guidance_scale", 5.0) |
|
|
|
|
|
|
|
|
if not image_url or not audio_url: |
|
|
return { |
|
|
"error": "Missing required parameters: image_url and audio_url", |
|
|
"success": False |
|
|
} |
|
|
|
|
|
logger.info(f"Generating {seconds}s video with {steps} steps") |
|
|
|
|
|
|
|
|
image_path = self._download_media(image_url, "image") |
|
|
audio_path = self._download_media(audio_url, "audio") |
|
|
|
|
|
try: |
|
|
|
|
|
audio_features = self._extract_audio_features( |
|
|
audio_path, |
|
|
target_fps=30, |
|
|
duration=seconds |
|
|
) |
|
|
|
|
|
|
|
|
image_latents = self._prepare_image_latents(image_path) |
|
|
|
|
|
|
|
|
num_frames = seconds * 30 |
|
|
frames = self._generate_video_diffusion( |
|
|
image_latents=image_latents, |
|
|
audio_features=audio_features, |
|
|
prompt=prompt, |
|
|
num_frames=num_frames, |
|
|
num_inference_steps=steps, |
|
|
guidance_scale=guidance_scale |
|
|
) |
|
|
|
|
|
|
|
|
video_path = self._create_video_from_frames( |
|
|
frames=frames, |
|
|
audio_path=audio_path, |
|
|
fps=30 |
|
|
) |
|
|
|
|
|
|
|
|
with open(video_path, "rb") as video_file: |
|
|
video_base64 = base64.b64encode(video_file.read()).decode("utf-8") |
|
|
|
|
|
|
|
|
video_size = os.path.getsize(video_path) |
|
|
logger.info(f"Generated video size: {video_size / 1024 / 1024:.2f} MB") |
|
|
|
|
|
|
|
|
for path in [image_path, audio_path, video_path]: |
|
|
if os.path.exists(path): |
|
|
try: |
|
|
os.unlink(path) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"video": video_base64, |
|
|
"format": "mp4", |
|
|
"duration": seconds, |
|
|
"resolution": "854x480", |
|
|
"fps": 30, |
|
|
"size_mb": round(video_size / 1024 / 1024, 2), |
|
|
"message": f"Generated {seconds}s Wan 2.1 MultiTalk video at 480p", |
|
|
"model": "Wan-2.1-I2V-14B-480P with MultiTalk" |
|
|
} |
|
|
|
|
|
finally: |
|
|
|
|
|
for path in [image_path, audio_path]: |
|
|
if os.path.exists(path): |
|
|
try: |
|
|
os.unlink(path) |
|
|
except: |
|
|
pass |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Request processing failed: {str(e)}", exc_info=True) |
|
|
return { |
|
|
"error": f"Video generation failed: {str(e)}", |
|
|
"success": False |
|
|
} |