Spaces:
Running on Zero
Running on Zero
File size: 4,952 Bytes
5c93746 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | from models.scheduler import SchedulerInterface
from abc import abstractmethod, ABC
from typing import List, Optional
import torch
import types
class DiffusionModelInterface(ABC, torch.nn.Module):
scheduler: SchedulerInterface
@abstractmethod
def forward(
self, noisy_image_or_video: torch.Tensor, conditional_dict: dict,
timestep: torch.Tensor, kv_cache: Optional[List[dict]] = None,
crossattn_cache: Optional[List[dict]] = None,
current_start: Optional[int] = None,
current_end: Optional[int] = None
) -> torch.Tensor:
"""
A method to run diffusion model.
Input:
- noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- timestep: a tensor with shape [B, F] where the number of frame is 1 for images.
all data should be on the same device as the model.
- kv_cache: a list of dictionaries containing the key and value tensors for each attention layer.
- current_start: the start index of the current frame in the sequence.
- current_end: the end index of the current frame in the sequence.
Output: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
We always expect a X0 prediction form for the output.
"""
pass
def get_scheduler(self) -> SchedulerInterface:
"""
Update the current scheduler with the interface's static method
"""
scheduler = self.scheduler
scheduler.convert_x0_to_noise = types.MethodType(
SchedulerInterface.convert_x0_to_noise, scheduler)
scheduler.convert_noise_to_x0 = types.MethodType(
SchedulerInterface.convert_noise_to_x0, scheduler)
scheduler.convert_velocity_to_x0 = types.MethodType(
SchedulerInterface.convert_velocity_to_x0, scheduler)
self.scheduler = scheduler
return scheduler
def post_init(self):
"""
A few custom initialization steps that should be called after the object is created.
Currently, the only one we have is to bind a few methods to scheduler.
We can gradually add more methods here if needed.
"""
self.get_scheduler()
def set_module_grad(self, module_grad: dict) -> None:
"""
Adjusts the state of each module in the object.
Parameters:
- module_grad (dict): A dictionary where each key is the name of a module (as an attribute of the object),
and each value is a bool indicating whether the module's parameters require gradients.
Functionality:
For each module name in the dictionary:
- Updates whether its parameters require gradients based on 'is_trainable'.
"""
for k, is_trainable in module_grad.items():
getattr(self, k).requires_grad_(is_trainable)
@abstractmethod
def enable_gradient_checkpointing(self) -> None:
"""
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
"""
pass
class VAEInterface(ABC, torch.nn.Module):
@abstractmethod
def decode_to_pixel(self, latent: torch.Tensor) -> torch.Tensor:
"""
A method to decode a latent representation to an image or video.
Input: a tensor with shape [B, F // T, C, H // S, W // S] where T and S are temporal and spatial compression factors.
Output: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
"""
pass
class TextEncoderInterface(ABC, torch.nn.Module):
@abstractmethod
def forward(self, text_prompts: List[str]) -> dict:
"""
A method to tokenize text prompts with a tokenizer and encode them into a latent representation.
Input: a list of strings.
Output: a dictionary containing the encoded representation of the text prompts.
"""
pass
class InferencePipelineInterface(ABC):
@abstractmethod
def inference_with_trajectory(self, noise: torch.Tensor, conditional_dict: dict) -> torch.Tensor:
"""
Run inference with the given diffusion / distilled generators.
Input:
- noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
Output:
- output: a tensor with shape [B, T, F, C, H, W].
T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0
represents the x0 prediction at each timestep.
"""
|