File size: 14,384 Bytes
4edb0a5 57868fa 4edb0a5 57868fa 4edb0a5 1d6a174 c688659 1d6a174 c688659 1d6a174 4edb0a5 57868fa 4edb0a5 1d6a174 4edb0a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 | # ==================================================================================================
# ZERO-SHOT-VIDEO-GENERATION - model.py (Neural Orchestration & Architecture)
# ==================================================================================================
#
# π DESCRIPTION
# This script constitutes the core machine learning inference orchestration class (`Model`). It
# abstracts the underlying PyTorch and Diffusers infrastructure, enabling seamless initialization,
# dynamic loading, and execution of various latent diffusion models (e.g., standard Text2Video,
# ControlNet-enhanced derivations). Crucially, this script also introduces logical chunking to
# mitigate Out-Of-Memory (OOM) errors during temporal inference streams common on consumer-grade hardware.
#
# π€ AUTHORS
# - Amey Thakur (https://github.com/Amey-Thakur)
#
# π€π» CREDITS
# Based directly on the foundational logic of Text2Video-Zero.
# Source Authors: Picsart AI Research (PAIR), UT Austin, U of Oregon, UIUC
# Reference: https://arxiv.org/abs/2303.13439
#
# π PROJECT LINKS
# Repository: https://github.com/Amey-Thakur/ZERO-SHOT-VIDEO-GENERATION
# Live Demo: https://huggingface.co/spaces/ameythakur/Zero-Shot-Video-Generation
# Video Demo: https://youtu.be/za9hId6UPoY
#
# π
RELEASE DATE
# November 22, 2023
#
# π LICENSE
# Released under the MIT License
# ==================================================================================================
from enum import Enum
import gc
import os
import numpy as np
import tomesd
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel
from diffusers.schedulers import EulerAncestralDiscreteScheduler, DDIMScheduler
from text_to_video_pipeline import TextToVideoPipeline
import utils
import gradio_utils
class ModelType(Enum):
"""Enumeration identifying target diffusion frameworks supported by the Model abstractor."""
Pix2Pix_Video = 1,
Text2Video = 2,
ControlNetCanny = 3,
ControlNetCannyDB = 4,
ControlNetPose = 5,
ControlNetDepth = 6,
class Model:
"""
Primary interface for managing diffusion pipeline lifecycles, execution states, and tensor
operations bridging natural language prompts directly to multi-frame video matrices.
"""
def __init__(self, device, dtype, **kwargs):
self.device = device
self.dtype = dtype
self.generator = torch.Generator(device=device)
# Pipeline mapping dictionary to seamlessly pivot between generation contexts.
self.pipe_dict = {
ModelType.Pix2Pix_Video: StableDiffusionInstructPix2PixPipeline,
ModelType.Text2Video: TextToVideoPipeline,
ModelType.ControlNetCanny: StableDiffusionControlNetPipeline,
ModelType.ControlNetCannyDB: StableDiffusionControlNetPipeline,
ModelType.ControlNetPose: StableDiffusionControlNetPipeline,
ModelType.ControlNetDepth: StableDiffusionControlNetPipeline,
}
# Instantiation of custom Cross-Frame Attention modules to globally preserve semantic consistency
# across the latent dimension timeline.
self.controlnet_attn_proc = utils.CrossFrameAttnProcessor(
unet_chunk_size=2)
self.pix2pix_attn_proc = utils.CrossFrameAttnProcessor(
unet_chunk_size=3)
self.text2video_attn_proc = utils.CrossFrameAttnProcessor(
unet_chunk_size=2)
self.pipe = None
self.model_type = None
self.states = {}
self.model_name = ""
def set_model(self, model_type: ModelType, model_id: str, **kwargs):
"""
Dynamically initializes the selected neural structural model. Incorporates hardware cleanup
protocols to proactively mitigate CUDA fragmentations. It also supports local fallback parsing
for environments operating offline without persistent API gateways.
"""
if hasattr(self, "pipe") and self.pipe is not None:
del self.pipe
self.pipe = None
torch.cuda.empty_cache()
gc.collect()
# Offline/Local parsing capability targeting critical auxiliary features.
models_dir = os.path.join(os.getcwd(), "models")
local_safety_path = os.path.join(models_dir, "stable-diffusion-safety-checker")
local_clip_path = os.path.join(models_dir, "clip-vit-large-patch14")
if os.path.exists(local_safety_path) and 'safety_checker' not in kwargs:
print(f"Loading local safety checker from {local_safety_path}")
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
kwargs['safety_checker'] = StableDiffusionSafetyChecker.from_pretrained(local_safety_path).to(self.device).to(self.dtype)
if os.path.exists(local_clip_path) and 'feature_extractor' not in kwargs:
print(f"Loading local feature extractor from {local_clip_path}")
from transformers import CLIPImageProcessor
kwargs['feature_extractor'] = CLIPImageProcessor.from_pretrained(local_clip_path)
# Download or load defined weights directly from disk into designated precision logic.
# Pass torch_dtype explicitly to avoid loading as float16 then converting (which triggers
# a warning on CPU devices).
if 'torch_dtype' not in kwargs:
kwargs['torch_dtype'] = self.dtype
self.pipe = self.pipe_dict[model_type].from_pretrained(
model_id, **kwargs).to(self.device)
self.model_type = model_type
self.model_name = model_id
def inference_chunk(self, frame_ids, **kwargs):
"""
Executes diffusion step sequences exclusively for a segmented sub-tensor of the video frames.
Essential structure for enabling high-resolution processing avoiding continuous VRAM spikes.
"""
if not hasattr(self, "pipe") or self.pipe is None:
return
prompt = kwargs.pop('prompt', '')
if prompt is None:
prompt = ''
if isinstance(prompt, str):
prompt = [prompt] * kwargs.get('video_length', len(frame_ids))
prompt = np.array(prompt)
negative_prompt = kwargs.pop('negative_prompt', '')
if negative_prompt is None:
negative_prompt = ''
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * kwargs.get('video_length', len(frame_ids))
negative_prompt = np.array(negative_prompt)
latents = None
if 'latents' in kwargs:
latents = kwargs.pop('latents')[frame_ids]
if 'image' in kwargs:
kwargs['image'] = kwargs['image'][frame_ids]
if 'video_length' in kwargs:
kwargs['video_length'] = len(frame_ids)
if self.model_type == ModelType.Text2Video:
kwargs["frame_ids"] = frame_ids
# Dispatch bounded operations to the active Denoising Diffusion pipeline.
return self.pipe(prompt=prompt[frame_ids].tolist(),
negative_prompt=negative_prompt[frame_ids].tolist(),
latents=latents,
generator=self.generator,
**kwargs)
def inference(self, **kwargs):
"""
Evaluates execution constraints to govern memory orchestration dynamically. Either triggers standard
contiguous processing or coordinates sequential chunking of the latent frames.
"""
if not hasattr(self, "pipe") or self.pipe is None:
return
split_to_chunks = kwargs.pop('split_to_chunks', False)
chunk_size = kwargs.pop('chunk_size', 8)
video_length = kwargs.get('video_length', 8)
if split_to_chunks:
# Iterative logic parsing discrete blocks into computational pipeline, later reassembled.
import math
num_chunks = math.ceil(video_length / chunk_size)
all_frames = []
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, video_length)
frame_ids = list(range(start_idx, end_idx))
chunk_result = self.inference_chunk(frame_ids, **kwargs)
all_frames.append(chunk_result.images)
# Post-generation frame integration logic rebuilding structural tensor outputs.
import numpy as np
if isinstance(all_frames[0], np.ndarray):
combined = np.concatenate(all_frames, axis=0)
else:
combined = [img for chunk in all_frames for img in chunk]
return combined
else:
return self.pipe(**kwargs).images
def process_text2video(self,
prompt,
model_name="dreamlike-art/dreamlike-photoreal-2.0",
motion_field_strength_x=12,
motion_field_strength_y=12,
t0=44,
t1=47,
n_prompt="",
chunk_size=8,
video_length=8,
watermark='',
merging_ratio=0.0,
seed=0,
resolution=512,
fps=2,
use_cf_attn=True,
use_motion_field=True,
smooth_bg=False,
smooth_bg_strength=0.4,
path=None):
"""
Definitive API execution method initializing the complete zero-shot process lifecycle. Evaluates models,
injects positive/negative prompt constraints, forces reproducibility utilizing PRNG seeding, and leverages
structural enhancements inclusive of CF-Attn and logical motion-field mapping.
"""
print("Module Text2Video")
# --- CPU OPTIMIZATION PROTOCOL ---
# When running on CPU (e.g., free-tier HF Spaces with 2 vCPU / 16 GB RAM),
# aggressively reduce computational parameters to prevent OOM and timeouts.
is_cpu = (self.device == "cpu" or str(self.device) == "cpu")
if is_cpu:
resolution = min(resolution, 320)
video_length = min(video_length, 4)
chunk_size = min(chunk_size, 2)
num_inference_steps = 30
print(f"CPU mode: resolution={resolution}, video_length={video_length}, "
f"chunk_size={chunk_size}, num_inference_steps={num_inference_steps}")
else:
num_inference_steps = 50
if self.model_type != ModelType.Text2Video or model_name != self.model_name:
print(f"Model update to {model_name}")
# Context-aware load behavior evaluating disk storage against online fetching logic.
local_model_path = os.path.join(os.getcwd(), "models", model_name.split('/')[-1])
load_path = local_model_path if os.path.exists(local_model_path) else model_name
if os.path.exists(local_model_path):
print(f"Using local model weights from {local_model_path}")
unet = UNet2DConditionModel.from_pretrained(
load_path, subfolder="unet", torch_dtype=self.dtype)
self.set_model(ModelType.Text2Video,
model_id=load_path, unet=unet)
self.model_name = model_name # Keep the original name for state tracking
# Applying fixed Denoising Diffusion Implicit Model parameters ensuring generation alignment.
self.pipe.scheduler = DDIMScheduler.from_config(
self.pipe.scheduler.config)
if use_cf_attn:
self.pipe.unet.set_attn_processor(
processor=self.text2video_attn_proc)
self.generator.manual_seed(seed)
# Forced architectural enhancements promoting output quality irrespective of user structural definition.
added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
# Textual conditioning serialization to ensure predictable tokenized vectors.
prompt = prompt.rstrip()
if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."):
prompt = prompt.rstrip()[:-1]
prompt = prompt.rstrip()
prompt = prompt + ", "+added_prompt
if len(n_prompt) > 0:
negative_prompt = n_prompt
else:
negative_prompt = None
# Call underlying structure logic enforcing custom bounds, generating target frames sequentially.
result = self.inference(prompt=prompt,
video_length=video_length,
height=resolution,
width=resolution,
num_inference_steps=num_inference_steps,
guidance_scale=7.5,
guidance_stop_step=1.0,
t0=t0,
t1=t1,
motion_field_strength_x=motion_field_strength_x,
motion_field_strength_y=motion_field_strength_y,
use_motion_field=use_motion_field,
smooth_bg=smooth_bg,
smooth_bg_strength=smooth_bg_strength,
seed=seed,
output_type='numpy',
negative_prompt=negative_prompt,
merging_ratio=merging_ratio,
split_to_chunks=True,
chunk_size=chunk_size,
)
return utils.create_video(result, fps, path=path, watermark=gradio_utils.logo_name_to_path(watermark)) |