Delete bria_utils.py
Browse files- bria_utils.py +0 -71
bria_utils.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
from typing import Union, Optional, List
|
| 2 |
-
import torch
|
| 3 |
-
from diffusers.utils import logging
|
| 4 |
-
from transformers import (
|
| 5 |
-
T5EncoderModel,
|
| 6 |
-
T5TokenizerFast,
|
| 7 |
-
)
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 11 |
-
|
| 12 |
-
def get_t5_prompt_embeds(
|
| 13 |
-
tokenizer: T5TokenizerFast ,
|
| 14 |
-
text_encoder: T5EncoderModel,
|
| 15 |
-
prompt: Union[str, List[str]] = None,
|
| 16 |
-
num_images_per_prompt: int = 1,
|
| 17 |
-
max_sequence_length: int = 128,
|
| 18 |
-
device: Optional[torch.device] = None,
|
| 19 |
-
):
|
| 20 |
-
device = device or text_encoder.device
|
| 21 |
-
|
| 22 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 23 |
-
batch_size = len(prompt)
|
| 24 |
-
|
| 25 |
-
text_inputs = tokenizer(
|
| 26 |
-
prompt,
|
| 27 |
-
# padding="max_length",
|
| 28 |
-
max_length=max_sequence_length,
|
| 29 |
-
truncation=True,
|
| 30 |
-
add_special_tokens=True,
|
| 31 |
-
return_tensors="pt",
|
| 32 |
-
)
|
| 33 |
-
text_input_ids = text_inputs.input_ids
|
| 34 |
-
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 35 |
-
|
| 36 |
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 37 |
-
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 38 |
-
logger.warning(
|
| 39 |
-
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 40 |
-
f" {max_sequence_length} tokens: {removed_text}"
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 44 |
-
|
| 45 |
-
# Concat zeros to max_sequence
|
| 46 |
-
b, seq_len, dim = prompt_embeds.shape
|
| 47 |
-
if seq_len<max_sequence_length:
|
| 48 |
-
padding = torch.zeros((b,max_sequence_length-seq_len,dim),dtype=prompt_embeds.dtype,device=prompt_embeds.device)
|
| 49 |
-
prompt_embeds = torch.concat([prompt_embeds,padding],dim=1)
|
| 50 |
-
|
| 51 |
-
prompt_embeds = prompt_embeds.to(device=device)
|
| 52 |
-
|
| 53 |
-
_, seq_len, _ = prompt_embeds.shape
|
| 54 |
-
|
| 55 |
-
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 56 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 57 |
-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 58 |
-
|
| 59 |
-
return prompt_embeds
|
| 60 |
-
|
| 61 |
-
# in order the get the same sigmas as in training and sample from them
|
| 62 |
-
def get_original_sigmas(num_train_timesteps=1000,num_inference_steps=1000):
|
| 63 |
-
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
| 64 |
-
sigmas = timesteps / num_train_timesteps
|
| 65 |
-
|
| 66 |
-
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps-1, num_inference_steps)]
|
| 67 |
-
new_sigmas = sigmas[inds]
|
| 68 |
-
return new_sigmas
|
| 69 |
-
|
| 70 |
-
def is_ng_none(negative_prompt):
|
| 71 |
-
return negative_prompt is None or negative_prompt=='' or (isinstance(negative_prompt,list) and negative_prompt[0] is None) or (type(negative_prompt)==list and negative_prompt[0]=='')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|