ameythakur's picture
text2video
c688659 verified
# ==================================================================================================
# ZERO-SHOT-VIDEO-GENERATION - model.py (Neural Orchestration & Architecture)
# ==================================================================================================
#
# πŸ“ DESCRIPTION
# This script constitutes the core machine learning inference orchestration class (`Model`). It
# abstracts the underlying PyTorch and Diffusers infrastructure, enabling seamless initialization,
# dynamic loading, and execution of various latent diffusion models (e.g., standard Text2Video,
# ControlNet-enhanced derivations). Crucially, this script also introduces logical chunking to
# mitigate Out-Of-Memory (OOM) errors during temporal inference streams common on consumer-grade hardware.
#
# πŸ‘€ AUTHORS
# - Amey Thakur (https://github.com/Amey-Thakur)
#
# 🀝🏻 CREDITS
# Based directly on the foundational logic of Text2Video-Zero.
# Source Authors: Picsart AI Research (PAIR), UT Austin, U of Oregon, UIUC
# Reference: https://arxiv.org/abs/2303.13439
#
# πŸ”— PROJECT LINKS
# Repository: https://github.com/Amey-Thakur/ZERO-SHOT-VIDEO-GENERATION
# Live Demo: https://huggingface.co/spaces/ameythakur/Zero-Shot-Video-Generation
# Video Demo: https://youtu.be/za9hId6UPoY
#
# πŸ“… RELEASE DATE
# November 22, 2023
#
# πŸ“œ LICENSE
# Released under the MIT License
# ==================================================================================================
from enum import Enum
import gc
import os
import numpy as np
import tomesd
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel
from diffusers.schedulers import EulerAncestralDiscreteScheduler, DDIMScheduler
from text_to_video_pipeline import TextToVideoPipeline
import utils
import gradio_utils
class ModelType(Enum):
"""Enumeration identifying target diffusion frameworks supported by the Model abstractor."""
Pix2Pix_Video = 1,
Text2Video = 2,
ControlNetCanny = 3,
ControlNetCannyDB = 4,
ControlNetPose = 5,
ControlNetDepth = 6,
class Model:
"""
Primary interface for managing diffusion pipeline lifecycles, execution states, and tensor
operations bridging natural language prompts directly to multi-frame video matrices.
"""
def __init__(self, device, dtype, **kwargs):
self.device = device
self.dtype = dtype
self.generator = torch.Generator(device=device)
# Pipeline mapping dictionary to seamlessly pivot between generation contexts.
self.pipe_dict = {
ModelType.Pix2Pix_Video: StableDiffusionInstructPix2PixPipeline,
ModelType.Text2Video: TextToVideoPipeline,
ModelType.ControlNetCanny: StableDiffusionControlNetPipeline,
ModelType.ControlNetCannyDB: StableDiffusionControlNetPipeline,
ModelType.ControlNetPose: StableDiffusionControlNetPipeline,
ModelType.ControlNetDepth: StableDiffusionControlNetPipeline,
}
# Instantiation of custom Cross-Frame Attention modules to globally preserve semantic consistency
# across the latent dimension timeline.
self.controlnet_attn_proc = utils.CrossFrameAttnProcessor(
unet_chunk_size=2)
self.pix2pix_attn_proc = utils.CrossFrameAttnProcessor(
unet_chunk_size=3)
self.text2video_attn_proc = utils.CrossFrameAttnProcessor(
unet_chunk_size=2)
self.pipe = None
self.model_type = None
self.states = {}
self.model_name = ""
def set_model(self, model_type: ModelType, model_id: str, **kwargs):
"""
Dynamically initializes the selected neural structural model. Incorporates hardware cleanup
protocols to proactively mitigate CUDA fragmentations. It also supports local fallback parsing
for environments operating offline without persistent API gateways.
"""
if hasattr(self, "pipe") and self.pipe is not None:
del self.pipe
self.pipe = None
torch.cuda.empty_cache()
gc.collect()
# Offline/Local parsing capability targeting critical auxiliary features.
models_dir = os.path.join(os.getcwd(), "models")
local_safety_path = os.path.join(models_dir, "stable-diffusion-safety-checker")
local_clip_path = os.path.join(models_dir, "clip-vit-large-patch14")
if os.path.exists(local_safety_path) and 'safety_checker' not in kwargs:
print(f"Loading local safety checker from {local_safety_path}")
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
kwargs['safety_checker'] = StableDiffusionSafetyChecker.from_pretrained(local_safety_path).to(self.device).to(self.dtype)
if os.path.exists(local_clip_path) and 'feature_extractor' not in kwargs:
print(f"Loading local feature extractor from {local_clip_path}")
from transformers import CLIPImageProcessor
kwargs['feature_extractor'] = CLIPImageProcessor.from_pretrained(local_clip_path)
# Download or load defined weights directly from disk into designated precision logic.
# Pass torch_dtype explicitly to avoid loading as float16 then converting (which triggers
# a warning on CPU devices).
if 'torch_dtype' not in kwargs:
kwargs['torch_dtype'] = self.dtype
self.pipe = self.pipe_dict[model_type].from_pretrained(
model_id, **kwargs).to(self.device)
self.model_type = model_type
self.model_name = model_id
def inference_chunk(self, frame_ids, **kwargs):
"""
Executes diffusion step sequences exclusively for a segmented sub-tensor of the video frames.
Essential structure for enabling high-resolution processing avoiding continuous VRAM spikes.
"""
if not hasattr(self, "pipe") or self.pipe is None:
return
prompt = kwargs.pop('prompt', '')
if prompt is None:
prompt = ''
if isinstance(prompt, str):
prompt = [prompt] * kwargs.get('video_length', len(frame_ids))
prompt = np.array(prompt)
negative_prompt = kwargs.pop('negative_prompt', '')
if negative_prompt is None:
negative_prompt = ''
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * kwargs.get('video_length', len(frame_ids))
negative_prompt = np.array(negative_prompt)
latents = None
if 'latents' in kwargs:
latents = kwargs.pop('latents')[frame_ids]
if 'image' in kwargs:
kwargs['image'] = kwargs['image'][frame_ids]
if 'video_length' in kwargs:
kwargs['video_length'] = len(frame_ids)
if self.model_type == ModelType.Text2Video:
kwargs["frame_ids"] = frame_ids
# Dispatch bounded operations to the active Denoising Diffusion pipeline.
return self.pipe(prompt=prompt[frame_ids].tolist(),
negative_prompt=negative_prompt[frame_ids].tolist(),
latents=latents,
generator=self.generator,
**kwargs)
def inference(self, **kwargs):
"""
Evaluates execution constraints to govern memory orchestration dynamically. Either triggers standard
contiguous processing or coordinates sequential chunking of the latent frames.
"""
if not hasattr(self, "pipe") or self.pipe is None:
return
split_to_chunks = kwargs.pop('split_to_chunks', False)
chunk_size = kwargs.pop('chunk_size', 8)
video_length = kwargs.get('video_length', 8)
if split_to_chunks:
# Iterative logic parsing discrete blocks into computational pipeline, later reassembled.
import math
num_chunks = math.ceil(video_length / chunk_size)
all_frames = []
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, video_length)
frame_ids = list(range(start_idx, end_idx))
chunk_result = self.inference_chunk(frame_ids, **kwargs)
all_frames.append(chunk_result.images)
# Post-generation frame integration logic rebuilding structural tensor outputs.
import numpy as np
if isinstance(all_frames[0], np.ndarray):
combined = np.concatenate(all_frames, axis=0)
else:
combined = [img for chunk in all_frames for img in chunk]
return combined
else:
return self.pipe(**kwargs).images
def process_text2video(self,
prompt,
model_name="dreamlike-art/dreamlike-photoreal-2.0",
motion_field_strength_x=12,
motion_field_strength_y=12,
t0=44,
t1=47,
n_prompt="",
chunk_size=8,
video_length=8,
watermark='',
merging_ratio=0.0,
seed=0,
resolution=512,
fps=2,
use_cf_attn=True,
use_motion_field=True,
smooth_bg=False,
smooth_bg_strength=0.4,
path=None):
"""
Definitive API execution method initializing the complete zero-shot process lifecycle. Evaluates models,
injects positive/negative prompt constraints, forces reproducibility utilizing PRNG seeding, and leverages
structural enhancements inclusive of CF-Attn and logical motion-field mapping.
"""
print("Module Text2Video")
# --- CPU OPTIMIZATION PROTOCOL ---
# When running on CPU (e.g., free-tier HF Spaces with 2 vCPU / 16 GB RAM),
# aggressively reduce computational parameters to prevent OOM and timeouts.
is_cpu = (self.device == "cpu" or str(self.device) == "cpu")
if is_cpu:
resolution = min(resolution, 320)
video_length = min(video_length, 4)
chunk_size = min(chunk_size, 2)
num_inference_steps = 30
print(f"CPU mode: resolution={resolution}, video_length={video_length}, "
f"chunk_size={chunk_size}, num_inference_steps={num_inference_steps}")
else:
num_inference_steps = 50
if self.model_type != ModelType.Text2Video or model_name != self.model_name:
print(f"Model update to {model_name}")
# Context-aware load behavior evaluating disk storage against online fetching logic.
local_model_path = os.path.join(os.getcwd(), "models", model_name.split('/')[-1])
load_path = local_model_path if os.path.exists(local_model_path) else model_name
if os.path.exists(local_model_path):
print(f"Using local model weights from {local_model_path}")
unet = UNet2DConditionModel.from_pretrained(
load_path, subfolder="unet", torch_dtype=self.dtype)
self.set_model(ModelType.Text2Video,
model_id=load_path, unet=unet)
self.model_name = model_name # Keep the original name for state tracking
# Applying fixed Denoising Diffusion Implicit Model parameters ensuring generation alignment.
self.pipe.scheduler = DDIMScheduler.from_config(
self.pipe.scheduler.config)
if use_cf_attn:
self.pipe.unet.set_attn_processor(
processor=self.text2video_attn_proc)
self.generator.manual_seed(seed)
# Forced architectural enhancements promoting output quality irrespective of user structural definition.
added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
# Textual conditioning serialization to ensure predictable tokenized vectors.
prompt = prompt.rstrip()
if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."):
prompt = prompt.rstrip()[:-1]
prompt = prompt.rstrip()
prompt = prompt + ", "+added_prompt
if len(n_prompt) > 0:
negative_prompt = n_prompt
else:
negative_prompt = None
# Call underlying structure logic enforcing custom bounds, generating target frames sequentially.
result = self.inference(prompt=prompt,
video_length=video_length,
height=resolution,
width=resolution,
num_inference_steps=num_inference_steps,
guidance_scale=7.5,
guidance_stop_step=1.0,
t0=t0,
t1=t1,
motion_field_strength_x=motion_field_strength_x,
motion_field_strength_y=motion_field_strength_y,
use_motion_field=use_motion_field,
smooth_bg=smooth_bg,
smooth_bg_strength=smooth_bg_strength,
seed=seed,
output_type='numpy',
negative_prompt=negative_prompt,
merging_ratio=merging_ratio,
split_to_chunks=True,
chunk_size=chunk_size,
)
return utils.create_video(result, fps, path=path, watermark=gradio_utils.logo_name_to_path(watermark))