| |
|
|
| from dataclasses import dataclass |
| from functools import partial |
| import logging |
| import math |
| import typing as tp |
|
|
| import torch |
| from torch import nn |
|
|
| from ..utils import utils |
| from ..modules.streaming import StreamingModule, State |
| from ..modules.transformer import StreamingTransformer, create_norm_fn |
|
|
| import time |
| from ..modules.conditioners import ( |
| ConditionFuser, |
| ClassifierFreeGuidanceDropout, |
| AttributeDropout, |
| ConditioningProvider, |
| ConditioningAttributes, |
| ConditionType, |
| ) |
| from ..modules.codebooks_patterns import CodebooksPatternProvider |
| from ..modules.activations import get_activation_fn |
| import warnings |
| warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms._transforms_video") |
| import torch.nn.init as init |
| import os |
|
|
| import logging |
| import random |
| import sys |
| import einops |
| from .transformer_module import Attention, PreNorm, FeedForward |
| from transformers import AutoProcessor, CLIPVisionModelWithProjection, VideoMAEModel |
|
|
| logger = logging.getLogger(__name__) |
| ConditionTensors = tp.Dict[str, ConditionType] |
| CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]] |
|
|
| def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): |
| """LM layer initialization. |
| Inspired from xlformers: https://github.com/fairinternal/xlformers |
| |
| Args: |
| method (str): Method name for init function. Valid options are: |
| 'gaussian', 'uniform'. |
| input_dim (int): Input dimension of the initialized module. |
| init_depth (int, optional): Optional init depth value used to rescale |
| the standard deviation if defined. |
| """ |
| |
| std = 1 / math.sqrt(input_dim) |
| |
| if init_depth is not None: |
| std = std / math.sqrt(2 * init_depth) |
|
|
| if method == 'gaussian': |
| return partial( |
| torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std |
| ) |
| elif method == 'uniform': |
| bound = math.sqrt(3) * std |
| return partial(torch.nn.init.uniform_, a=-bound, b=bound) |
| else: |
| raise ValueError("Unsupported layer initialization method") |
|
|
|
|
| def init_layer(m: nn.Module, |
| method: str, |
| init_depth: tp.Optional[int] = None, |
| zero_bias_init: bool = False): |
| """Wrapper around `get_init_fn for proper initialization of LM modules. |
| |
| Args: |
| m (nn.Module): Module to initialize. |
| method (str): Method name for the init function. |
| init_depth (int, optional): Optional init depth value used to rescale |
| the standard deviation if defined. |
| zero_bias_init (bool): Whether to initialize the bias to 0 or not. |
| """ |
| if isinstance(m, nn.Linear): |
| init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) |
| if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: |
| weight = m.weight.float() |
| init_fn(weight) |
| m.weight.data[:] = weight.half() |
| else: |
| init_fn(m.weight) |
| if zero_bias_init and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.Embedding): |
| init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) |
| if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: |
| weight = m.weight.float() |
| init_fn(weight) |
| m.weight.data[:] = weight.half() |
| else: |
| init_fn(m.weight) |
|
|
|
|
| class ScaledEmbedding(nn.Embedding): |
| """Boost learning rate for embeddings (with scale). |
| """ |
| def __init__(self, *args, lr=None, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.lr = lr |
|
|
| def make_optim_group(self): |
| group = {"params": list(self.parameters())} |
| if self.lr is not None: |
| group["lr"] = self.lr |
| return group |
|
|
|
|
| @dataclass |
| class LMOutput: |
| |
| |
| logits: torch.Tensor |
| mask: torch.Tensor |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): |
| super().__init__() |
| self.layers = nn.ModuleList([]) |
| self.norm = nn.LayerNorm(dim) |
| for _ in range(depth): |
| self.layers.append(nn.ModuleList([ |
| PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), |
| PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) |
| ])) |
|
|
| def forward(self, x): |
| for attn, ff in self.layers: |
| x = attn(x) + x |
| x = ff(x) + x |
| return self.norm(x) |
|
|
|
|
| class MultiHeadCrossAttention(nn.Module): |
| def __init__(self, x1, num_heads): |
| super().__init__() |
| self.num_heads = num_heads |
| self.depth = x1 // num_heads |
|
|
| self.query = nn.Linear(x1, x1) |
| self.key = nn.Linear(x1, x1) |
| self.value = nn.Linear(x1, x1) |
|
|
| self.final_linear = nn.Linear(x1, x1) |
|
|
| self.norm1 = nn.LayerNorm(x1) |
| self.norm2 = nn.LayerNorm(x1) |
| |
| init.constant_(self.final_linear.weight, 0) |
| if self.final_linear.bias is not None: |
| init.constant_(self.final_linear.bias, 0) |
| |
| def split_heads(self, x, batch_size): |
| x = x.view(batch_size, -1, self.num_heads, self.depth) |
| return x.permute(0, 2, 1, 3) |
|
|
| def forward(self, tensor_A, tensor_B): |
| batch_size = tensor_A.size(0) |
|
|
| Q = self.split_heads(self.query(tensor_A), batch_size) |
| K = self.split_heads(self.key(tensor_B), batch_size) |
| V = self.split_heads(self.value(tensor_B), batch_size) |
|
|
| attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.depth ** 0.5) |
| attention_scores = torch.softmax(attention_scores, dim=-1) |
|
|
| attention_output = torch.matmul(attention_scores, V) |
| attention_output = attention_output.permute(0, 2, 1, 3).contiguous() |
|
|
| output = attention_output.view(batch_size, -1, self.num_heads * self.depth) |
| |
| output = self.norm1(output + tensor_A) |
| output = self.norm2(self.final_linear(output) + output) |
| return output |
|
|
|
|
| def evenly_sample_or_duplicate_frames(video_tensor, target_frames=32): |
| num_frames = video_tensor.size(0) |
| if target_frames <= num_frames: |
| indices = torch.linspace(0, num_frames - 1, steps=target_frames).long() |
| return video_tensor[indices] |
| else: |
| scale_factor = target_frames / num_frames |
| repeated_indices = (torch.arange(target_frames) / scale_factor).long() |
| return video_tensor[repeated_indices] |
| |
| class LMModel(StreamingModule): |
| """Transformer-based language model on multiple streams of codes. |
| |
| Args: |
| pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. |
| condition_provider (MusicConditioningProvider): Conditioning provider from metadata. |
| fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. |
| n_q (int): Number of parallel streams to model. |
| card (int): Cardinality, vocabulary size. |
| dim (int): Dimension of the transformer encoder. |
| num_heads (int): Number of heads for the transformer encoder. |
| hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. |
| norm (str): Normalization method. |
| norm_first (bool): Use pre-norm instead of post-norm. |
| emb_lr (float, optional): Embedding-specific learning rate. |
| bias_proj (bool): Use bias for output projections. |
| weight_init (str, optional): Method for weight initialization. |
| depthwise_init (str, optional): Method for depthwise weight initialization. |
| zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. |
| cfg_dropout (float): Classifier-free guidance dropout. |
| cfg_coef (float): Classifier-free guidance coefficient. |
| attribute_dropout (dict): Attribute dropout probabilities. |
| two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. |
| **kwargs: Additional parameters for the transformer encoder. |
| """ |
|
|
| def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider, |
| visual_encoder, |
| if_add_gobal, |
| fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8, |
| hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False, |
| emb_lr: tp.Optional[float] = None, bias_proj: bool = True, |
| weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None, |
| zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, |
| attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False, |
| depth=2, |
| temporal_dim=768, |
| dim_head=64, |
| **kwargs): |
| super().__init__() |
| self.cfg_coef = cfg_coef |
| self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout) |
| self.att_dropout = AttributeDropout(p=attribute_dropout) |
| self.condition_provider = condition_provider |
| self.visual_encoder = visual_encoder |
| self.if_add_gobal = if_add_gobal |
| self.temporal_dim = temporal_dim |
| |
| self.fuser = fuser |
| self.card = card |
| embed_dim = self.card + 1 |
| self.n_q = n_q |
| self.dim = dim |
| self.pattern_provider = pattern_provider |
| self.two_step_cfg = two_step_cfg |
| self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)]) |
| if 'activation' in kwargs: |
| kwargs['activation'] = get_activation_fn(kwargs['activation']) |
| self.transformer = StreamingTransformer( |
| d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim), |
| norm=norm, norm_first=norm_first, **kwargs) |
| |
| |
| self.out_norm: tp.Optional[nn.Module] = None |
| if norm_first: |
| self.out_norm = create_norm_fn(norm, dim) |
| self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)]) |
| self._init_weights(weight_init, depthwise_init, zero_bias_init) |
| self._fsdp: tp.Optional[nn.Module] |
| self.__dict__['_fsdp'] = None |
| |
| if self.visual_encoder == 'clip': |
| self.visual_encoder_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") |
| self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| |
| else: |
| print(f'the encoder now is:{self.visual_encoder}') |
| print(f'please input the right video encoder.') |
| exit() |
| |
| if self.visual_encoder == 'clip': |
| temporal_dim = 768 |
| self.local_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim)) |
| self.visual_encoder_model = self.visual_encoder_model.eval() |
| for param in self.visual_encoder_model.parameters(): |
| param.requires_grad = False |
|
|
| self.local_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) |
|
|
| if self.if_add_gobal: |
| if self.visual_encoder == 'clip': |
| self.global_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim)) |
|
|
| self.global_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) |
| |
| cross_attention_num_heads = 3 |
| self.multi_head_cross_attention = MultiHeadCrossAttention(temporal_dim, cross_attention_num_heads) |
| |
| self.visual_feature_proj = nn.Linear(temporal_dim, dim) |
|
|
|
|
| def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): |
| """Initialization of the transformer module weights. |
| |
| Args: |
| weight_init (str, optional): Weight initialization strategy. See `get_init_fn for valid options. |
| depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: |
| 'current' where the depth corresponds to the current layer index or 'global' where the total number |
| of layer is used as depth. If not set, no depthwise initialization strategy is used. |
| zero_bias_init (bool): Whether to initialize bias to zero or not. |
| """ |
| assert depthwise_init is None or depthwise_init in ['current', 'global'] |
| assert depthwise_init is None or weight_init is not None, \ |
| "If 'depthwise_init' is defined, a 'weight_init' method should be provided." |
| assert not zero_bias_init or weight_init is not None, \ |
| "If 'zero_bias_init', a 'weight_init' method should be provided" |
|
|
| if weight_init is None: |
| return |
|
|
| for emb_layer in self.emb: |
| init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) |
|
|
| for layer_idx, tr_layer in enumerate(self.transformer.layers): |
| depth = None |
| if depthwise_init == 'current': |
| depth = layer_idx + 1 |
| elif depthwise_init == 'global': |
| depth = len(self.transformer.layers) |
| init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init) |
| tr_layer.apply(init_fn) |
|
|
| for linear in self.linears: |
| init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) |
|
|
|
|
| @property |
| def special_token_id(self) -> int: |
| return self.card |
|
|
| @property |
| def num_codebooks(self) -> int: |
| return self.n_q |
|
|
| def compute_video_emb(self, video_tensor_list: tp.List, device: str) -> torch.Tensor: |
| assert isinstance(video_tensor_list, list) |
| assert self.if_add_gobal |
| assert len(video_tensor_list) == 2 |
|
|
| [local_video_tensor, global_video_tensor] = video_tensor_list |
| local_image = local_video_tensor.to(dtype=torch.float32) |
| global_image = global_video_tensor.to(dtype=torch.float32) |
|
|
| local_batch_size, _, local_time_length, _, _ = local_image.size() |
| local_image = einops.rearrange(local_image, 'b c t h w -> (b t) c h w') |
|
|
| global_batch_size, _, global_time_length, _, _ = global_image.size() |
| global_image = einops.rearrange(global_image, 'b c t h w -> (b t) c h w') |
|
|
| local_temporal_transformer = self.local_temporal_transformer |
| global_temporal_transformer = self.global_temporal_transformer |
|
|
| local_video_inputs = self.processor(images=local_image.float(), return_tensors="pt") |
| local_pixel_values = local_video_inputs['pixel_values'].to(device) |
|
|
| global_video_inputs = self.processor(images=global_image.float(), return_tensors="pt") |
| global_pixel_values = global_video_inputs['pixel_values'].to(device) |
|
|
| if self.visual_encoder == 'clip': |
| with torch.no_grad(): |
| local_video_hidden = self.visual_encoder_model(pixel_values=local_pixel_values).last_hidden_state |
| local_video_hidden += self.local_pos_embedding |
| local_video_hidden = local_temporal_transformer(local_video_hidden) |
| local_video_hidden = einops.rearrange( |
| local_video_hidden, '(b t) q h -> b (t q) h', |
| b=local_batch_size, t=local_time_length |
| ) |
|
|
| with torch.no_grad(): |
| global_video_hidden = self.visual_encoder_model(pixel_values=global_pixel_values).last_hidden_state |
| global_video_hidden += self.global_pos_embedding |
| global_video_hidden = global_temporal_transformer(global_video_hidden) |
| global_video_hidden = einops.rearrange( |
| global_video_hidden, '(b t) q h -> b (t q) h', |
| b=global_batch_size, t=global_time_length |
| ) |
|
|
| video_hidden = self.multi_head_cross_attention(local_video_hidden, global_video_hidden) |
| video_emb = self.visual_feature_proj(video_hidden) |
|
|
| return video_emb |
|
|
|
|
| def forward(self, sequence: torch.Tensor, |
| conditions: tp.List[ConditioningAttributes], |
| video_tensor_list: tp.List, |
| precomputed_video_emb: tp.Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
|
|
| B, K, S = sequence.shape |
| assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks" |
| input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)]) |
| self.device = input_.device |
| assert self.device != "cpu" |
|
|
| if precomputed_video_emb is None: |
| video_emb = self.compute_video_emb(video_tensor_list, device=self.device) |
| else: |
| video_emb = precomputed_video_emb |
|
|
| out = self.transformer(input_, cross_attention_src=video_emb) |
| if self.out_norm: |
| out = self.out_norm(out) |
| logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) |
|
|
| if len(self.fuser.fuse2cond['prepend']) > 0: |
| logits = logits[:, :, -S:] |
| return logits |
|
|
|
|
| def compute_predictions( |
| self, codes: torch.Tensor, |
| conditions: tp.List[ConditioningAttributes], |
| condition_tensors_list: tp.List) -> LMOutput: |
| """Given an input tensor of codes [B, K, T] and list of conditions, runs the model |
| forward using the specified codes interleaving pattern. |
| |
| Args: |
| codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, |
| K the number of codebooks and T the number of timesteps. |
| conditions (list of ConditioningAttributes): Conditions to use when modeling |
| the given codes. Note that when evaluating multiple time with the same conditioning |
| you should pre-compute those and pass them as condition_tensors. |
| condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning |
| tensors, see conditions. |
| Returns: |
| LMOutput: Language model outputs |
| logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, |
| i.e. the first item corresponds to logits to predict the first code, meaning that |
| no additional shifting of codes and logits is required. |
| mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. |
| Given the specified interleaving strategies, parts of the logits and codes should |
| not be considered as valid predictions because of invalid context. |
| """ |
| B, K, T = codes.shape |
| codes = codes.contiguous() |
| |
| assert isinstance(condition_tensors_list, list) |
| pattern = self.pattern_provider.get_pattern(T) |
| sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence( |
| codes, self.special_token_id, keep_only_valid_steps=True |
| ) |
| |
| model = self if self._fsdp is None else self._fsdp |
| logits = model(sequence_codes, conditions, condition_tensors_list) |
|
|
|
|
| logits = logits.permute(0, 3, 1, 2) |
| logits, logits_indexes, logits_mask = pattern.revert_pattern_logits( |
| logits, float('nan'), keep_only_valid_steps=True |
| ) |
| logits = logits.permute(0, 2, 3, 1) |
| logits_mask = logits_mask[None, :, :].expand(B, -1, -1) |
| return LMOutput(logits, logits_mask) |
|
|
|
|
| def _sample_next_token( |
| self, |
| sequence: torch.Tensor, |
| cfg_conditions_list: tp.List, |
| unconditional_state: State, |
| use_sampling: bool = False, |
| temp: float = 1.0, |
| top_k: int = 0, |
| top_p: float = 0.0, |
| cfg_coef: tp.Optional[float] = None, |
| two_step_cfg: tp.Optional[bool] = None, |
| precomputed_video_emb: tp.Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """Sample next token from the model given a sequence and a set of conditions. The model supports |
| multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). |
| |
| Args: |
| sequence (torch.Tensor): Current sequence of shape [B, K, S] |
| with K corresponding to the number of codebooks and S the number of sequence steps. |
| S = 1 in streaming mode, except for the first step that contains a bigger prompt. |
| condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used, |
| should be twice the batch size, being the concatenation of the conditions + null conditions. |
| use_sampling (bool): Whether to use a sampling strategy or not. |
| temp (float): Sampling temperature. |
| top_k (int): K for "top-k" sampling. |
| top_p (float): P for "top-p" sampling. |
| cfg_coef (float, optional): classifier free guidance coefficient. |
| Returns: |
| next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. |
| """ |
| B = sequence.shape[0] |
| cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef |
| model = self if self._fsdp is None else self._fsdp |
| two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg |
|
|
| assert isinstance(cfg_conditions_list, list) |
| assert len(cfg_conditions_list) == 2 |
| local_cfg_conditions = cfg_conditions_list[0] |
| global_cfg_conditions = cfg_conditions_list[1] |
|
|
| if two_step_cfg and local_cfg_conditions != {}: |
| assert isinstance(local_cfg_conditions, tuple), type(local_cfg_conditions) |
| local_condition_tensors, local_null_condition_tensors = local_cfg_conditions |
| global_condition_tensors, global_null_condition_tensors = global_cfg_conditions |
| cond_logits = model(sequence, conditions=[], condition_tensors=[local_condition_tensors, global_condition_tensors]) |
|
|
| state = self.get_streaming_state() |
| self.set_streaming_state(unconditional_state) |
| uncond_logits = model(sequence, conditions=[], condition_tensors=[local_null_condition_tensors, global_null_condition_tensors]) |
| unconditional_state.update(self.get_streaming_state()) |
| self.set_streaming_state(state) |
| logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef |
| else: |
| local_condition_tensors = cfg_conditions_list[0].to(sequence.device) |
| global_condition_tensors = cfg_conditions_list[1].to(sequence.device) |
| sequence = torch.cat([sequence, sequence], dim=0) |
| |
| if precomputed_video_emb is None: |
| video_emb = self.compute_video_emb([cfg_conditions_list[0], cfg_conditions_list[1]], device=sequence.device) |
| else: |
| video_emb = precomputed_video_emb |
|
|
| all_logits = model( |
| sequence, |
| conditions=[], |
| video_tensor_list=[], |
| precomputed_video_emb=video_emb |
| ) |
| cond_logits, uncond_logits = all_logits.split(B, dim=0) |
| logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef |
|
|
| logits = logits.permute(0, 1, 3, 2) |
| logits = logits[..., -1] |
|
|
| |
| if use_sampling and temp > 0.0: |
| probs = torch.softmax(logits / temp, dim=-1) |
| if top_p > 0.0: |
| next_token = utils.sample_top_p(probs, p=top_p) |
| elif top_k > 0: |
| next_token = utils.sample_top_k(probs, k=top_k) |
| else: |
| next_token = utils.multinomial(probs, num_samples=1) |
| else: |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) |
| return next_token |
|
|
|
|
| @torch.no_grad() |
| def generate(self, |
| prompt: tp.Optional[torch.Tensor] = None, |
| conditions_list: tp.List = [], |
| num_samples: tp.Optional[int] = None, |
| max_gen_len: int = 256, |
| use_sampling: bool = True, |
| temp: float = 1.0, |
| top_k: int = 250, |
| top_p: float = 0.0, |
| cfg_coef: tp.Optional[float] = None, |
| two_step_cfg: tp.Optional[bool] = None, |
| remove_prompts: bool = False, |
| check: bool = False, |
| callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor: |
| """Generate tokens sampling from the model given a prompt or unconditionally. Generation can |
| be perform in a greedy fashion or using sampling with top K and top P strategies. |
| |
| Args: |
| prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. |
| conditions_tensors (list of ConditioningAttributes, optional): List of conditions. |
| num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given. |
| max_gen_len (int): Maximum generation length. |
| use_sampling (bool): Whether to use a sampling strategy or not. |
| temp (float): Sampling temperature. |
| top_k (int): K for "top-k" sampling. |
| top_p (float): P for "top-p" sampling. |
| cfg_coeff (float, optional): Classifier-free guidance coefficient. |
| two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation. |
| remove_prompts (bool): Whether to remove prompts from generation or not. |
| check (bool): Whether to apply further checks on generated sequence. |
| callback (Callback, optional): Callback function to report generation progress. |
| Returns: |
| torch.Tensor: Generated tokens. |
| """ |
| assert not self.training, "generation shouldn't be used in training mode." |
| first_param = next(iter(self.parameters())) |
| device = first_param.device |
| assert isinstance(conditions_list, list) |
| |
| assert len(conditions_list) == 2 |
| local_conditions = conditions_list[0] |
| global_conditions = conditions_list[1] |
| |
| possible_num_samples = [] |
| if num_samples is not None: |
| possible_num_samples.append(num_samples) |
| elif prompt is not None: |
| possible_num_samples.append(prompt.shape[0]) |
| elif local_conditions is not None: |
| possible_num_samples.append(len(local_conditions)) |
| else: |
| possible_num_samples.append(1) |
| |
| assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" |
| num_samples = possible_num_samples[0] |
|
|
| local_cfg_conditions: CFGConditions |
| global_cfg_conditions: CFGConditions |
| two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg |
| local_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(local_conditions) |
| local_cfg_conditions = torch.cat((local_conditions, local_null_conditions), dim=0) |
| global_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(global_conditions) |
| global_cfg_conditions = torch.cat((global_conditions, global_null_conditions), dim=0) |
|
|
| if prompt is None: |
| assert num_samples > 0 |
| prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) |
|
|
| B, K, T = prompt.shape |
| start_offset = T |
| assert start_offset < max_gen_len |
|
|
| pattern = self.pattern_provider.get_pattern(max_gen_len) |
| |
| unknown_token = -1 |
|
|
|
|
| gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device) |
| gen_codes[..., :start_offset] = prompt |
| gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) |
| start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) |
| assert start_offset_sequence is not None |
|
|
| video_emb = self.compute_video_emb([local_cfg_conditions, global_cfg_conditions], device=device) |
|
|
| with self.streaming(): |
| unconditional_state = self.get_streaming_state() |
| prev_offset = 0 |
| gen_sequence_len = gen_sequence.shape[-1] |
|
|
| for offset in range(start_offset_sequence, gen_sequence_len): |
| curr_sequence = gen_sequence[..., prev_offset:offset] |
| curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1) |
| if check: |
| assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all() |
| assert not (curr_sequence == unknown_token).any() |
| next_token = self._sample_next_token( |
| curr_sequence, |
| [local_cfg_conditions, global_cfg_conditions], |
| unconditional_state, |
| use_sampling, |
| temp, |
| top_k, |
| top_p, |
| cfg_coef=cfg_coef, |
| two_step_cfg=two_step_cfg, |
| precomputed_video_emb=video_emb |
| ) |
| valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) |
| next_token[~valid_mask] = self.special_token_id |
| gen_sequence[..., offset:offset+1] = torch.where( |
| gen_sequence[..., offset:offset+1] == unknown_token, |
| next_token, |
| gen_sequence[..., offset:offset+1] |
| ) |
| prev_offset = offset |
| if callback is not None: |
| callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) |
|
|
| unconditional_state.clear() |
| assert not (gen_sequence == unknown_token).any() |
| assert ( |
| gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id) |
| ).all() |
| out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) |
|
|
| assert (out_codes[..., :max_gen_len] != unknown_token).all() |
| assert (out_mask[..., :max_gen_len] == 1).all() |
|
|
| out_start_offset = start_offset if remove_prompts else 0 |
| out_codes = out_codes[..., out_start_offset:max_gen_len] |
|
|
| assert (out_codes >= 0).all() and (out_codes <= self.card).all() |
| return out_codes |
|
|