File size: 8,248 Bytes
e46a321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import torch
import os
import gc
import time
from typing import Optional, Callable, Any
from pathlib import Path
import numpy as np
from PIL import Image
import safetensors.torch

# Configuration
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"  # Base model
LORA_CACHE_DIR = "/tmp/lora_cache"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32

# Ensure LoRA cache directory exists
os.makedirs(LORA_CACHE_DIR, exist_ok=True)

# Predefined LoRA configurations
AVAILABLE_LORAS = {
    "wan-fast-lora": {
        "repo": "Kijai/Wan2.1-fp8-diffusers",  # FP8 quantized for speed
        "filename": "wan2.1_fast_lora.safetensors",
        "description": "Optimized for 2-3x faster generation",
        "trigger_words": []
    },
    "wan-quality-lora": {
        "repo": "Kijai/Wan2.1-fp8-diffusers",
        "filename": "wan2.1_quality_lora.safetensors",
        "description": "Enhanced visual quality",
        "trigger_words": ["high quality", "detailed"]
    },
    "wan-motion-lora": {
        "repo": "Kijai/Wan2.1-fp8-diffusers",
        "filename": "wan2.1_motion_lora.safetensors",
        "description": "Better motion dynamics",
        "trigger_words": ["smooth motion", "dynamic"]
    }
}


def get_available_loras() -> list:
    """Get list of available LoRAs."""
    return list(AVAILABLE_LORAS.keys())


class WanVideoGenerator:
    """Wan2.2-TI2V-5B Video Generator with LoRA support."""
    
    def __init__(self):
        self.pipeline = None
        self.current_lora = None
        self.lora_scale = 0.0
        self._load_model()
    
    def _load_model(self):
        """Load the base model with optimizations."""
        from diffusers import WanPipeline, WanTransformer3DModel
        from diffusers.schedulers import UniPCMultistepScheduler
        from transformers import AutoTokenizer, T5EncoderModel
        
        print(f"Loading Wan2.2-TI2V-5B model on {DEVICE}...")
        
        # Load transformer with memory optimizations
        transformer = WanTransformer3DModel.from_pretrained(
            MODEL_ID,
            subfolder="transformer",
            torch_dtype=DTYPE,
            use_safetensors=True,
        )
        
        # Load text encoder
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_ID,
            subfolder="tokenizer",
        )
        text_encoder = T5EncoderModel.from_pretrained(
            MODEL_ID,
            subfolder="text_encoder",
            torch_dtype=DTYPE,
        )
        
        # Create pipeline
        self.pipeline = WanPipeline.from_pretrained(
            MODEL_ID,
            transformer=transformer,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            torch_dtype=DTYPE,
        )
        
        # Enable memory optimizations
        if DEVICE == "cuda":
            self.pipeline.enable_model_cpu_offload()
            # Enable attention slicing for lower memory
            self.pipeline.enable_attention_slicing()
        
        # Use efficient scheduler
        self.pipeline.scheduler = UniPCMultistepScheduler.from_config(
            self.pipeline.scheduler.config
        )
        
        print("Model loaded successfully!")
    
    def load_lora(self, lora_name: str, scale: float = 0.8):
        """Load a LoRA adapter on demand."""
        if lora_name not in AVAILABLE_LORAS:
            raise ValueError(f"Unknown LoRA: {lora_name}")
        
        if self.current_lora == lora_name and abs(self.lora_scale - scale) < 0.01:
            print(f"LoRA {lora_name} already loaded with scale {scale}")
            return
        
        # Unload previous LoRA
        if self.current_lora:
            self.unload_lora()
        
        lora_config = AVAILABLE_LORAS[lora_name]
        lora_path = self._download_lora(lora_config)
        
        print(f"Loading LoRA: {lora_name} with scale {scale}...")
        
        # Load LoRA weights
        self.pipeline.load_lora_weights(
            lora_path,
            adapter_name=lora_name,
        )
        
        # Set LoRA scale
        self.pipeline.set_adapters([lora_name], adapter_weights=[scale])
        
        self.current_lora = lora_name
        self.lora_scale = scale
        print(f"LoRA {lora_name} loaded successfully!")
    
    def _download_lora(self, lora_config: dict) -> str:
        """Download LoRA weights if not cached."""
        from huggingface_hub import hf_hub_download
        
        lora_path = os.path.join(LORA_CACHE_DIR, lora_config["filename"])
        
        if not os.path.exists(lora_path):
            print(f"Downloading LoRA: {lora_config['filename']}...")
            lora_path = hf_hub_download(
                repo_id=lora_config["repo"],
                filename=lora_config["filename"],
                local_dir=LORA_CACHE_DIR,
            )
        
        return lora_path
    
    def unload_lora(self):
        """Unload current LoRA adapter."""
        if self.current_lora and self.pipeline:
            try:
                self.pipeline.disable_lora()
                self.pipeline.unload_lora_weights()
                print(f"Unloaded LoRA: {self.current_lora}")
            except Exception as e:
                print(f"Warning: Could not unload LoRA: {e}")
            finally:
                self.current_lora = None
                self.lora_scale = 0.0
    
    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        negative_prompt: str = "",
        image: Optional[Image.Image] = None,
        height: int = 480,
        width: int = 848,
        num_frames: int = 25,
        guidance_scale: float = 5.0,
        num_inference_steps: int = 20,
        fps: int = 16,
        seed: Optional[int] = None,
        progress_callback: Optional[Callable[[float], None]] = None,
    ) -> str:
        """Generate video from text or image prompt."""
        
        # Set seed
        generator = None
        if seed is not None:
            generator = torch.Generator(device=DEVICE).manual_seed(seed)
        
        # Prepare kwargs
        kwargs = {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "height": height,
            "width": width,
            "num_frames": num_frames,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
            "generator": generator,
            "output_type": "pil",
        }
        
        # Add image for TI2V
        if image is not None:
            kwargs["image"] = image
        
        # Generate with progress tracking
        start_time = time.time()
        
        # Callback for progress
        def callback_on_step_end(pipeline, i, t, callback_kwargs):
            if progress_callback:
                progress = (i + 1) / num_inference_steps
                progress_callback(progress)
            return callback_kwargs
        
        kwargs["callback_on_step_end"] = callback_on_step_end
        
        # Generate frames
        output = self.pipeline(**kwargs)
        
        frames = output.frames[0]
        
        # Save video
        output_path = f"/tmp/output_{int(time.time())}.mp4"
        self._save_video(frames, output_path, fps)
        
        elapsed = time.time() - start_time
        print(f"Generation completed in {elapsed:.2f}s")
        
        return output_path
    
    def _save_video(self, frames: list, output_path: str, fps: int):
        """Save frames as video file."""
        import imageio
        
        # Convert PIL images to numpy arrays
        frames_np = [np.array(frame) for frame in frames]
        
        # Write video
        with imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8) as writer:
            for frame in frames_np:
                writer.append_data(frame)
        
        print(f"Video saved to: {output_path}")


# Singleton instance
_generator_instance = None

def get_generator() -> WanVideoGenerator:
    """Get or create the generator instance."""
    global _generator_instance
    if _generator_instance is None:
        _generator_instance = WanVideoGenerator()
    return _generator_instance