Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import os | |
| import tempfile | |
| import shutil | |
| import imageio | |
| import pandas as pd | |
| import numpy as np | |
| from diffsynth import ModelManager, WanVideoReCamMasterPipeline, save_video | |
| import json | |
| from torchvision.transforms import v2 | |
| from einops import rearrange | |
| import torchvision | |
| from PIL import Image | |
| import logging | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Get model storage path from environment variable or use default | |
| MODELS_ROOT_DIR = os.environ.get("RECAMMASTER_MODELS_DIR", "/data/models") | |
| logger.info(f"Using models root directory: {MODELS_ROOT_DIR}") | |
| # Camera transformation types | |
| CAMERA_TRANSFORMATIONS = { | |
| "1": "Pan Right", | |
| "2": "Pan Left", | |
| "3": "Tilt Up", | |
| "4": "Tilt Down", | |
| "5": "Zoom In", | |
| "6": "Zoom Out", | |
| "7": "Translate Up (with rotation)", | |
| "8": "Translate Down (with rotation)", | |
| "9": "Arc Left (with rotation)", | |
| "10": "Arc Right (with rotation)" | |
| } | |
| # Global variables for model | |
| model_manager = None | |
| pipe = None | |
| is_model_loaded = False | |
| # Define model repositories and files | |
| WAN21_REPO_ID = "Wan-AI/Wan2.1-T2V-1.3B" | |
| WAN21_LOCAL_DIR = f"{MODELS_ROOT_DIR}/Wan-AI/Wan2.1-T2V-1.3B" | |
| WAN21_FILES = [ | |
| "diffusion_pytorch_model.safetensors", | |
| "models_t5_umt5-xxl-enc-bf16.pth", | |
| "Wan2.1_VAE.pth" | |
| ] | |
| # Define tokenizer files to download | |
| UMT5_XXL_TOKENIZER_FILES = [ | |
| "google/umt5-xxl/special_tokens_map.json", | |
| "google/umt5-xxl/spiece.model", | |
| "google/umt5-xxl/tokenizer.json", | |
| "google/umt5-xxl/tokenizer_config.json" | |
| ] | |
| RECAMMASTER_REPO_ID = "KwaiVGI/ReCamMaster-Wan2.1" | |
| RECAMMASTER_CHECKPOINT_FILE = "step20000.ckpt" | |
| RECAMMASTER_LOCAL_DIR = f"{MODELS_ROOT_DIR}/ReCamMaster/checkpoints" | |
| # Define test data directory | |
| TEST_DATA_DIR = "example_test_data" | |
| def download_umt5_xxl_tokenizer(progress_callback=None): | |
| """Download UMT5-XXL tokenizer files from HuggingFace""" | |
| total_files = len(UMT5_XXL_TOKENIZER_FILES) | |
| downloaded_paths = [] | |
| for i, file_path in enumerate(UMT5_XXL_TOKENIZER_FILES): | |
| local_dir = f"{WAN21_LOCAL_DIR}/{os.path.dirname(file_path)}" | |
| filename = os.path.basename(file_path) | |
| full_local_path = f"{WAN21_LOCAL_DIR}/{file_path}" | |
| # Update progress | |
| if progress_callback: | |
| progress_callback(i/total_files, desc=f"Checking tokenizer file {i+1}/{total_files}: {filename}") | |
| # Check if already exists | |
| if os.path.exists(full_local_path): | |
| logger.info(f"✓ Tokenizer file {filename} already exists at {full_local_path}") | |
| downloaded_paths.append(full_local_path) | |
| continue | |
| # Create directory if it doesn't exist | |
| os.makedirs(local_dir, exist_ok=True) | |
| # Download the file | |
| logger.info(f"Downloading tokenizer file {filename} from {WAN21_REPO_ID}/{file_path}...") | |
| if progress_callback: | |
| progress_callback(i/total_files, desc=f"Downloading tokenizer file {i+1}/{total_files}: {filename}") | |
| try: | |
| # Download using huggingface_hub | |
| downloaded_path = hf_hub_download( | |
| repo_id=WAN21_REPO_ID, | |
| filename=file_path, | |
| local_dir=WAN21_LOCAL_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"✓ Successfully downloaded tokenizer file {filename} to {downloaded_path}!") | |
| downloaded_paths.append(downloaded_path) | |
| except Exception as e: | |
| logger.error(f"✗ Error downloading tokenizer file {filename}: {e}") | |
| raise | |
| if progress_callback: | |
| progress_callback(1.0, desc=f"All tokenizer files downloaded successfully!") | |
| return downloaded_paths | |
| def download_wan21_models(progress_callback=None): | |
| """Download Wan2.1 model files from HuggingFace""" | |
| total_files = len(WAN21_FILES) | |
| downloaded_paths = [] | |
| # Create directory if it doesn't exist | |
| Path(WAN21_LOCAL_DIR).mkdir(parents=True, exist_ok=True) | |
| for i, filename in enumerate(WAN21_FILES): | |
| local_path = Path(WAN21_LOCAL_DIR) / filename | |
| # Update progress | |
| if progress_callback: | |
| progress_callback(i/total_files, desc=f"Checking Wan2.1 file {i+1}/{total_files}: {filename}") | |
| # Check if already exists | |
| if local_path.exists(): | |
| logger.info(f"✓ {filename} already exists at {local_path}") | |
| downloaded_paths.append(str(local_path)) | |
| continue | |
| # Download the file | |
| logger.info(f"Downloading {filename} from {WAN21_REPO_ID}...") | |
| if progress_callback: | |
| progress_callback(i/total_files, desc=f"Downloading Wan2.1 file {i+1}/{total_files}: {filename}") | |
| try: | |
| # Download using huggingface_hub | |
| downloaded_path = hf_hub_download( | |
| repo_id=WAN21_REPO_ID, | |
| filename=filename, | |
| local_dir=WAN21_LOCAL_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"✓ Successfully downloaded {filename} to {downloaded_path}!") | |
| downloaded_paths.append(downloaded_path) | |
| except Exception as e: | |
| logger.error(f"✗ Error downloading {filename}: {e}") | |
| raise | |
| if progress_callback: | |
| progress_callback(1.0, desc=f"All Wan2.1 models downloaded successfully!") | |
| return downloaded_paths | |
| def download_recammaster_checkpoint(progress_callback=None): | |
| """Download ReCamMaster checkpoint from HuggingFace using huggingface_hub""" | |
| checkpoint_path = Path(RECAMMASTER_LOCAL_DIR) / RECAMMASTER_CHECKPOINT_FILE | |
| # Check if already exists | |
| if checkpoint_path.exists(): | |
| logger.info(f"✓ ReCamMaster checkpoint already exists at {checkpoint_path}") | |
| return checkpoint_path | |
| # Create directory if it doesn't exist | |
| Path(RECAMMASTER_LOCAL_DIR).mkdir(parents=True, exist_ok=True) | |
| # Download the checkpoint | |
| logger.info("Downloading ReCamMaster checkpoint from HuggingFace...") | |
| logger.info(f"Repository: {RECAMMASTER_REPO_ID}") | |
| logger.info(f"File: {RECAMMASTER_CHECKPOINT_FILE}") | |
| logger.info(f"Destination: {checkpoint_path}") | |
| if progress_callback: | |
| progress_callback(0.0, desc=f"Downloading ReCamMaster checkpoint...") | |
| try: | |
| # Download using huggingface_hub | |
| downloaded_path = hf_hub_download( | |
| repo_id=RECAMMASTER_REPO_ID, | |
| filename=RECAMMASTER_CHECKPOINT_FILE, | |
| local_dir=RECAMMASTER_LOCAL_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"✓ Successfully downloaded ReCamMaster checkpoint to {downloaded_path}!") | |
| if progress_callback: | |
| progress_callback(1.0, desc=f"ReCamMaster checkpoint downloaded successfully!") | |
| return downloaded_path | |
| except Exception as e: | |
| logger.error(f"✗ Error downloading checkpoint: {e}") | |
| raise | |
| def create_test_data_structure(progress_callback=None): | |
| """Create sample camera extrinsics data for testing""" | |
| if progress_callback: | |
| progress_callback(0.0, desc="Creating test data structure...") | |
| # Create directories | |
| data_dir = Path(f"{TEST_DATA_DIR}/cameras") | |
| videos_dir = Path(f"{TEST_DATA_DIR}/videos") | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| videos_dir.mkdir(parents=True, exist_ok=True) | |
| camera_file = data_dir / "camera_extrinsics.json" | |
| # Skip if file already exists | |
| if camera_file.exists(): | |
| logger.info(f"✓ Camera extrinsics already exist at {camera_file}") | |
| if progress_callback: | |
| progress_callback(1.0, desc="Test data structure already exists") | |
| return | |
| if progress_callback: | |
| progress_callback(0.3, desc="Generating camera extrinsics data...") | |
| # Generate sample camera data | |
| camera_data = {} | |
| # Create 81 frames with 10 camera trajectories each | |
| for frame_idx in range(81): | |
| frame_key = f"frame{frame_idx}" | |
| camera_data[frame_key] = {} | |
| for cam_idx in range(1, 11): # Camera types 1-10 | |
| # Create a sample camera matrix (this is just an example - replace with actual logic if needed) | |
| # In reality, these would be calculated based on specific camera movement patterns | |
| # Create a base identity matrix | |
| base_matrix = np.eye(4) | |
| # Add some variation based on frame and camera type | |
| # This is a simplistic example - real camera movements would be more complex | |
| if cam_idx == 1: # Pan Right | |
| base_matrix[0, 3] = 0.01 * frame_idx # Move right over time | |
| elif cam_idx == 2: # Pan Left | |
| base_matrix[0, 3] = -0.01 * frame_idx # Move left over time | |
| elif cam_idx == 3: # Tilt Up | |
| # Rotate around X-axis | |
| angle = 0.005 * frame_idx | |
| base_matrix[1, 1] = np.cos(angle) | |
| base_matrix[1, 2] = -np.sin(angle) | |
| base_matrix[2, 1] = np.sin(angle) | |
| base_matrix[2, 2] = np.cos(angle) | |
| elif cam_idx == 4: # Tilt Down | |
| # Rotate around X-axis (opposite direction) | |
| angle = -0.005 * frame_idx | |
| base_matrix[1, 1] = np.cos(angle) | |
| base_matrix[1, 2] = -np.sin(angle) | |
| base_matrix[2, 1] = np.sin(angle) | |
| base_matrix[2, 2] = np.cos(angle) | |
| elif cam_idx == 5: # Zoom In | |
| base_matrix[2, 3] = -0.01 * frame_idx # Move forward over time | |
| elif cam_idx == 6: # Zoom Out | |
| base_matrix[2, 3] = 0.01 * frame_idx # Move backward over time | |
| elif cam_idx == 7: # Translate Up (with rotation) | |
| base_matrix[1, 3] = 0.01 * frame_idx # Move up over time | |
| angle = 0.003 * frame_idx | |
| base_matrix[0, 0] = np.cos(angle) | |
| base_matrix[0, 2] = np.sin(angle) | |
| base_matrix[2, 0] = -np.sin(angle) | |
| base_matrix[2, 2] = np.cos(angle) | |
| elif cam_idx == 8: # Translate Down (with rotation) | |
| base_matrix[1, 3] = -0.01 * frame_idx # Move down over time | |
| angle = -0.003 * frame_idx | |
| base_matrix[0, 0] = np.cos(angle) | |
| base_matrix[0, 2] = np.sin(angle) | |
| base_matrix[2, 0] = -np.sin(angle) | |
| base_matrix[2, 2] = np.cos(angle) | |
| elif cam_idx == 9: # Arc Left (with rotation) | |
| angle = 0.005 * frame_idx | |
| radius = 2.0 | |
| base_matrix[0, 3] = -radius * np.sin(angle) | |
| base_matrix[2, 3] = -radius * np.cos(angle) + radius | |
| # Rotate to look at center | |
| look_angle = angle + np.pi | |
| base_matrix[0, 0] = np.cos(look_angle) | |
| base_matrix[0, 2] = np.sin(look_angle) | |
| base_matrix[2, 0] = -np.sin(look_angle) | |
| base_matrix[2, 2] = np.cos(look_angle) | |
| elif cam_idx == 10: # Arc Right (with rotation) | |
| angle = -0.005 * frame_idx | |
| radius = 2.0 | |
| base_matrix[0, 3] = -radius * np.sin(angle) | |
| base_matrix[2, 3] = -radius * np.cos(angle) + radius | |
| # Rotate to look at center | |
| look_angle = angle + np.pi | |
| base_matrix[0, 0] = np.cos(look_angle) | |
| base_matrix[0, 2] = np.sin(look_angle) | |
| base_matrix[2, 0] = -np.sin(look_angle) | |
| base_matrix[2, 2] = np.cos(look_angle) | |
| # Format the matrix as a string (as expected by the app) | |
| matrix_str = ' '.join([' '.join([str(base_matrix[i, j]) for j in range(4)]) for i in range(4)]) | |
| matrix_str = '[ ' + matrix_str.replace(' ', ' ] [ ', 3) + ' ]' | |
| camera_data[frame_key][f"cam{cam_idx:02d}"] = matrix_str | |
| if progress_callback: | |
| progress_callback(0.7, desc="Saving camera extrinsics data...") | |
| # Save camera extrinsics to JSON file | |
| with open(camera_file, 'w') as f: | |
| json.dump(camera_data, f, indent=2) | |
| logger.info(f"Created sample camera extrinsics at {camera_file}") | |
| logger.info(f"Created directory for example videos at {videos_dir}") | |
| if progress_callback: | |
| progress_callback(1.0, desc="Test data structure created successfully!") | |
| class Camera(object): | |
| def __init__(self, c2w): | |
| c2w_mat = np.array(c2w).reshape(4, 4) | |
| self.c2w_mat = c2w_mat | |
| self.w2c_mat = np.linalg.inv(c2w_mat) | |
| def parse_matrix(matrix_str): | |
| """Parse camera matrix string from JSON format""" | |
| rows = matrix_str.strip().split('] [') | |
| matrix = [] | |
| for row in rows: | |
| row = row.replace('[', '').replace(']', '') | |
| matrix.append(list(map(float, row.split()))) | |
| return np.array(matrix) | |
| def get_relative_pose(cam_params): | |
| """Calculate relative camera poses""" | |
| abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] | |
| abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] | |
| cam_to_origin = 0 | |
| target_cam_c2w = np.array([ | |
| [1, 0, 0, 0], | |
| [0, 1, 0, -cam_to_origin], | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1] | |
| ]) | |
| abs2rel = target_cam_c2w @ abs_w2cs[0] | |
| ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] | |
| ret_poses = np.array(ret_poses, dtype=np.float32) | |
| return ret_poses | |
| def load_models(progress_callback=None): | |
| """Load the ReCamMaster models""" | |
| global model_manager, pipe, is_model_loaded | |
| if is_model_loaded: | |
| return "Models already loaded!" | |
| try: | |
| logger.info("Starting model loading...") | |
| # First create the test data structure | |
| if progress_callback: | |
| progress_callback(0.05, desc="Setting up test data structure...") | |
| try: | |
| create_test_data_structure(progress_callback) | |
| except Exception as e: | |
| error_msg = f"Error creating test data structure: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Second, ensure the checkpoint is downloaded | |
| if progress_callback: | |
| progress_callback(0.1, desc="Checking for ReCamMaster checkpoint...") | |
| try: | |
| ckpt_path = download_recammaster_checkpoint(progress_callback) | |
| logger.info(f"Using checkpoint at {ckpt_path}") | |
| except Exception as e: | |
| error_msg = f"Error downloading ReCamMaster checkpoint: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Third, download Wan2.1 models if needed | |
| if progress_callback: | |
| progress_callback(0.2, desc="Checking for Wan2.1 models...") | |
| try: | |
| wan21_paths = download_wan21_models(progress_callback) | |
| logger.info(f"Using Wan2.1 models: {wan21_paths}") | |
| except Exception as e: | |
| error_msg = f"Error downloading Wan2.1 models: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Fourth, download UMT5-XXL tokenizer files | |
| if progress_callback: | |
| progress_callback(0.3, desc="Checking for UMT5-XXL tokenizer files...") | |
| try: | |
| tokenizer_paths = download_umt5_xxl_tokenizer(progress_callback) | |
| logger.info(f"Using UMT5-XXL tokenizer files: {tokenizer_paths}") | |
| except Exception as e: | |
| error_msg = f"Error downloading UMT5-XXL tokenizer files: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Now, load the models | |
| if progress_callback: | |
| progress_callback(0.4, desc="Loading model manager...") | |
| # Create symlink for google/umt5-xxl to handle potential path issues | |
| # Some libraries might look for this in a different way | |
| try: | |
| google_dir = f"{MODELS_ROOT_DIR}/google" | |
| if not os.path.exists(google_dir): | |
| os.makedirs(google_dir, exist_ok=True) | |
| umt5_xxl_symlink = f"{google_dir}/umt5-xxl" | |
| umt5_xxl_source = f"{WAN21_LOCAL_DIR}/google/umt5-xxl" | |
| # Create a symlink if it doesn't exist | |
| if not os.path.exists(umt5_xxl_symlink) and os.path.exists(umt5_xxl_source): | |
| if os.name == 'nt': # Windows | |
| import ctypes | |
| kdll = ctypes.windll.LoadLibrary("kernel32.dll") | |
| kdll.CreateSymbolicLinkA(umt5_xxl_symlink.encode(), umt5_xxl_source.encode(), 1) | |
| else: # Unix/Linux | |
| os.symlink(umt5_xxl_source, umt5_xxl_symlink) | |
| logger.info(f"Created symlink from {umt5_xxl_source} to {umt5_xxl_symlink}") | |
| except Exception as e: | |
| logger.warning(f"Could not create symlink for google/umt5-xxl: {str(e)}") | |
| # This is a warning, not an error, as we'll try to proceed anyway | |
| # Load Wan2.1 pre-trained models | |
| model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") | |
| if progress_callback: | |
| progress_callback(0.5, desc="Loading Wan2.1 models...") | |
| # Build full paths for the model files | |
| model_files = [f"{WAN21_LOCAL_DIR}/{filename}" for filename in WAN21_FILES] | |
| for model_file in model_files: | |
| logger.info(f"Loading model from: {model_file}") | |
| if not os.path.exists(model_file): | |
| error_msg = f"Error: Model file not found: {model_file}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Set environment variable for transformers to find the tokenizer | |
| os.environ["TRANSFORMERS_CACHE"] = MODELS_ROOT_DIR | |
| # Set the configuration for the text encoder to use the downloaded tokenizer path | |
| # This is needed because the WanTextEncoder expects the tokenizer to be at this path | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism warning | |
| model_manager.load_models(model_files) | |
| if progress_callback: | |
| progress_callback(0.7, desc="Creating pipeline...") | |
| pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") | |
| if progress_callback: | |
| progress_callback(0.8, desc="Initializing ReCamMaster modules...") | |
| # Initialize additional modules introduced in ReCamMaster | |
| dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] | |
| for block in pipe.dit.blocks: | |
| block.cam_encoder = nn.Linear(12, dim) | |
| block.projector = nn.Linear(dim, dim) | |
| block.cam_encoder.weight.data.zero_() | |
| block.cam_encoder.bias.data.zero_() | |
| block.projector.weight = nn.Parameter(torch.eye(dim)) | |
| block.projector.bias = nn.Parameter(torch.zeros(dim)) | |
| if progress_callback: | |
| progress_callback(0.9, desc="Loading ReCamMaster checkpoint...") | |
| # Load ReCamMaster checkpoint | |
| if not os.path.exists(ckpt_path): | |
| error_msg = f"Error: ReCamMaster checkpoint not found at {ckpt_path} even after download attempt." | |
| logger.error(error_msg) | |
| return error_msg | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| pipe.dit.load_state_dict(state_dict, strict=True) | |
| pipe.to("cuda") | |
| pipe.to(dtype=torch.bfloat16) | |
| is_model_loaded = True | |
| if progress_callback: | |
| progress_callback(1.0, desc="Models loaded successfully!") | |
| logger.info("Models loaded successfully!") | |
| return "Models loaded successfully!" | |
| except Exception as e: | |
| logger.error(f"Error loading models: {str(e)}") | |
| return f"Error loading models: {str(e)}" | |
| def extract_frames_from_video(video_path, output_dir, max_frames=81): | |
| """Extract frames from video and ensure we have at least 81 frames""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| reader = imageio.get_reader(video_path) | |
| fps = reader.get_meta_data()['fps'] | |
| total_frames = reader.count_frames() | |
| frames = [] | |
| for i, frame in enumerate(reader): | |
| frames.append(frame) | |
| reader.close() | |
| # If we have fewer than required frames, repeat the last frame | |
| if len(frames) < max_frames: | |
| logger.info(f"Video has {len(frames)} frames, padding to {max_frames} frames") | |
| last_frame = frames[-1] | |
| while len(frames) < max_frames: | |
| frames.append(last_frame) | |
| # Save frames | |
| for i, frame in enumerate(frames[:max_frames]): | |
| frame_path = os.path.join(output_dir, f"frame_{i:04d}.png") | |
| imageio.imwrite(frame_path, frame) | |
| return len(frames[:max_frames]), fps | |
| def process_video_for_recammaster(video_path, text_prompt, cam_type, height=480, width=832): | |
| """Process video through ReCamMaster model""" | |
| global pipe | |
| # Create frame processor | |
| frame_process = v2.Compose([ | |
| v2.CenterCrop(size=(height, width)), | |
| v2.Resize(size=(height, width), antialias=True), | |
| v2.ToTensor(), | |
| v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| def crop_and_resize(image): | |
| width_img, height_img = image.size | |
| scale = max(width / width_img, height / height_img) | |
| image = torchvision.transforms.functional.resize( | |
| image, | |
| (round(height_img*scale), round(width_img*scale)), | |
| interpolation=torchvision.transforms.InterpolationMode.BILINEAR | |
| ) | |
| return image | |
| # Load video frames | |
| reader = imageio.get_reader(video_path) | |
| frames = [] | |
| for i in range(81): # ReCamMaster needs exactly 81 frames | |
| try: | |
| frame = reader.get_data(i) | |
| frame = Image.fromarray(frame) | |
| frame = crop_and_resize(frame) | |
| frame = frame_process(frame) | |
| frames.append(frame) | |
| except: | |
| # If we run out of frames, repeat the last one | |
| if frames: | |
| frames.append(frames[-1]) | |
| else: | |
| raise ValueError("Video is too short!") | |
| reader.close() | |
| frames = torch.stack(frames, dim=0) | |
| frames = rearrange(frames, "T C H W -> C T H W") | |
| video_tensor = frames.unsqueeze(0) # Add batch dimension | |
| # Load camera trajectory | |
| tgt_camera_path = f"./{TEST_DATA_DIR}/cameras/camera_extrinsics.json" | |
| with open(tgt_camera_path, 'r') as file: | |
| cam_data = json.load(file) | |
| # Get camera trajectory for selected type | |
| cam_idx = list(range(81))[::4] # Sample every 4 frames | |
| traj = [parse_matrix(cam_data[f"frame{idx}"][f"cam{int(cam_type):02d}"]) for idx in cam_idx] | |
| traj = np.stack(traj).transpose(0, 2, 1) | |
| c2ws = [] | |
| for c2w in traj: | |
| c2w = c2w[:, [1, 2, 0, 3]] | |
| c2w[:3, 1] *= -1. | |
| c2w[:3, 3] /= 100 | |
| c2ws.append(c2w) | |
| tgt_cam_params = [Camera(cam_param) for cam_param in c2ws] | |
| relative_poses = [] | |
| for i in range(len(tgt_cam_params)): | |
| relative_pose = get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]]) | |
| relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) | |
| pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4 | |
| pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') | |
| camera_tensor = pose_embedding.to(torch.bfloat16).unsqueeze(0) # Add batch dimension | |
| # Generate video with ReCamMaster | |
| video = pipe( | |
| prompt=[text_prompt], | |
| negative_prompt=["worst quality, low quality, blurry, jittery, distorted"], | |
| source_video=video_tensor, | |
| target_camera=camera_tensor, | |
| cfg_scale=5.0, | |
| num_inference_steps=50, | |
| seed=0, | |
| tiled=True | |
| ) | |
| return video | |
| def generate_recammaster_video( | |
| video_file, | |
| text_prompt, | |
| camera_type, | |
| progress=gr.Progress() | |
| ): | |
| """Main function to generate video with ReCamMaster""" | |
| global pipe, is_model_loaded | |
| if not is_model_loaded: | |
| return None, "Error: Models not loaded! Please load models first." | |
| if video_file is None: | |
| return None, "Please upload a video file." | |
| try: | |
| # Create temporary directory for processing | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| progress(0.1, desc="Processing input video...") | |
| # Copy uploaded video to temp directory | |
| input_video_path = os.path.join(temp_dir, "input.mp4") | |
| shutil.copy(video_file.name, input_video_path) | |
| # Extract frames | |
| progress(0.2, desc="Extracting video frames...") | |
| num_frames, fps = extract_frames_from_video(input_video_path, os.path.join(temp_dir, "frames")) | |
| logger.info(f"Extracted {num_frames} frames at {fps} fps") | |
| # Process with ReCamMaster | |
| progress(0.3, desc="Processing with ReCamMaster...") | |
| output_video = process_video_for_recammaster( | |
| input_video_path, | |
| text_prompt, | |
| camera_type | |
| ) | |
| # Save output video | |
| progress(0.9, desc="Saving output video...") | |
| output_path = os.path.join(temp_dir, "output.mp4") | |
| save_video(output_video, output_path, fps=30, quality=5) | |
| # Copy to persistent location | |
| final_output_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name | |
| shutil.copy(output_path, final_output_path) | |
| progress(1.0, desc="Done!") | |
| transformation_name = CAMERA_TRANSFORMATIONS.get(str(camera_type), "Unknown") | |
| status_msg = f"Successfully generated video with '{transformation_name}' camera movement!" | |
| return final_output_path, status_msg | |
| except Exception as e: | |
| logger.error(f"Error generating video: {str(e)}") | |
| return None, f"Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="ReCamMaster Demo") as demo: | |
| # Show loading status | |
| loading_status = gr.Textbox( | |
| label="Model Loading Status", | |
| value="Loading models, please wait...", | |
| interactive=False, | |
| visible=True | |
| ) | |
| gr.Markdown(f""" | |
| # 🎥 ReCamMaster Demo | |
| ReCamMaster allows you to re-capture videos with novel camera trajectories. | |
| Upload a video and select a camera transformation to see the magic! | |
| **Note:** All required models will be automatically downloaded to {MODELS_ROOT_DIR} when you start the app. | |
| You can customize this location by setting the RECAMMASTER_MODELS_DIR environment variable. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Video input section | |
| with gr.Group(): | |
| gr.Markdown("### Step 1: Upload Video") | |
| video_input = gr.Video(label="Input Video") | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt (describe your video)", | |
| placeholder="A person walking in the street", | |
| value="A dynamic scene" | |
| ) | |
| # Camera selection | |
| with gr.Group(): | |
| gr.Markdown("### Step 2: Select Camera Movement") | |
| camera_type = gr.Radio( | |
| choices=[(v, k) for k, v in CAMERA_TRANSFORMATIONS.items()], | |
| label="Camera Transformation", | |
| value="1" | |
| ) | |
| # Generate button | |
| generate_btn = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| # Output section | |
| output_video = gr.Video(label="Output Video") | |
| status_output = gr.Textbox(label="Generation Status", interactive=False) | |
| # Example videos | |
| gr.Markdown("### Example Videos") | |
| gr.Examples( | |
| examples=[ | |
| [f"{TEST_DATA_DIR}/videos/case0.mp4", "A person dancing", "1"], | |
| [f"{TEST_DATA_DIR}/videos/case1.mp4", "A scenic view", "5"], | |
| ], | |
| inputs=[video_input, text_prompt, camera_type], | |
| ) | |
| # Load models automatically when the interface loads | |
| def on_load(): | |
| status = load_models() | |
| return gr.update(value=status, visible=True if "Error" in status else False) | |
| demo.load(on_load, outputs=[loading_status]) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_recammaster_video, | |
| inputs=[video_input, text_prompt, camera_type], | |
| outputs=[output_video, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |