Spaces:
Configuration error
Configuration error
| import torch | |
| from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler | |
| from diffusers.utils import export_to_video | |
| from transformers import CLIPVisionModel | |
| import gradio as gr | |
| import tempfile | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| from PIL import Image | |
| import random | |
| import logging | |
| import gc | |
| # 로깅 설정 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 모델 설정 | |
| MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" | |
| LORA_REPO_ID = "Kijai/WanVideo_comfy" | |
| LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" | |
| # 파라미터 설정 | |
| MOD_VALUE = 32 | |
| DEFAULT_H_SLIDER_VALUE = 512 | |
| DEFAULT_W_SLIDER_VALUE = 512 # Zero GPU를 위해 정사각형 기본값 | |
| NEW_FORMULA_MAX_AREA = 480.0 * 832.0 | |
| SLIDER_MIN_H, SLIDER_MAX_H = 128, 896 | |
| SLIDER_MIN_W, SLIDER_MAX_W = 128, 896 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| FIXED_FPS = 24 | |
| MIN_FRAMES_MODEL = 8 | |
| MAX_FRAMES_MODEL = 81 | |
| default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation" | |
| default_negative_prompt = "static, blurred, low quality, watermark, text" | |
| # 모델 글로벌 로딩 | |
| logger.info("Loading model components...") | |
| image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32) | |
| vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32) | |
| pipe = WanImageToVideoPipeline.from_pretrained( | |
| MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0) | |
| pipe.to("cuda") | |
| # LoRA 로딩 | |
| try: | |
| causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) | |
| pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") | |
| pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) | |
| pipe.fuse_lora() | |
| logger.info("LoRA loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"LoRA loading failed: {e}") | |
| # 메모리 최적화 - WanImageToVideoPipeline에서 지원하는 메서드만 사용 | |
| try: | |
| pipe.enable_model_cpu_offload() | |
| logger.info("CPU offload enabled") | |
| except: | |
| logger.info("CPU offload not available") | |
| logger.info("Model loaded and ready") | |
| def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area, | |
| min_slider_h, max_slider_h, | |
| min_slider_w, max_slider_w, | |
| default_h, default_w): | |
| orig_w, orig_h = pil_image.size | |
| if orig_w <= 0 or orig_h <= 0: | |
| return default_h, default_w | |
| aspect_ratio = orig_h / orig_w | |
| # Zero GPU를 위한 보수적인 계산 | |
| if hasattr(spaces, 'GPU'): | |
| # 더 작은 max_area 사용 | |
| calculation_max_area = min(calculation_max_area, 320.0 * 320.0) | |
| calc_h = round(np.sqrt(calculation_max_area * aspect_ratio)) | |
| calc_w = round(np.sqrt(calculation_max_area / aspect_ratio)) | |
| calc_h = max(mod_val, (calc_h // mod_val) * mod_val) | |
| calc_w = max(mod_val, (calc_w // mod_val) * mod_val) | |
| # Zero GPU 환경에서 추가 제한 | |
| if hasattr(spaces, 'GPU'): | |
| max_slider_h = min(max_slider_h, 640) | |
| max_slider_w = min(max_slider_w, 640) | |
| new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val)) | |
| new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val)) | |
| return new_h, new_w | |
| def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val): | |
| if uploaded_pil_image is None: | |
| return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) | |
| try: | |
| new_h, new_w = _calculate_new_dimensions_wan( | |
| uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA, | |
| SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W, | |
| DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE | |
| ) | |
| return gr.update(value=new_h), gr.update(value=new_w) | |
| except Exception as e: | |
| gr.Warning("Error attempting to calculate new dimensions") | |
| return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) | |
| def get_duration(input_image, prompt, height, width, | |
| negative_prompt, duration_seconds, | |
| guidance_scale, steps, | |
| seed, randomize_seed, | |
| progress): | |
| # Zero GPU를 위한 보수적인 시간 할당 | |
| base_time = 60 | |
| if hasattr(spaces, 'GPU'): | |
| # Zero GPU 환경에서 더 많은 시간 할당 | |
| if steps > 4 and duration_seconds > 2: | |
| return 90 | |
| elif steps > 4 or duration_seconds > 2: | |
| return 80 | |
| else: | |
| return 70 | |
| else: | |
| # 일반 GPU 환경 | |
| if steps > 4 and duration_seconds > 2: | |
| return 90 | |
| elif steps > 4 or duration_seconds > 2: | |
| return 75 | |
| else: | |
| return 60 | |
| def generate_video(input_image, prompt, height, width, | |
| negative_prompt=default_negative_prompt, duration_seconds = 2, | |
| guidance_scale = 1, steps = 4, | |
| seed = 42, randomize_seed = False, | |
| progress=gr.Progress(track_tqdm=True)): | |
| if input_image is None: | |
| raise gr.Error("Please upload an input image.") | |
| # Zero GPU 환경에서 추가 검증 | |
| if hasattr(spaces, 'GPU'): | |
| # 픽셀 제한 | |
| max_pixels = 409600 # 640x640 | |
| if height * width > max_pixels: | |
| raise gr.Error(f"Resolution too high for Zero GPU. Maximum {max_pixels:,} pixels (e.g., 640×640)") | |
| # Duration 제한 | |
| if duration_seconds > 2.5: | |
| duration_seconds = 2.5 | |
| gr.Warning("Duration limited to 2.5s in Zero GPU environment") | |
| # Steps 제한 | |
| if steps > 8: | |
| steps = 8 | |
| gr.Warning("Steps limited to 8 in Zero GPU environment") | |
| target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) | |
| target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) | |
| num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
| # Zero GPU에서 프레임 수 추가 제한 | |
| if hasattr(spaces, 'GPU'): | |
| max_frames_zerogpu = int(2.5 * FIXED_FPS) # 2.5초 | |
| num_frames = min(num_frames, max_frames_zerogpu) | |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
| logger.info(f"Generating video: {target_h}x{target_w}, {num_frames} frames, seed={current_seed}") | |
| # 이미지 리사이즈 | |
| resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) | |
| try: | |
| with torch.inference_mode(): | |
| output_frames_list = pipe( | |
| image=resized_image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=target_h, | |
| width=target_w, | |
| num_frames=num_frames, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(steps), | |
| generator=torch.Generator(device="cuda").manual_seed(current_seed) | |
| ).frames[0] | |
| except torch.cuda.OutOfMemoryError: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| raise gr.Error("GPU out of memory. Try smaller resolution or shorter duration.") | |
| except Exception as e: | |
| logger.error(f"Generation failed: {e}") | |
| raise gr.Error(f"Video generation failed: {str(e)[:100]}") | |
| # 비디오 저장 | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
| video_path = tmpfile.name | |
| export_to_video(output_frames_list, video_path, fps=FIXED_FPS) | |
| # 메모리 정리 | |
| del output_frames_list | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return video_path, current_seed | |
| # CSS 스타일 (기존 UI 유지) | |
| css = """ | |
| .container { | |
| max-width: 1200px; | |
| margin: auto; | |
| padding: 20px; | |
| } | |
| .header { | |
| text-align: center; | |
| margin-bottom: 30px; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 40px; | |
| border-radius: 20px; | |
| color: white; | |
| box-shadow: 0 10px 30px rgba(0,0,0,0.2); | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .header::before { | |
| content: ''; | |
| position: absolute; | |
| top: -50%; | |
| left: -50%; | |
| width: 200%; | |
| height: 200%; | |
| background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, transparent 70%); | |
| animation: pulse 4s ease-in-out infinite; | |
| } | |
| @keyframes pulse { | |
| 0%, 100% { transform: scale(1); opacity: 0.5; } | |
| 50% { transform: scale(1.1); opacity: 0.8; } | |
| } | |
| .header h1 { | |
| font-size: 3em; | |
| margin-bottom: 10px; | |
| text-shadow: 2px 2px 4px rgba(0,0,0,0.3); | |
| position: relative; | |
| z-index: 1; | |
| } | |
| .header p { | |
| font-size: 1.2em; | |
| opacity: 0.95; | |
| position: relative; | |
| z-index: 1; | |
| } | |
| .gpu-status { | |
| position: absolute; | |
| top: 10px; | |
| right: 10px; | |
| background: rgba(0,0,0,0.3); | |
| padding: 5px 15px; | |
| border-radius: 20px; | |
| font-size: 0.8em; | |
| } | |
| .main-content { | |
| background: rgba(255, 255, 255, 0.95); | |
| border-radius: 20px; | |
| padding: 30px; | |
| box-shadow: 0 5px 20px rgba(0,0,0,0.1); | |
| backdrop-filter: blur(10px); | |
| } | |
| .input-section { | |
| background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
| padding: 25px; | |
| border-radius: 15px; | |
| margin-bottom: 20px; | |
| } | |
| .generate-btn { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| font-size: 1.3em; | |
| padding: 15px 40px; | |
| border-radius: 30px; | |
| border: none; | |
| cursor: pointer; | |
| transition: all 0.3s ease; | |
| box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); | |
| width: 100%; | |
| margin-top: 20px; | |
| } | |
| .generate-btn:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6); | |
| } | |
| .generate-btn:active { | |
| transform: translateY(0); | |
| } | |
| .video-output { | |
| background: #f8f9fa; | |
| padding: 20px; | |
| border-radius: 15px; | |
| text-align: center; | |
| min-height: 400px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| } | |
| .accordion { | |
| background: rgba(255, 255, 255, 0.7); | |
| border-radius: 10px; | |
| margin-top: 15px; | |
| padding: 15px; | |
| } | |
| .slider-container { | |
| background: rgba(255, 255, 255, 0.5); | |
| padding: 15px; | |
| border-radius: 10px; | |
| margin: 10px 0; | |
| } | |
| body { | |
| background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab); | |
| background-size: 400% 400%; | |
| animation: gradient 15s ease infinite; | |
| } | |
| @keyframes gradient { | |
| 0% { background-position: 0% 50%; } | |
| 50% { background-position: 100% 50%; } | |
| 100% { background-position: 0% 50%; } | |
| } | |
| .warning-box { | |
| background: rgba(255, 193, 7, 0.1); | |
| border: 1px solid rgba(255, 193, 7, 0.3); | |
| border-radius: 10px; | |
| padding: 15px; | |
| margin: 10px 0; | |
| color: #856404; | |
| font-size: 0.9em; | |
| } | |
| .info-box { | |
| background: rgba(52, 152, 219, 0.1); | |
| border: 1px solid rgba(52, 152, 219, 0.3); | |
| border-radius: 10px; | |
| padding: 15px; | |
| margin: 10px 0; | |
| color: #2c5282; | |
| font-size: 0.9em; | |
| } | |
| .footer { | |
| text-align: center; | |
| margin-top: 30px; | |
| color: #666; | |
| font-size: 0.9em; | |
| } | |
| """ | |
| # Gradio UI (기존 구조 유지) | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
| with gr.Column(elem_classes="container"): | |
| # Header with GPU status | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>🎬 AI Video Magic Studio</h1> | |
| <p>Transform your images into captivating videos with Wan 2.1 + CausVid LoRA</p> | |
| <div class="gpu-status">🖥️ Zero GPU Optimized</div> | |
| </div> | |
| """) | |
| # GPU 메모리 경고 | |
| if hasattr(spaces, 'GPU'): | |
| gr.HTML(""" | |
| <div class="warning-box"> | |
| <strong>💡 Zero GPU Performance Tips:</strong> | |
| <ul style="margin: 5px 0; padding-left: 20px;"> | |
| <li>Maximum duration: 2.5 seconds</li> | |
| <li>Maximum resolution: 640×640 pixels</li> | |
| <li>Recommended: 512×512 at 2 seconds</li> | |
| <li>Use 4-6 steps for optimal speed/quality balance</li> | |
| <li>Processing time: ~60-90 seconds</li> | |
| </ul> | |
| </div> | |
| """) | |
| # 정보 박스 | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| <strong>🎯 Quick Start Guide:</strong> | |
| <ol style="margin: 5px 0; padding-left: 20px;"> | |
| <li>Upload your image - AI will calculate optimal dimensions</li> | |
| <li>Enter a creative prompt or use the default</li> | |
| <li>Adjust duration (2s recommended for best results)</li> | |
| <li>Click Generate and wait for completion</li> | |
| </ol> | |
| </div> | |
| """) | |
| with gr.Row(elem_classes="main-content"): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📸 Input Settings") | |
| with gr.Column(elem_classes="input-section"): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="🖼️ Upload Your Image", | |
| elem_classes="image-upload" | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="✨ Animation Prompt", | |
| value=default_prompt_i2v, | |
| placeholder="Describe how you want your image to move...", | |
| lines=2 | |
| ) | |
| duration_input = gr.Slider( | |
| minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1), | |
| maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1) if not hasattr(spaces, 'GPU') else 2.5, | |
| step=0.1, | |
| value=2, | |
| label=f"⏱️ Video Duration (seconds) - Clamped to {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps", | |
| elem_classes="slider-container" | |
| ) | |
| with gr.Accordion("🎛️ Advanced Settings", open=False, elem_classes="accordion"): | |
| negative_prompt = gr.Textbox( | |
| label="🚫 Negative Prompt", | |
| value=default_negative_prompt, | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42, | |
| label="🎲 Seed" | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="🔀 Randomize", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| height_slider = gr.Slider( | |
| minimum=SLIDER_MIN_H, | |
| maximum=SLIDER_MAX_H if not hasattr(spaces, 'GPU') else 640, | |
| step=MOD_VALUE, | |
| value=DEFAULT_H_SLIDER_VALUE, | |
| label=f"📏 Height (multiple of {MOD_VALUE})" | |
| ) | |
| width_slider = gr.Slider( | |
| minimum=SLIDER_MIN_W, | |
| maximum=SLIDER_MAX_W if not hasattr(spaces, 'GPU') else 640, | |
| step=MOD_VALUE, | |
| value=DEFAULT_W_SLIDER_VALUE, | |
| label=f"📐 Width (multiple of {MOD_VALUE})" | |
| ) | |
| steps_slider = gr.Slider( | |
| minimum=1, | |
| maximum=30 if not hasattr(spaces, 'GPU') else 8, | |
| step=1, | |
| value=4, | |
| label="🔧 Quality Steps (4-6 recommended)" | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=0.0, | |
| maximum=20.0, | |
| step=0.5, | |
| value=1.0, | |
| label="🎯 Guidance Scale", | |
| visible=False | |
| ) | |
| generate_btn = gr.Button( | |
| "🎬 Generate Video", | |
| variant="primary", | |
| elem_classes="generate-btn" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎥 Generated Video") | |
| video_output = gr.Video( | |
| label="", | |
| autoplay=True, | |
| elem_classes="video-output" | |
| ) | |
| gr.HTML(""" | |
| <div class="footer"> | |
| <p>💡 Tip: For best results, use clear images with good lighting and distinct subjects</p> | |
| </div> | |
| """) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 512, 512], | |
| ["forg.jpg", "the frog jumps around", 448, 576], | |
| ], | |
| inputs=[input_image, prompt_input, height_slider, width_slider], | |
| outputs=[video_output, seed], | |
| fn=generate_video, | |
| cache_examples=False # 캐시 비활성화로 메모리 절약 | |
| ) | |
| # 개선사항 요약 | |
| gr.HTML(""" | |
| <div style="background: rgba(255,255,255,0.9); border-radius: 10px; padding: 15px; margin-top: 20px; font-size: 0.8em; text-align: center;"> | |
| <p style="margin: 0; color: #666;"> | |
| <strong style="color: #667eea;">Powered by:</strong> | |
| Wan 2.1 I2V (14B) + CausVid LoRA • 🚀 4-8 steps fast inference • 🎬 Up to 81 frames | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers | |
| input_image.upload( | |
| fn=handle_image_upload_for_dims_wan, | |
| inputs=[input_image, height_slider, width_slider], | |
| outputs=[height_slider, width_slider] | |
| ) | |
| input_image.clear( | |
| fn=handle_image_upload_for_dims_wan, | |
| inputs=[input_image, height_slider, width_slider], | |
| outputs=[height_slider, width_slider] | |
| ) | |
| generate_btn.click( | |
| fn=generate_video, | |
| inputs=[ | |
| input_image, prompt_input, height_slider, width_slider, | |
| negative_prompt, duration_input, guidance_scale, | |
| steps_slider, seed, randomize_seed | |
| ], | |
| outputs=[video_output, seed] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |