Spaces:
Configuration error
Configuration error
| import math | |
| from contextlib import nullcontext | |
| from functools import partial | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import kornia | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange, repeat | |
| from omegaconf import ListConfig | |
| from torch.utils.checkpoint import checkpoint | |
| from transformers import ( | |
| T5EncoderModel, | |
| T5Tokenizer, | |
| ) | |
| from ...util import ( | |
| append_dims, | |
| autocast, | |
| count_params, | |
| default, | |
| disabled_train, | |
| expand_dims_like, | |
| instantiate_from_config, | |
| ) | |
| class AbstractEmbModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self._is_trainable = None | |
| self._ucg_rate = None | |
| self._input_key = None | |
| def is_trainable(self) -> bool: | |
| return self._is_trainable | |
| def ucg_rate(self) -> Union[float, torch.Tensor]: | |
| return self._ucg_rate | |
| def input_key(self) -> str: | |
| return self._input_key | |
| def is_trainable(self, value: bool): | |
| self._is_trainable = value | |
| def ucg_rate(self, value: Union[float, torch.Tensor]): | |
| self._ucg_rate = value | |
| def input_key(self, value: str): | |
| self._input_key = value | |
| def is_trainable(self): | |
| del self._is_trainable | |
| def ucg_rate(self): | |
| del self._ucg_rate | |
| def input_key(self): | |
| del self._input_key | |
| class GeneralConditioner(nn.Module): | |
| OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} | |
| KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} | |
| def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]): | |
| super().__init__() | |
| embedders = [] | |
| for n, embconfig in enumerate(emb_models): | |
| embedder = instantiate_from_config(embconfig) | |
| assert isinstance( | |
| embedder, AbstractEmbModel | |
| ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" | |
| embedder.is_trainable = embconfig.get("is_trainable", False) | |
| embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) | |
| if not embedder.is_trainable: | |
| embedder.train = disabled_train | |
| for param in embedder.parameters(): | |
| param.requires_grad = False | |
| embedder.eval() | |
| print( | |
| f"Initialized embedder #{n}: {embedder.__class__.__name__} " | |
| f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" | |
| ) | |
| if "input_key" in embconfig: | |
| embedder.input_key = embconfig["input_key"] | |
| elif "input_keys" in embconfig: | |
| embedder.input_keys = embconfig["input_keys"] | |
| else: | |
| raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") | |
| embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) | |
| if embedder.legacy_ucg_val is not None: | |
| embedder.ucg_prng = np.random.RandomState() | |
| embedders.append(embedder) | |
| self.embedders = nn.ModuleList(embedders) | |
| if len(cor_embs) > 0: | |
| assert len(cor_p) == 2 ** len(cor_embs) | |
| self.cor_embs = cor_embs | |
| self.cor_p = cor_p | |
| def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: | |
| assert embedder.legacy_ucg_val is not None | |
| p = embedder.ucg_rate | |
| val = embedder.legacy_ucg_val | |
| for i in range(len(batch[embedder.input_key])): | |
| if embedder.ucg_prng.choice(2, p=[1 - p, p]): | |
| batch[embedder.input_key][i] = val | |
| return batch | |
| def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict: | |
| assert embedder.legacy_ucg_val is not None | |
| val = embedder.legacy_ucg_val | |
| for i in range(len(batch[embedder.input_key])): | |
| if cond_or_not[i]: | |
| batch[embedder.input_key][i] = val | |
| return batch | |
| def get_single_embedding( | |
| self, | |
| embedder, | |
| batch, | |
| output, | |
| cond_or_not: Optional[np.ndarray] = None, | |
| force_zero_embeddings: Optional[List] = None, | |
| ): | |
| embedding_context = nullcontext if embedder.is_trainable else torch.no_grad | |
| with embedding_context(): | |
| if hasattr(embedder, "input_key") and (embedder.input_key is not None): | |
| if embedder.legacy_ucg_val is not None: | |
| if cond_or_not is None: | |
| batch = self.possibly_get_ucg_val(embedder, batch) | |
| else: | |
| batch = self.surely_get_ucg_val(embedder, batch, cond_or_not) | |
| emb_out = embedder(batch[embedder.input_key]) | |
| elif hasattr(embedder, "input_keys"): | |
| emb_out = embedder(*[batch[k] for k in embedder.input_keys]) | |
| assert isinstance( | |
| emb_out, (torch.Tensor, list, tuple) | |
| ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" | |
| if not isinstance(emb_out, (list, tuple)): | |
| emb_out = [emb_out] | |
| for emb in emb_out: | |
| out_key = self.OUTPUT_DIM2KEYS[emb.dim()] | |
| if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: | |
| if cond_or_not is None: | |
| emb = ( | |
| expand_dims_like( | |
| torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), | |
| emb, | |
| ) | |
| * emb | |
| ) | |
| else: | |
| emb = ( | |
| expand_dims_like( | |
| torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device), | |
| emb, | |
| ) | |
| * emb | |
| ) | |
| if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings: | |
| emb = torch.zeros_like(emb) | |
| if out_key in output: | |
| output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key]) | |
| else: | |
| output[out_key] = emb | |
| return output | |
| def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict: | |
| output = dict() | |
| if force_zero_embeddings is None: | |
| force_zero_embeddings = [] | |
| if len(self.cor_embs) > 0: | |
| batch_size = len(batch[list(batch.keys())[0]]) | |
| rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p) | |
| for emb_idx in self.cor_embs: | |
| cond_or_not = rand_idx % 2 | |
| rand_idx //= 2 | |
| output = self.get_single_embedding( | |
| self.embedders[emb_idx], | |
| batch, | |
| output=output, | |
| cond_or_not=cond_or_not, | |
| force_zero_embeddings=force_zero_embeddings, | |
| ) | |
| for i, embedder in enumerate(self.embedders): | |
| if i in self.cor_embs: | |
| continue | |
| output = self.get_single_embedding( | |
| embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings | |
| ) | |
| return output | |
| def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None): | |
| if force_uc_zero_embeddings is None: | |
| force_uc_zero_embeddings = [] | |
| ucg_rates = list() | |
| for embedder in self.embedders: | |
| ucg_rates.append(embedder.ucg_rate) | |
| embedder.ucg_rate = 0.0 | |
| cor_embs = self.cor_embs | |
| cor_p = self.cor_p | |
| self.cor_embs = [] | |
| self.cor_p = [] | |
| c = self(batch_c) | |
| uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) | |
| for embedder, rate in zip(self.embedders, ucg_rates): | |
| embedder.ucg_rate = rate | |
| self.cor_embs = cor_embs | |
| self.cor_p = cor_p | |
| return c, uc | |
| class FrozenT5Embedder(AbstractEmbModel): | |
| """Uses the T5 transformer encoder for text""" | |
| def __init__( | |
| self, | |
| model_dir="google/t5-v1_1-xxl", | |
| device="cuda", | |
| max_length=77, | |
| freeze=True, | |
| cache_dir=None, | |
| ): | |
| super().__init__() | |
| if model_dir is not "google/t5-v1_1-xxl": | |
| self.tokenizer = T5Tokenizer.from_pretrained(model_dir) | |
| self.transformer = T5EncoderModel.from_pretrained(model_dir) | |
| else: | |
| self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir) | |
| self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| # @autocast | |
| def forward(self, text): | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| with torch.autocast("cuda", enabled=False): | |
| outputs = self.transformer(input_ids=tokens) | |
| z = outputs.last_hidden_state | |
| return z | |
| def encode(self, text): | |
| return self(text) | |