StreamDiffusionV2-Realtime / models /model_interface.py
multimodalart's picture
multimodalart HF Staff
Upload folder using huggingface_hub
5c93746 verified
Raw
History Blame Contribute Delete
4.95 kB
from models.scheduler import SchedulerInterface
from abc import abstractmethod, ABC
from typing import List, Optional
import torch
import types
class DiffusionModelInterface(ABC, torch.nn.Module):
scheduler: SchedulerInterface
@abstractmethod
def forward(
self, noisy_image_or_video: torch.Tensor, conditional_dict: dict,
timestep: torch.Tensor, kv_cache: Optional[List[dict]] = None,
crossattn_cache: Optional[List[dict]] = None,
current_start: Optional[int] = None,
current_end: Optional[int] = None
) -> torch.Tensor:
"""
A method to run diffusion model.
Input:
- noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- timestep: a tensor with shape [B, F] where the number of frame is 1 for images.
all data should be on the same device as the model.
- kv_cache: a list of dictionaries containing the key and value tensors for each attention layer.
- current_start: the start index of the current frame in the sequence.
- current_end: the end index of the current frame in the sequence.
Output: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
We always expect a X0 prediction form for the output.
"""
pass
def get_scheduler(self) -> SchedulerInterface:
"""
Update the current scheduler with the interface's static method
"""
scheduler = self.scheduler
scheduler.convert_x0_to_noise = types.MethodType(
SchedulerInterface.convert_x0_to_noise, scheduler)
scheduler.convert_noise_to_x0 = types.MethodType(
SchedulerInterface.convert_noise_to_x0, scheduler)
scheduler.convert_velocity_to_x0 = types.MethodType(
SchedulerInterface.convert_velocity_to_x0, scheduler)
self.scheduler = scheduler
return scheduler
def post_init(self):
"""
A few custom initialization steps that should be called after the object is created.
Currently, the only one we have is to bind a few methods to scheduler.
We can gradually add more methods here if needed.
"""
self.get_scheduler()
def set_module_grad(self, module_grad: dict) -> None:
"""
Adjusts the state of each module in the object.
Parameters:
- module_grad (dict): A dictionary where each key is the name of a module (as an attribute of the object),
and each value is a bool indicating whether the module's parameters require gradients.
Functionality:
For each module name in the dictionary:
- Updates whether its parameters require gradients based on 'is_trainable'.
"""
for k, is_trainable in module_grad.items():
getattr(self, k).requires_grad_(is_trainable)
@abstractmethod
def enable_gradient_checkpointing(self) -> None:
"""
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
"""
pass
class VAEInterface(ABC, torch.nn.Module):
@abstractmethod
def decode_to_pixel(self, latent: torch.Tensor) -> torch.Tensor:
"""
A method to decode a latent representation to an image or video.
Input: a tensor with shape [B, F // T, C, H // S, W // S] where T and S are temporal and spatial compression factors.
Output: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
"""
pass
class TextEncoderInterface(ABC, torch.nn.Module):
@abstractmethod
def forward(self, text_prompts: List[str]) -> dict:
"""
A method to tokenize text prompts with a tokenizer and encode them into a latent representation.
Input: a list of strings.
Output: a dictionary containing the encoded representation of the text prompts.
"""
pass
class InferencePipelineInterface(ABC):
@abstractmethod
def inference_with_trajectory(self, noise: torch.Tensor, conditional_dict: dict) -> torch.Tensor:
"""
Run inference with the given diffusion / distilled generators.
Input:
- noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
Output:
- output: a tensor with shape [B, T, F, C, H, W].
T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0
represents the x0 prediction at each timestep.
"""