Spaces:
Running
Running
| import torch | |
| import os | |
| import gc | |
| import time | |
| from typing import Optional, Callable, Any | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import safetensors.torch | |
| # Configuration | |
| MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" # Base model | |
| LORA_CACHE_DIR = "/tmp/lora_cache" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| # Ensure LoRA cache directory exists | |
| os.makedirs(LORA_CACHE_DIR, exist_ok=True) | |
| # Predefined LoRA configurations | |
| AVAILABLE_LORAS = { | |
| "wan-fast-lora": { | |
| "repo": "Kijai/Wan2.1-fp8-diffusers", # FP8 quantized for speed | |
| "filename": "wan2.1_fast_lora.safetensors", | |
| "description": "Optimized for 2-3x faster generation", | |
| "trigger_words": [] | |
| }, | |
| "wan-quality-lora": { | |
| "repo": "Kijai/Wan2.1-fp8-diffusers", | |
| "filename": "wan2.1_quality_lora.safetensors", | |
| "description": "Enhanced visual quality", | |
| "trigger_words": ["high quality", "detailed"] | |
| }, | |
| "wan-motion-lora": { | |
| "repo": "Kijai/Wan2.1-fp8-diffusers", | |
| "filename": "wan2.1_motion_lora.safetensors", | |
| "description": "Better motion dynamics", | |
| "trigger_words": ["smooth motion", "dynamic"] | |
| } | |
| } | |
| def get_available_loras() -> list: | |
| """Get list of available LoRAs.""" | |
| return list(AVAILABLE_LORAS.keys()) | |
| class WanVideoGenerator: | |
| """Wan2.2-TI2V-5B Video Generator with LoRA support.""" | |
| def __init__(self): | |
| self.pipeline = None | |
| self.current_lora = None | |
| self.lora_scale = 0.0 | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the base model with optimizations.""" | |
| from diffusers import WanPipeline, WanTransformer3DModel | |
| from diffusers.schedulers import UniPCMultistepScheduler | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| print(f"Loading Wan2.2-TI2V-5B model on {DEVICE}...") | |
| # Load transformer with memory optimizations | |
| transformer = WanTransformer3DModel.from_pretrained( | |
| MODEL_ID, | |
| subfolder="transformer", | |
| torch_dtype=DTYPE, | |
| use_safetensors=True, | |
| ) | |
| # Load text encoder | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| subfolder="tokenizer", | |
| ) | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| MODEL_ID, | |
| subfolder="text_encoder", | |
| torch_dtype=DTYPE, | |
| ) | |
| # Create pipeline | |
| self.pipeline = WanPipeline.from_pretrained( | |
| MODEL_ID, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| torch_dtype=DTYPE, | |
| ) | |
| # Enable memory optimizations | |
| if DEVICE == "cuda": | |
| self.pipeline.enable_model_cpu_offload() | |
| # Enable attention slicing for lower memory | |
| self.pipeline.enable_attention_slicing() | |
| # Use efficient scheduler | |
| self.pipeline.scheduler = UniPCMultistepScheduler.from_config( | |
| self.pipeline.scheduler.config | |
| ) | |
| print("Model loaded successfully!") | |
| def load_lora(self, lora_name: str, scale: float = 0.8): | |
| """Load a LoRA adapter on demand.""" | |
| if lora_name not in AVAILABLE_LORAS: | |
| raise ValueError(f"Unknown LoRA: {lora_name}") | |
| if self.current_lora == lora_name and abs(self.lora_scale - scale) < 0.01: | |
| print(f"LoRA {lora_name} already loaded with scale {scale}") | |
| return | |
| # Unload previous LoRA | |
| if self.current_lora: | |
| self.unload_lora() | |
| lora_config = AVAILABLE_LORAS[lora_name] | |
| lora_path = self._download_lora(lora_config) | |
| print(f"Loading LoRA: {lora_name} with scale {scale}...") | |
| # Load LoRA weights | |
| self.pipeline.load_lora_weights( | |
| lora_path, | |
| adapter_name=lora_name, | |
| ) | |
| # Set LoRA scale | |
| self.pipeline.set_adapters([lora_name], adapter_weights=[scale]) | |
| self.current_lora = lora_name | |
| self.lora_scale = scale | |
| print(f"LoRA {lora_name} loaded successfully!") | |
| def _download_lora(self, lora_config: dict) -> str: | |
| """Download LoRA weights if not cached.""" | |
| from huggingface_hub import hf_hub_download | |
| lora_path = os.path.join(LORA_CACHE_DIR, lora_config["filename"]) | |
| if not os.path.exists(lora_path): | |
| print(f"Downloading LoRA: {lora_config['filename']}...") | |
| lora_path = hf_hub_download( | |
| repo_id=lora_config["repo"], | |
| filename=lora_config["filename"], | |
| local_dir=LORA_CACHE_DIR, | |
| ) | |
| return lora_path | |
| def unload_lora(self): | |
| """Unload current LoRA adapter.""" | |
| if self.current_lora and self.pipeline: | |
| try: | |
| self.pipeline.disable_lora() | |
| self.pipeline.unload_lora_weights() | |
| print(f"Unloaded LoRA: {self.current_lora}") | |
| except Exception as e: | |
| print(f"Warning: Could not unload LoRA: {e}") | |
| finally: | |
| self.current_lora = None | |
| self.lora_scale = 0.0 | |
| def generate( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| image: Optional[Image.Image] = None, | |
| height: int = 480, | |
| width: int = 848, | |
| num_frames: int = 25, | |
| guidance_scale: float = 5.0, | |
| num_inference_steps: int = 20, | |
| fps: int = 16, | |
| seed: Optional[int] = None, | |
| progress_callback: Optional[Callable[[float], None]] = None, | |
| ) -> str: | |
| """Generate video from text or image prompt.""" | |
| # Set seed | |
| generator = None | |
| if seed is not None: | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| # Prepare kwargs | |
| kwargs = { | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| "height": height, | |
| "width": width, | |
| "num_frames": num_frames, | |
| "guidance_scale": guidance_scale, | |
| "num_inference_steps": num_inference_steps, | |
| "generator": generator, | |
| "output_type": "pil", | |
| } | |
| # Add image for TI2V | |
| if image is not None: | |
| kwargs["image"] = image | |
| # Generate with progress tracking | |
| start_time = time.time() | |
| # Callback for progress | |
| def callback_on_step_end(pipeline, i, t, callback_kwargs): | |
| if progress_callback: | |
| progress = (i + 1) / num_inference_steps | |
| progress_callback(progress) | |
| return callback_kwargs | |
| kwargs["callback_on_step_end"] = callback_on_step_end | |
| # Generate frames | |
| output = self.pipeline(**kwargs) | |
| frames = output.frames[0] | |
| # Save video | |
| output_path = f"/tmp/output_{int(time.time())}.mp4" | |
| self._save_video(frames, output_path, fps) | |
| elapsed = time.time() - start_time | |
| print(f"Generation completed in {elapsed:.2f}s") | |
| return output_path | |
| def _save_video(self, frames: list, output_path: str, fps: int): | |
| """Save frames as video file.""" | |
| import imageio | |
| # Convert PIL images to numpy arrays | |
| frames_np = [np.array(frame) for frame in frames] | |
| # Write video | |
| with imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8) as writer: | |
| for frame in frames_np: | |
| writer.append_data(frame) | |
| print(f"Video saved to: {output_path}") | |
| # Singleton instance | |
| _generator_instance = None | |
| def get_generator() -> WanVideoGenerator: | |
| """Get or create the generator instance.""" | |
| global _generator_instance | |
| if _generator_instance is None: | |
| _generator_instance = WanVideoGenerator() | |
| return _generator_instance |