File size: 6,683 Bytes
968db84 a6a08bc b3af2c2 a6a08bc d378367 b6a633b 968db84 b3af2c2 b6a633b 1a5f6fc b6a633b 1a5f6fc b3af2c2 1a5f6fc b6a633b 1a5f6fc b6a633b 1a5f6fc b3af2c2 b6a633b b3af2c2 1a5f6fc b3af2c2 b6a633b 1a5f6fc 1b12be6 1a5f6fc 1b12be6 1a5f6fc 1b12be6 b6a633b b3af2c2 1a5f6fc b3af2c2 b6a633b 1a5f6fc b6a633b 1a5f6fc b3af2c2 1a5f6fc b6a633b b3af2c2 b6a633b b3af2c2 1a5f6fc b6a633b b3af2c2 b6a633b b3af2c2 b6a633b b3af2c2 b6a633b 1a5f6fc b6a633b 1a5f6fc b6a633b b3af2c2 b6a633b b3af2c2 1a5f6fc b3af2c2 1a5f6fc b3af2c2 b6a633b 1a5f6fc b3af2c2 b6a633b 1a5f6fc b6a633b b3af2c2 b6a633b 1a5f6fc b3af2c2 1a5f6fc b6a633b 1a5f6fc b6a633b 1a5f6fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import torch.nn as nn
import os
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from typing import Union, List, Optional
class SD3TextEncoderWithMask(nn.Module):
def __init__(self, model_path, torch_dtype):
super().__init__()
# Define the devices for each GPU
self.device_0 = torch.device('cuda:0') # GPU 0 for text encoder
self.device_1 = torch.device('cuda:1') # GPU 1 for other tasks
# Tokenizers for CLIP and T5
self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
# Lazy loading of models
self.text_encoder = None
self.text_encoder_2 = None
self.text_encoder_3 = None
self.model_path = model_path
self.torch_dtype = torch_dtype
self.tokenizer_max_length = self.tokenizer.model_max_length
# Freeze parameters to avoid training overhead
self._freeze()
def _freeze(self):
""" Freeze all model parameters to avoid training overhead. """
for param in self.parameters():
param.requires_grad = False
def _load_models_if_needed(self):
""" Load models only if they haven't been loaded already. """
if self.text_encoder is None:
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
).to(self.device_0) # Move to GPU 0
if self.text_encoder_2 is None:
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
).to(self.device_0) # Move to GPU 0
if self.text_encoder_3 is None:
self.text_encoder_3 = T5EncoderModel.from_pretrained(
os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
).to(self.device_0) # Move to GPU 0
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
max_sequence_length: int = 128,
):
""" Get embeddings from T5 model. """
self._load_models_if_needed() # Lazy loading
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer_3(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype, device=device)
# Duplicate embeddings for each image generation
batch_size = len(prompt)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1).repeat(num_images_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_model_index: int = 0,
):
""" Get embeddings from CLIP model. """
self._load_models_if_needed() # Lazy loading
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = clip_tokenizers[clip_model_index]
text_encoder = clip_text_encoders[clip_model_index]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]
# Duplicate embeddings for each image generation
batch_size = len(prompt)
pooled_prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, -1)
return pooled_prompt_embeds
def encode_prompt(self,
prompt,
num_images_per_prompt=1,
device=None
):
""" Encode the prompt using both CLIP and T5 models. """
prompt = [prompt] if isinstance(prompt, str) else prompt
# Get embeddings from both CLIP models (on GPU 0)
pooled_prompt_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=0)
pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=1)
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
# Get T5 embeddings (on GPU 0)
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0)
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
def forward(self, input_prompts):
""" Forward pass for encoding prompts. """
with torch.no_grad():
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts)
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
# Example code for using GPU 1 for other parts of the model
class OtherModel(nn.Module):
def __init__(self):
super(OtherModel, self).__init__()
# Define your model layers
self.fc = nn.Linear(512, 512).to('cuda:1') # Example layer on GPU 1
def forward(self, x):
return self.fc(x)
# In the main script or generation process, use GPU 1 for other tasks
other_model = OtherModel().to('cuda:1') # Load on GPU 1
input_data = torch.randn(64, 512).to('cuda:1') # Move input data to GPU 1
# Perform forward pass on GPU 1
output = other_model(input_data)
print(output)
|