Spaces:
Running on Zero
Running on Zero
| from models.scheduler import SchedulerInterface | |
| from abc import abstractmethod, ABC | |
| from typing import List, Optional | |
| import torch | |
| import types | |
| class DiffusionModelInterface(ABC, torch.nn.Module): | |
| scheduler: SchedulerInterface | |
| def forward( | |
| self, noisy_image_or_video: torch.Tensor, conditional_dict: dict, | |
| timestep: torch.Tensor, kv_cache: Optional[List[dict]] = None, | |
| crossattn_cache: Optional[List[dict]] = None, | |
| current_start: Optional[int] = None, | |
| current_end: Optional[int] = None | |
| ) -> torch.Tensor: | |
| """ | |
| A method to run diffusion model. | |
| Input: | |
| - noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| - timestep: a tensor with shape [B, F] where the number of frame is 1 for images. | |
| all data should be on the same device as the model. | |
| - kv_cache: a list of dictionaries containing the key and value tensors for each attention layer. | |
| - current_start: the start index of the current frame in the sequence. | |
| - current_end: the end index of the current frame in the sequence. | |
| Output: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images. | |
| We always expect a X0 prediction form for the output. | |
| """ | |
| pass | |
| def get_scheduler(self) -> SchedulerInterface: | |
| """ | |
| Update the current scheduler with the interface's static method | |
| """ | |
| scheduler = self.scheduler | |
| scheduler.convert_x0_to_noise = types.MethodType( | |
| SchedulerInterface.convert_x0_to_noise, scheduler) | |
| scheduler.convert_noise_to_x0 = types.MethodType( | |
| SchedulerInterface.convert_noise_to_x0, scheduler) | |
| scheduler.convert_velocity_to_x0 = types.MethodType( | |
| SchedulerInterface.convert_velocity_to_x0, scheduler) | |
| self.scheduler = scheduler | |
| return scheduler | |
| def post_init(self): | |
| """ | |
| A few custom initialization steps that should be called after the object is created. | |
| Currently, the only one we have is to bind a few methods to scheduler. | |
| We can gradually add more methods here if needed. | |
| """ | |
| self.get_scheduler() | |
| def set_module_grad(self, module_grad: dict) -> None: | |
| """ | |
| Adjusts the state of each module in the object. | |
| Parameters: | |
| - module_grad (dict): A dictionary where each key is the name of a module (as an attribute of the object), | |
| and each value is a bool indicating whether the module's parameters require gradients. | |
| Functionality: | |
| For each module name in the dictionary: | |
| - Updates whether its parameters require gradients based on 'is_trainable'. | |
| """ | |
| for k, is_trainable in module_grad.items(): | |
| getattr(self, k).requires_grad_(is_trainable) | |
| def enable_gradient_checkpointing(self) -> None: | |
| """ | |
| Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or | |
| *checkpoint activations* in other frameworks). | |
| """ | |
| pass | |
| class VAEInterface(ABC, torch.nn.Module): | |
| def decode_to_pixel(self, latent: torch.Tensor) -> torch.Tensor: | |
| """ | |
| A method to decode a latent representation to an image or video. | |
| Input: a tensor with shape [B, F // T, C, H // S, W // S] where T and S are temporal and spatial compression factors. | |
| Output: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images. | |
| """ | |
| pass | |
| class TextEncoderInterface(ABC, torch.nn.Module): | |
| def forward(self, text_prompts: List[str]) -> dict: | |
| """ | |
| A method to tokenize text prompts with a tokenizer and encode them into a latent representation. | |
| Input: a list of strings. | |
| Output: a dictionary containing the encoded representation of the text prompts. | |
| """ | |
| pass | |
| class InferencePipelineInterface(ABC): | |
| def inference_with_trajectory(self, noise: torch.Tensor, conditional_dict: dict) -> torch.Tensor: | |
| """ | |
| Run inference with the given diffusion / distilled generators. | |
| Input: | |
| - noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| Output: | |
| - output: a tensor with shape [B, T, F, C, H, W]. | |
| T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0 | |
| represents the x0 prediction at each timestep. | |
| """ | |