Spaces:
Paused
Paused
| import gradio as gr | |
| from fastapi import FastAPI, UploadFile, File | |
| import uvicorn, uuid, os | |
| import subprocess | |
| import sys | |
| # Initialize on startup | |
| def initialize(): | |
| """Initialize the application by downloading models and cloning repositories""" | |
| print("Initializing Wan2.1 VACE environment...") | |
| # Clone repositories if they don't exist | |
| if not os.path.exists("Wan2.1"): | |
| print("Cloning Wan2.1 repository...") | |
| try: | |
| subprocess.run(["git", "clone", "https://github.com/Wan-Video/Wan2.1.git"], check=True) | |
| except Exception as e: | |
| print(f"Warning: Failed to clone Wan2.1: {e}") | |
| if not os.path.exists("VACE"): | |
| print("Cloning VACE repository...") | |
| try: | |
| subprocess.run(["git", "clone", "https://github.com/ali-vilab/VACE.git"], check=True) | |
| except Exception as e: | |
| print(f"Warning: Failed to clone VACE: {e}") | |
| # Patch Wan2.1 attention.py to disable flash_attn requirement | |
| attention_file = "Wan2.1/wan/modules/attention.py" | |
| if os.path.exists(attention_file): | |
| print("Patching attention.py to disable flash_attn requirement...") | |
| try: | |
| with open(attention_file, 'r') as f: | |
| content = f.read() | |
| # Replace the assert statement with a fallback | |
| if "assert FLASH_ATTN_2_AVAILABLE" in content: | |
| # First, ensure F is imported | |
| if "import torch.nn.functional as F" not in content: | |
| content = "import torch.nn.functional as F\n" + content | |
| # Replace the assert with a conditional return | |
| # Find the line with assert and get its indentation | |
| lines = content.split('\n') | |
| for i, line in enumerate(lines): | |
| if "assert FLASH_ATTN_2_AVAILABLE" in line: | |
| # Get the indentation of the assert line | |
| indent = len(line) - len(line.lstrip()) | |
| indent_str = ' ' * indent | |
| # Replace with properly indented if statement | |
| lines[i] = f"{indent_str}if not FLASH_ATTN_2_AVAILABLE:" | |
| lines.insert(i + 1, f"{indent_str} return F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)") | |
| break | |
| content = '\n'.join(lines) | |
| with open(attention_file, 'w') as f: | |
| f.write(content) | |
| print("Successfully patched attention.py") | |
| except Exception as e: | |
| print(f"Warning: Failed to patch attention.py: {e}") | |
| # Download models using huggingface-cli | |
| models = [ | |
| ("Wan-AI/Wan2.1-VACE-1.3B", "Wan2.1-VACE-1.3B"), | |
| ("Wan-AI/Wan2.1-FLF2V-14B-720P", "Wan2.1-FLF2V-14B-720P") | |
| ] | |
| for repo_id, local_dir in models: | |
| if not os.path.exists(local_dir): | |
| print(f"Downloading {repo_id}...") | |
| try: | |
| # Use huggingface-cli to download all files | |
| subprocess.run([ | |
| "huggingface-cli", "download", repo_id, | |
| "--local-dir", local_dir, | |
| "--local-dir-use-symlinks", "False" | |
| ], check=True) | |
| print(f"Successfully downloaded {repo_id}") | |
| except subprocess.CalledProcessError as e: | |
| print(f"ERROR: Failed to download {repo_id}: {e}") | |
| print("Please ensure you have sufficient disk space and network connectivity.") | |
| else: | |
| print(f"Model {local_dir} already exists, skipping download") | |
| # Check for critical model files | |
| critical_files = [ | |
| "Wan2.1-VACE-1.3B/models_t5_umt5-xxl-enc-bf16.pth", | |
| "Wan2.1-VACE-1.3B/diffusion_pytorch_model.safetensors", | |
| "Wan2.1-FLF2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth", | |
| "Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model.safetensors.index.json" | |
| ] | |
| for file_path in critical_files: | |
| if not os.path.exists(file_path): | |
| print(f"WARNING: Critical model file missing: {file_path}") | |
| print("The application may not work properly without this file.") | |
| # Run initialization | |
| initialize() | |
| # Import after initialization | |
| from wan_runner import generate_video, generate_image | |
| api = FastAPI() | |
| async def api_generate_video(ref: UploadFile = File(...), first: UploadFile = File(...), last: UploadFile = File(...)): | |
| uid = uuid.uuid4().hex | |
| os.makedirs(uid, exist_ok=True) | |
| paths = [f"{uid}/{name}" for name in ["ref.png", "first.png", "last.png"]] | |
| for upload, path in zip([ref, first, last], paths): | |
| with open(path, "wb") as f: | |
| f.write(await upload.read()) | |
| output = f"{uid}/output.mp4" | |
| generate_video(*paths, output) | |
| return {"video_path": output} | |
| async def api_generate_image(ref: UploadFile = File(...), prompt: str = ""): | |
| uid = uuid.uuid4().hex | |
| os.makedirs(uid, exist_ok=True) | |
| ref_path = f"{uid}/ref.png" | |
| with open(ref_path, "wb") as f: | |
| f.write(await ref.read()) | |
| output = f"{uid}/output.png" | |
| generate_image(ref_path, prompt, output) | |
| return {"image_path": output} | |
| with gr.Blocks() as demo: | |
| with gr.Tab("ๅ็ป็ๆ"): | |
| gr.Markdown("### FLF2V-14B ๅ็ป็ๆ\nโ ๏ธ ใใฎใขใใซใฏ**1280ร720 (16:9)ใฎใฟ**ใตใใผใใใฆใใพใใใขใใใญใผใ็ปๅใไปใฎใตใคใบใงใ**1280ร720**ใซ่ชๅใชใตใคใบใใใพใใ\nโฑ๏ธ ็ๆใใใๅ็ปใฏ**5็ง้**ใงใใ") | |
| ref_img = gr.Image(label="ๅ็ ง็ปๅ", type="pil") | |
| first_img = gr.Image(label="้ๅง็ปๅ", type="pil") | |
| last_img = gr.Image(label="็ตไบ็ปๅ", type="pil") | |
| btn_video = gr.Button("ๅ็ปใ็ๆ") | |
| output_video = gr.Video() | |
| def video_ui(ref, first, last): | |
| import tempfile | |
| from PIL import Image | |
| try: | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| # FLF2V-14B only supports 1280x720 | |
| resolution = (1280, 720) | |
| # Resize all images to required resolution | |
| # Note: This may change aspect ratio, but it's required by the model | |
| ref_resized = ref.resize(resolution, Image.Resampling.LANCZOS) | |
| first_resized = first.resize(resolution, Image.Resampling.LANCZOS) | |
| last_resized = last.resize(resolution, Image.Resampling.LANCZOS) | |
| # Save resized images | |
| ref_path = f"{tmpdir}/ref.png" | |
| first_path = f"{tmpdir}/first.png" | |
| last_path = f"{tmpdir}/last.png" | |
| ref_resized.save(ref_path) | |
| first_resized.save(first_path) | |
| last_resized.save(last_path) | |
| output = f"{uuid.uuid4().hex}.mp4" | |
| # FLF2V only supports 1280x720 | |
| generate_video(ref_path, first_path, last_path, output, size="1280*720") | |
| return output | |
| except FileNotFoundError as e: | |
| raise gr.Error(str(e)) | |
| except Exception as e: | |
| raise gr.Error(f"ๅ็ป็ๆใจใฉใผ: {str(e)}") | |
| btn_video.click(video_ui, [ref_img, first_img, last_img], output_video) | |
| with gr.Tab("็ปๅ็ๆ"): | |
| gr.Markdown("### VACE-1.3B ็ปๅ็ๆ\nโ ๏ธ ใใฎใขใใซใฏ**832ร480๏ผๆจช้ท๏ผใพใใฏ480ร832๏ผ็ธฆ้ท๏ผใฎใฟ**ใตใใผใใใฆใใพใใใขใใใญใผใ็ปๅใไปใฎใตใคใบใงใๅฏพๅฟ่งฃๅๅบฆใซ่ชๅใชใตใคใบใใใพใใ") | |
| ref_img2 = gr.Image(label="ๅ็ ง็ปๅ", type="pil") | |
| prompt = gr.Textbox(label="็ปๅใใญใณใใ") | |
| btn_image = gr.Button("็ปๅใ็ๆ") | |
| output_image = gr.Image() | |
| def image_ui(ref, prompt): | |
| import tempfile | |
| from PIL import Image | |
| try: | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| # Get original aspect ratio | |
| orig_width, orig_height = ref.size | |
| aspect_ratio = orig_width / orig_height | |
| # Supported resolutions for VACE model based on generate.py | |
| # ใจใฉใผใกใใปใผใธใใๅฎ้ใซใตใใผใใใใฆใใใฎใฏไปฅไธใฎ2ใคใฎใฟ | |
| supported_resolutions = [ | |
| (832, 480), # 16:9 (approx) landscape | |
| (480, 832), # 9:16 (approx) portrait | |
| ] | |
| # Find best matching resolution based on aspect ratio | |
| best_resolution = min(supported_resolutions, | |
| key=lambda res: abs((res[0]/res[1]) - aspect_ratio)) | |
| # Resize to best matching resolution | |
| ref_resized = ref.resize(best_resolution, Image.Resampling.LANCZOS) | |
| # Save resized image | |
| ref_path = f"{tmpdir}/ref.png" | |
| ref_resized.save(ref_path) | |
| # Update size parameter for model | |
| size_param = f"{best_resolution[0]}*{best_resolution[1]}" | |
| output = f"{uuid.uuid4().hex}.png" | |
| # Pass size parameter to generate_image | |
| generate_image(ref_path, prompt, output, size=size_param) | |
| return output | |
| except FileNotFoundError as e: | |
| raise gr.Error(str(e)) | |
| except Exception as e: | |
| raise gr.Error(f"็ปๅ็ๆใจใฉใผ: {str(e)}") | |
| btn_image.click(image_ui, [ref_img2, prompt], output_image) | |
| app = gr.mount_gradio_app(api, demo, path="/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |