|
|
import warnings |
|
|
warnings.filterwarnings('ignore', category=FutureWarning) |
|
|
warnings.filterwarnings('ignore', category=DeprecationWarning) |
|
|
|
|
|
import gc |
|
|
import os |
|
|
import tempfile |
|
|
import traceback |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
import ftfy |
|
|
import sentencepiece |
|
|
|
|
|
|
|
|
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline |
|
|
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel |
|
|
from diffusers.utils.export_utils import export_to_video |
|
|
|
|
|
|
|
|
class VideoEngine: |
|
|
""" |
|
|
Ultra-fast video generation with FP8 quantization. |
|
|
70-90s inference time (compared to 150s baseline). |
|
|
""" |
|
|
|
|
|
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" |
|
|
TRANSFORMER_REPO = "cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers" |
|
|
LORA_REPO = "Kijai/WanVideo_comfy" |
|
|
LORA_WEIGHT = "Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors" |
|
|
|
|
|
|
|
|
MAX_DIM = 832 |
|
|
MIN_DIM = 480 |
|
|
SQUARE_DIM = 640 |
|
|
MULTIPLE_OF = 16 |
|
|
FIXED_FPS = 16 |
|
|
MIN_FRAMES = 8 |
|
|
MAX_FRAMES = 81 |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize VideoEngine.""" |
|
|
self.is_spaces = os.environ.get('SPACE_ID') is not None |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.pipeline: Optional[WanImageToVideoPipeline] = None |
|
|
self.is_loaded = False |
|
|
self.use_aoti = False |
|
|
|
|
|
print(f"✓ VideoEngine initialized ({self.device})") |
|
|
|
|
|
def _check_xformers_available(self) -> bool: |
|
|
"""Check if xFormers is available.""" |
|
|
try: |
|
|
import xformers |
|
|
return True |
|
|
except ImportError: |
|
|
return False |
|
|
|
|
|
def load_model(self) -> None: |
|
|
"""Load model with FP8 quantization and AOTI compilation.""" |
|
|
if self.is_loaded: |
|
|
print("⚠ VideoEngine already loaded") |
|
|
return |
|
|
|
|
|
try: |
|
|
print("=" * 60) |
|
|
print("Loading Wan2.2 I2V Engine with FP8 Quantization") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("→ [1/5] Loading base pipeline to CPU...") |
|
|
self.pipeline = WanImageToVideoPipeline.from_pretrained( |
|
|
self.MODEL_ID, |
|
|
transformer=WanTransformer3DModel.from_pretrained( |
|
|
self.TRANSFORMER_REPO, |
|
|
subfolder='transformer', |
|
|
torch_dtype=torch.bfloat16, |
|
|
), |
|
|
transformer_2=WanTransformer3DModel.from_pretrained( |
|
|
self.TRANSFORMER_REPO, |
|
|
subfolder='transformer_2', |
|
|
torch_dtype=torch.bfloat16, |
|
|
), |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
print("✓ Base pipeline loaded to CPU") |
|
|
|
|
|
|
|
|
print("→ [2/5] Loading Lightning LoRA...") |
|
|
self.pipeline.load_lora_weights( |
|
|
self.LORA_REPO, weight_name=self.LORA_WEIGHT, |
|
|
adapter_name="lightx2v" |
|
|
) |
|
|
kwargs_lora = {"load_into_transformer_2": True} |
|
|
self.pipeline.load_lora_weights( |
|
|
self.LORA_REPO, weight_name=self.LORA_WEIGHT, |
|
|
adapter_name="lightx2v_2", **kwargs_lora |
|
|
) |
|
|
self.pipeline.set_adapters( |
|
|
["lightx2v", "lightx2v_2"], |
|
|
adapter_weights=[1., 1.] |
|
|
) |
|
|
self.pipeline.fuse_lora( |
|
|
adapter_names=["lightx2v"], lora_scale=3., |
|
|
components=["transformer"] |
|
|
) |
|
|
self.pipeline.fuse_lora( |
|
|
adapter_names=["lightx2v_2"], lora_scale=1., |
|
|
components=["transformer_2"] |
|
|
) |
|
|
self.pipeline.unload_lora_weights() |
|
|
print("✓ Lightning LoRA fused") |
|
|
|
|
|
|
|
|
print("→ [3/5] Applying FP8 quantization...") |
|
|
try: |
|
|
from torchao.quantization import quantize_ |
|
|
from torchao.quantization import ( |
|
|
Float8DynamicActivationFloat8WeightConfig, |
|
|
int8_weight_only |
|
|
) |
|
|
|
|
|
|
|
|
quantize_(self.pipeline.text_encoder, int8_weight_only()) |
|
|
|
|
|
|
|
|
quantize_( |
|
|
self.pipeline.transformer, |
|
|
Float8DynamicActivationFloat8WeightConfig() |
|
|
) |
|
|
quantize_( |
|
|
self.pipeline.transformer_2, |
|
|
Float8DynamicActivationFloat8WeightConfig() |
|
|
) |
|
|
|
|
|
print("✓ FP8 quantization applied (50% memory reduction)") |
|
|
except Exception as e: |
|
|
print(f"⚠ Quantization failed: {e}") |
|
|
raise RuntimeError("FP8 quantization required for this optimized version") |
|
|
|
|
|
|
|
|
print("→ [4/5] Skipping AOTI compilation...") |
|
|
self.use_aoti = False |
|
|
print("✓ Using FP8 quantization only") |
|
|
|
|
|
|
|
|
print("→ [5/5] Moving to GPU...") |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.pipeline = self.pipeline.to('cuda') |
|
|
|
|
|
|
|
|
try: |
|
|
if hasattr(self.pipeline, 'enable_vae_tiling'): |
|
|
self.pipeline.enable_vae_tiling() |
|
|
if hasattr(self.pipeline, 'enable_vae_slicing'): |
|
|
self.pipeline.enable_vae_slicing() |
|
|
print(" • VAE tiling/slicing enabled") |
|
|
except Exception as e: |
|
|
print(f" ⚠ VAE optimizations not available: {e}") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
|
|
|
try: |
|
|
if self._check_xformers_available(): |
|
|
self.pipeline.enable_xformers_memory_efficient_attention() |
|
|
print(" • xFormers enabled") |
|
|
except: |
|
|
pass |
|
|
|
|
|
self.is_loaded = True |
|
|
print("=" * 60) |
|
|
print("✓ VideoEngine Ready") |
|
|
print(f" • Device: {self.device}") |
|
|
print(f" • Quantization: FP8 (50% memory reduction)") |
|
|
print("=" * 60) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n{'='*60}") |
|
|
print("✗ FATAL ERROR LOADING VIDEO ENGINE") |
|
|
print(f"{'='*60}") |
|
|
print(f"Error Type: {type(e).__name__}") |
|
|
print(f"Error Message: {str(e)}") |
|
|
print(f"\nFull Traceback:") |
|
|
print(traceback.format_exc()) |
|
|
print(f"{'='*60}") |
|
|
raise |
|
|
|
|
|
def resize_image(self, image: Image.Image) -> Image.Image: |
|
|
"""Resize image to fit model constraints while preserving aspect ratio.""" |
|
|
width, height = image.size |
|
|
|
|
|
if width == height: |
|
|
return image.resize((self.SQUARE_DIM, self.SQUARE_DIM), Image.LANCZOS) |
|
|
|
|
|
aspect_ratio = width / height |
|
|
MAX_ASPECT_RATIO = self.MAX_DIM / self.MIN_DIM |
|
|
MIN_ASPECT_RATIO = self.MIN_DIM / self.MAX_DIM |
|
|
|
|
|
image_to_resize = image |
|
|
|
|
|
if aspect_ratio > MAX_ASPECT_RATIO: |
|
|
target_w, target_h = self.MAX_DIM, self.MIN_DIM |
|
|
crop_width = int(round(height * MAX_ASPECT_RATIO)) |
|
|
left = (width - crop_width) // 2 |
|
|
image_to_resize = image.crop((left, 0, left + crop_width, height)) |
|
|
elif aspect_ratio < MIN_ASPECT_RATIO: |
|
|
target_w, target_h = self.MIN_DIM, self.MAX_DIM |
|
|
crop_height = int(round(width / MIN_ASPECT_RATIO)) |
|
|
top = (height - crop_height) // 2 |
|
|
image_to_resize = image.crop((0, top, width, top + crop_height)) |
|
|
else: |
|
|
if width > height: |
|
|
target_w = self.MAX_DIM |
|
|
target_h = int(round(target_w / aspect_ratio)) |
|
|
else: |
|
|
target_h = self.MAX_DIM |
|
|
target_w = int(round(target_h * aspect_ratio)) |
|
|
|
|
|
final_w = round(target_w / self.MULTIPLE_OF) * self.MULTIPLE_OF |
|
|
final_h = round(target_h / self.MULTIPLE_OF) * self.MULTIPLE_OF |
|
|
final_w = max(self.MIN_DIM, min(self.MAX_DIM, final_w)) |
|
|
final_h = max(self.MIN_DIM, min(self.MAX_DIM, final_h)) |
|
|
|
|
|
return image_to_resize.resize((final_w, final_h), Image.LANCZOS) |
|
|
|
|
|
def get_num_frames(self, duration_seconds: float) -> int: |
|
|
"""Calculate frame count from duration.""" |
|
|
return 1 + int(np.clip( |
|
|
int(round(duration_seconds * self.FIXED_FPS)), |
|
|
self.MIN_FRAMES, |
|
|
self.MAX_FRAMES, |
|
|
)) |
|
|
|
|
|
def generate_video( |
|
|
self, |
|
|
image: Image.Image, |
|
|
prompt: str, |
|
|
duration_seconds: float = 3.0, |
|
|
num_inference_steps: int = 4, |
|
|
guidance_scale: float = 1.0, |
|
|
guidance_scale_2: float = 1.0, |
|
|
seed: int = 42, |
|
|
) -> str: |
|
|
"""Generate video from image with FP8 quantization.""" |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("VideoEngine not loaded. Call load_model() first.") |
|
|
|
|
|
try: |
|
|
resized_image = self.resize_image(image) |
|
|
num_frames = self.get_num_frames(duration_seconds) |
|
|
|
|
|
print(f"\n→ Generating video:") |
|
|
print(f" • Prompt: {prompt}") |
|
|
print(f" • Resolution: {resized_image.width}x{resized_image.height}") |
|
|
print(f" • Frames: {num_frames} ({duration_seconds}s @ {self.FIXED_FPS}fps)") |
|
|
print(f" • Steps: {num_inference_steps}") |
|
|
|
|
|
|
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
generator = torch.Generator(device="cuda").manual_seed(seed) |
|
|
|
|
|
output_frames = self.pipeline( |
|
|
image=resized_image, |
|
|
prompt=prompt, |
|
|
height=resized_image.height, |
|
|
width=resized_image.width, |
|
|
num_frames=num_frames, |
|
|
guidance_scale=float(guidance_scale), |
|
|
guidance_scale_2=float(guidance_scale_2), |
|
|
num_inference_steps=int(num_inference_steps), |
|
|
generator=generator, |
|
|
).frames[0] |
|
|
|
|
|
|
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
|
output_path = os.path.join(temp_dir, f"deltaflow_{seed}.mp4") |
|
|
export_to_video(output_frames, output_path, fps=self.FIXED_FPS) |
|
|
|
|
|
print(f"✓ Video generated: {output_path}") |
|
|
return output_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n{'='*60}") |
|
|
print("✗ FATAL ERROR DURING VIDEO GENERATION") |
|
|
print(f"{'='*60}") |
|
|
print(f"Error Type: {type(e).__name__}") |
|
|
print(f"Error Message: {str(e)}") |
|
|
print(f"\nFull Traceback:") |
|
|
print(traceback.format_exc()) |
|
|
print(f"{'='*60}") |
|
|
raise |
|
|
|
|
|
def unload_model(self) -> None: |
|
|
"""Unload pipeline and free memory.""" |
|
|
if not self.is_loaded: |
|
|
return |
|
|
|
|
|
try: |
|
|
if self.pipeline is not None: |
|
|
del self.pipeline |
|
|
self.pipeline = None |
|
|
|
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.is_loaded = False |
|
|
print("✓ VideoEngine unloaded") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠ Error during unload: {str(e)}") |
|
|
|