import torch import os import gc import time from typing import Optional, Callable, Any from pathlib import Path import numpy as np from PIL import Image import safetensors.torch # Configuration MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" # Base model LORA_CACHE_DIR = "/tmp/lora_cache" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # Ensure LoRA cache directory exists os.makedirs(LORA_CACHE_DIR, exist_ok=True) # Predefined LoRA configurations AVAILABLE_LORAS = { "wan-fast-lora": { "repo": "Kijai/Wan2.1-fp8-diffusers", # FP8 quantized for speed "filename": "wan2.1_fast_lora.safetensors", "description": "Optimized for 2-3x faster generation", "trigger_words": [] }, "wan-quality-lora": { "repo": "Kijai/Wan2.1-fp8-diffusers", "filename": "wan2.1_quality_lora.safetensors", "description": "Enhanced visual quality", "trigger_words": ["high quality", "detailed"] }, "wan-motion-lora": { "repo": "Kijai/Wan2.1-fp8-diffusers", "filename": "wan2.1_motion_lora.safetensors", "description": "Better motion dynamics", "trigger_words": ["smooth motion", "dynamic"] } } def get_available_loras() -> list: """Get list of available LoRAs.""" return list(AVAILABLE_LORAS.keys()) class WanVideoGenerator: """Wan2.2-TI2V-5B Video Generator with LoRA support.""" def __init__(self): self.pipeline = None self.current_lora = None self.lora_scale = 0.0 self._load_model() def _load_model(self): """Load the base model with optimizations.""" from diffusers import WanPipeline, WanTransformer3DModel from diffusers.schedulers import UniPCMultistepScheduler from transformers import AutoTokenizer, T5EncoderModel print(f"Loading Wan2.2-TI2V-5B model on {DEVICE}...") # Load transformer with memory optimizations transformer = WanTransformer3DModel.from_pretrained( MODEL_ID, subfolder="transformer", torch_dtype=DTYPE, use_safetensors=True, ) # Load text encoder tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, subfolder="tokenizer", ) text_encoder = T5EncoderModel.from_pretrained( MODEL_ID, subfolder="text_encoder", torch_dtype=DTYPE, ) # Create pipeline self.pipeline = WanPipeline.from_pretrained( MODEL_ID, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=DTYPE, ) # Enable memory optimizations if DEVICE == "cuda": self.pipeline.enable_model_cpu_offload() # Enable attention slicing for lower memory self.pipeline.enable_attention_slicing() # Use efficient scheduler self.pipeline.scheduler = UniPCMultistepScheduler.from_config( self.pipeline.scheduler.config ) print("Model loaded successfully!") def load_lora(self, lora_name: str, scale: float = 0.8): """Load a LoRA adapter on demand.""" if lora_name not in AVAILABLE_LORAS: raise ValueError(f"Unknown LoRA: {lora_name}") if self.current_lora == lora_name and abs(self.lora_scale - scale) < 0.01: print(f"LoRA {lora_name} already loaded with scale {scale}") return # Unload previous LoRA if self.current_lora: self.unload_lora() lora_config = AVAILABLE_LORAS[lora_name] lora_path = self._download_lora(lora_config) print(f"Loading LoRA: {lora_name} with scale {scale}...") # Load LoRA weights self.pipeline.load_lora_weights( lora_path, adapter_name=lora_name, ) # Set LoRA scale self.pipeline.set_adapters([lora_name], adapter_weights=[scale]) self.current_lora = lora_name self.lora_scale = scale print(f"LoRA {lora_name} loaded successfully!") def _download_lora(self, lora_config: dict) -> str: """Download LoRA weights if not cached.""" from huggingface_hub import hf_hub_download lora_path = os.path.join(LORA_CACHE_DIR, lora_config["filename"]) if not os.path.exists(lora_path): print(f"Downloading LoRA: {lora_config['filename']}...") lora_path = hf_hub_download( repo_id=lora_config["repo"], filename=lora_config["filename"], local_dir=LORA_CACHE_DIR, ) return lora_path def unload_lora(self): """Unload current LoRA adapter.""" if self.current_lora and self.pipeline: try: self.pipeline.disable_lora() self.pipeline.unload_lora_weights() print(f"Unloaded LoRA: {self.current_lora}") except Exception as e: print(f"Warning: Could not unload LoRA: {e}") finally: self.current_lora = None self.lora_scale = 0.0 @torch.inference_mode() def generate( self, prompt: str, negative_prompt: str = "", image: Optional[Image.Image] = None, height: int = 480, width: int = 848, num_frames: int = 25, guidance_scale: float = 5.0, num_inference_steps: int = 20, fps: int = 16, seed: Optional[int] = None, progress_callback: Optional[Callable[[float], None]] = None, ) -> str: """Generate video from text or image prompt.""" # Set seed generator = None if seed is not None: generator = torch.Generator(device=DEVICE).manual_seed(seed) # Prepare kwargs kwargs = { "prompt": prompt, "negative_prompt": negative_prompt, "height": height, "width": width, "num_frames": num_frames, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "generator": generator, "output_type": "pil", } # Add image for TI2V if image is not None: kwargs["image"] = image # Generate with progress tracking start_time = time.time() # Callback for progress def callback_on_step_end(pipeline, i, t, callback_kwargs): if progress_callback: progress = (i + 1) / num_inference_steps progress_callback(progress) return callback_kwargs kwargs["callback_on_step_end"] = callback_on_step_end # Generate frames output = self.pipeline(**kwargs) frames = output.frames[0] # Save video output_path = f"/tmp/output_{int(time.time())}.mp4" self._save_video(frames, output_path, fps) elapsed = time.time() - start_time print(f"Generation completed in {elapsed:.2f}s") return output_path def _save_video(self, frames: list, output_path: str, fps: int): """Save frames as video file.""" import imageio # Convert PIL images to numpy arrays frames_np = [np.array(frame) for frame in frames] # Write video with imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8) as writer: for frame in frames_np: writer.append_data(frame) print(f"Video saved to: {output_path}") # Singleton instance _generator_instance = None def get_generator() -> WanVideoGenerator: """Get or create the generator instance.""" global _generator_instance if _generator_instance is None: _generator_instance = WanVideoGenerator() return _generator_instance