Onise commited on
Commit
e46a321
·
verified ·
1 Parent(s): 61723d2

Update model.py from anycoder

Browse files
Files changed (1) hide show
  1. model.py +254 -0
model.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import gc
4
+ import time
5
+ from typing import Optional, Callable, Any
6
+ from pathlib import Path
7
+ import numpy as np
8
+ from PIL import Image
9
+ import safetensors.torch
10
+
11
+ # Configuration
12
+ MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" # Base model
13
+ LORA_CACHE_DIR = "/tmp/lora_cache"
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
16
+
17
+ # Ensure LoRA cache directory exists
18
+ os.makedirs(LORA_CACHE_DIR, exist_ok=True)
19
+
20
+ # Predefined LoRA configurations
21
+ AVAILABLE_LORAS = {
22
+ "wan-fast-lora": {
23
+ "repo": "Kijai/Wan2.1-fp8-diffusers", # FP8 quantized for speed
24
+ "filename": "wan2.1_fast_lora.safetensors",
25
+ "description": "Optimized for 2-3x faster generation",
26
+ "trigger_words": []
27
+ },
28
+ "wan-quality-lora": {
29
+ "repo": "Kijai/Wan2.1-fp8-diffusers",
30
+ "filename": "wan2.1_quality_lora.safetensors",
31
+ "description": "Enhanced visual quality",
32
+ "trigger_words": ["high quality", "detailed"]
33
+ },
34
+ "wan-motion-lora": {
35
+ "repo": "Kijai/Wan2.1-fp8-diffusers",
36
+ "filename": "wan2.1_motion_lora.safetensors",
37
+ "description": "Better motion dynamics",
38
+ "trigger_words": ["smooth motion", "dynamic"]
39
+ }
40
+ }
41
+
42
+
43
+ def get_available_loras() -> list:
44
+ """Get list of available LoRAs."""
45
+ return list(AVAILABLE_LORAS.keys())
46
+
47
+
48
+ class WanVideoGenerator:
49
+ """Wan2.2-TI2V-5B Video Generator with LoRA support."""
50
+
51
+ def __init__(self):
52
+ self.pipeline = None
53
+ self.current_lora = None
54
+ self.lora_scale = 0.0
55
+ self._load_model()
56
+
57
+ def _load_model(self):
58
+ """Load the base model with optimizations."""
59
+ from diffusers import WanPipeline, WanTransformer3DModel
60
+ from diffusers.schedulers import UniPCMultistepScheduler
61
+ from transformers import AutoTokenizer, T5EncoderModel
62
+
63
+ print(f"Loading Wan2.2-TI2V-5B model on {DEVICE}...")
64
+
65
+ # Load transformer with memory optimizations
66
+ transformer = WanTransformer3DModel.from_pretrained(
67
+ MODEL_ID,
68
+ subfolder="transformer",
69
+ torch_dtype=DTYPE,
70
+ use_safetensors=True,
71
+ )
72
+
73
+ # Load text encoder
74
+ tokenizer = AutoTokenizer.from_pretrained(
75
+ MODEL_ID,
76
+ subfolder="tokenizer",
77
+ )
78
+ text_encoder = T5EncoderModel.from_pretrained(
79
+ MODEL_ID,
80
+ subfolder="text_encoder",
81
+ torch_dtype=DTYPE,
82
+ )
83
+
84
+ # Create pipeline
85
+ self.pipeline = WanPipeline.from_pretrained(
86
+ MODEL_ID,
87
+ transformer=transformer,
88
+ text_encoder=text_encoder,
89
+ tokenizer=tokenizer,
90
+ torch_dtype=DTYPE,
91
+ )
92
+
93
+ # Enable memory optimizations
94
+ if DEVICE == "cuda":
95
+ self.pipeline.enable_model_cpu_offload()
96
+ # Enable attention slicing for lower memory
97
+ self.pipeline.enable_attention_slicing()
98
+
99
+ # Use efficient scheduler
100
+ self.pipeline.scheduler = UniPCMultistepScheduler.from_config(
101
+ self.pipeline.scheduler.config
102
+ )
103
+
104
+ print("Model loaded successfully!")
105
+
106
+ def load_lora(self, lora_name: str, scale: float = 0.8):
107
+ """Load a LoRA adapter on demand."""
108
+ if lora_name not in AVAILABLE_LORAS:
109
+ raise ValueError(f"Unknown LoRA: {lora_name}")
110
+
111
+ if self.current_lora == lora_name and abs(self.lora_scale - scale) < 0.01:
112
+ print(f"LoRA {lora_name} already loaded with scale {scale}")
113
+ return
114
+
115
+ # Unload previous LoRA
116
+ if self.current_lora:
117
+ self.unload_lora()
118
+
119
+ lora_config = AVAILABLE_LORAS[lora_name]
120
+ lora_path = self._download_lora(lora_config)
121
+
122
+ print(f"Loading LoRA: {lora_name} with scale {scale}...")
123
+
124
+ # Load LoRA weights
125
+ self.pipeline.load_lora_weights(
126
+ lora_path,
127
+ adapter_name=lora_name,
128
+ )
129
+
130
+ # Set LoRA scale
131
+ self.pipeline.set_adapters([lora_name], adapter_weights=[scale])
132
+
133
+ self.current_lora = lora_name
134
+ self.lora_scale = scale
135
+ print(f"LoRA {lora_name} loaded successfully!")
136
+
137
+ def _download_lora(self, lora_config: dict) -> str:
138
+ """Download LoRA weights if not cached."""
139
+ from huggingface_hub import hf_hub_download
140
+
141
+ lora_path = os.path.join(LORA_CACHE_DIR, lora_config["filename"])
142
+
143
+ if not os.path.exists(lora_path):
144
+ print(f"Downloading LoRA: {lora_config['filename']}...")
145
+ lora_path = hf_hub_download(
146
+ repo_id=lora_config["repo"],
147
+ filename=lora_config["filename"],
148
+ local_dir=LORA_CACHE_DIR,
149
+ )
150
+
151
+ return lora_path
152
+
153
+ def unload_lora(self):
154
+ """Unload current LoRA adapter."""
155
+ if self.current_lora and self.pipeline:
156
+ try:
157
+ self.pipeline.disable_lora()
158
+ self.pipeline.unload_lora_weights()
159
+ print(f"Unloaded LoRA: {self.current_lora}")
160
+ except Exception as e:
161
+ print(f"Warning: Could not unload LoRA: {e}")
162
+ finally:
163
+ self.current_lora = None
164
+ self.lora_scale = 0.0
165
+
166
+ @torch.inference_mode()
167
+ def generate(
168
+ self,
169
+ prompt: str,
170
+ negative_prompt: str = "",
171
+ image: Optional[Image.Image] = None,
172
+ height: int = 480,
173
+ width: int = 848,
174
+ num_frames: int = 25,
175
+ guidance_scale: float = 5.0,
176
+ num_inference_steps: int = 20,
177
+ fps: int = 16,
178
+ seed: Optional[int] = None,
179
+ progress_callback: Optional[Callable[[float], None]] = None,
180
+ ) -> str:
181
+ """Generate video from text or image prompt."""
182
+
183
+ # Set seed
184
+ generator = None
185
+ if seed is not None:
186
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
187
+
188
+ # Prepare kwargs
189
+ kwargs = {
190
+ "prompt": prompt,
191
+ "negative_prompt": negative_prompt,
192
+ "height": height,
193
+ "width": width,
194
+ "num_frames": num_frames,
195
+ "guidance_scale": guidance_scale,
196
+ "num_inference_steps": num_inference_steps,
197
+ "generator": generator,
198
+ "output_type": "pil",
199
+ }
200
+
201
+ # Add image for TI2V
202
+ if image is not None:
203
+ kwargs["image"] = image
204
+
205
+ # Generate with progress tracking
206
+ start_time = time.time()
207
+
208
+ # Callback for progress
209
+ def callback_on_step_end(pipeline, i, t, callback_kwargs):
210
+ if progress_callback:
211
+ progress = (i + 1) / num_inference_steps
212
+ progress_callback(progress)
213
+ return callback_kwargs
214
+
215
+ kwargs["callback_on_step_end"] = callback_on_step_end
216
+
217
+ # Generate frames
218
+ output = self.pipeline(**kwargs)
219
+
220
+ frames = output.frames[0]
221
+
222
+ # Save video
223
+ output_path = f"/tmp/output_{int(time.time())}.mp4"
224
+ self._save_video(frames, output_path, fps)
225
+
226
+ elapsed = time.time() - start_time
227
+ print(f"Generation completed in {elapsed:.2f}s")
228
+
229
+ return output_path
230
+
231
+ def _save_video(self, frames: list, output_path: str, fps: int):
232
+ """Save frames as video file."""
233
+ import imageio
234
+
235
+ # Convert PIL images to numpy arrays
236
+ frames_np = [np.array(frame) for frame in frames]
237
+
238
+ # Write video
239
+ with imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8) as writer:
240
+ for frame in frames_np:
241
+ writer.append_data(frame)
242
+
243
+ print(f"Video saved to: {output_path}")
244
+
245
+
246
+ # Singleton instance
247
+ _generator_instance = None
248
+
249
+ def get_generator() -> WanVideoGenerator:
250
+ """Get or create the generator instance."""
251
+ global _generator_instance
252
+ if _generator_instance is None:
253
+ _generator_instance = WanVideoGenerator()
254
+ return _generator_instance