File size: 6,683 Bytes
968db84
 
 
a6a08bc
b3af2c2
 
 
 
a6a08bc
d378367
b6a633b
968db84
b3af2c2
 
b6a633b
1a5f6fc
b6a633b
 
 
1a5f6fc
b3af2c2
 
1a5f6fc
b6a633b
1a5f6fc
 
 
 
 
 
 
b6a633b
1a5f6fc
b3af2c2
b6a633b
b3af2c2
1a5f6fc
b3af2c2
 
b6a633b
1a5f6fc
 
 
 
 
1b12be6
1a5f6fc
 
 
 
1b12be6
1a5f6fc
 
 
 
1b12be6
b6a633b
b3af2c2
 
 
 
 
 
 
1a5f6fc
 
b3af2c2
 
 
 
 
 
 
 
 
b6a633b
1a5f6fc
 
 
 
b6a633b
1a5f6fc
 
b3af2c2
1a5f6fc
 
b6a633b
b3af2c2
b6a633b
b3af2c2
 
 
 
 
 
 
1a5f6fc
 
b6a633b
b3af2c2
 
b6a633b
b3af2c2
 
b6a633b
b3af2c2
 
 
 
 
 
 
b6a633b
1a5f6fc
 
b6a633b
1a5f6fc
 
 
b6a633b
b3af2c2
b6a633b
b3af2c2
 
 
1a5f6fc
b3af2c2
1a5f6fc
b3af2c2
b6a633b
1a5f6fc
 
 
b3af2c2
b6a633b
1a5f6fc
 
b6a633b
b3af2c2
b6a633b
1a5f6fc
 
b3af2c2
1a5f6fc
 
 
 
 
 
 
 
 
 
 
 
 
b6a633b
1a5f6fc
 
 
b6a633b
1a5f6fc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import torch.nn as nn
import os
from transformers import (
    CLIPTextModelWithProjection,
    CLIPTokenizer,
    T5EncoderModel,
    T5TokenizerFast,
)
from typing import Union, List, Optional

class SD3TextEncoderWithMask(nn.Module):
    def __init__(self, model_path, torch_dtype):
        super().__init__()

        # Define the devices for each GPU
        self.device_0 = torch.device('cuda:0')  # GPU 0 for text encoder
        self.device_1 = torch.device('cuda:1')  # GPU 1 for other tasks

        # Tokenizers for CLIP and T5
        self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
        self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
        self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))

        # Lazy loading of models
        self.text_encoder = None
        self.text_encoder_2 = None
        self.text_encoder_3 = None
        self.model_path = model_path
        self.torch_dtype = torch_dtype
        self.tokenizer_max_length = self.tokenizer.model_max_length

        # Freeze parameters to avoid training overhead
        self._freeze()

    def _freeze(self):
        """ Freeze all model parameters to avoid training overhead. """
        for param in self.parameters():
            param.requires_grad = False

    def _load_models_if_needed(self):
        """ Load models only if they haven't been loaded already. """
        if self.text_encoder is None:
            self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
                os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
            ).to(self.device_0)  # Move to GPU 0

        if self.text_encoder_2 is None:
            self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
                os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
            ).to(self.device_0)  # Move to GPU 0

        if self.text_encoder_3 is None:
            self.text_encoder_3 = T5EncoderModel.from_pretrained(
                os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
            ).to(self.device_0)  # Move to GPU 0

    def _get_t5_prompt_embeds(
        self,
        prompt: Union[str, List[str]] = None,
        num_images_per_prompt: int = 1,
        device: Optional[torch.device] = None,
        max_sequence_length: int = 128,
    ):
        """ Get embeddings from T5 model. """
        self._load_models_if_needed()  # Lazy loading
        prompt = [prompt] if isinstance(prompt, str) else prompt
        text_inputs = self.tokenizer_3(
            prompt,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        text_input_ids = text_inputs.input_ids.to(device)
        prompt_attention_mask = text_inputs.attention_mask.to(device)
        prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype, device=device)

        # Duplicate embeddings for each image generation
        batch_size = len(prompt)
        _, seq_len, _ = prompt_embeds.shape
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, seq_len, -1)
        prompt_attention_mask = prompt_attention_mask.view(batch_size, -1).repeat(num_images_per_prompt, 1)

        return prompt_embeds, prompt_attention_mask

    def _get_clip_prompt_embeds(
        self,
        prompt: Union[str, List[str]],
        num_images_per_prompt: int = 1,
        device: Optional[torch.device] = None,
        clip_model_index: int = 0,
    ):
        """ Get embeddings from CLIP model. """
        self._load_models_if_needed()  # Lazy loading

        clip_tokenizers = [self.tokenizer, self.tokenizer_2]
        clip_text_encoders = [self.text_encoder, self.text_encoder_2]

        tokenizer = clip_tokenizers[clip_model_index]
        text_encoder = clip_text_encoders[clip_model_index]

        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer_max_length,
            truncation=True,
            return_tensors="pt",
        )

        text_input_ids = text_inputs.input_ids.to(device)
        prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]

        # Duplicate embeddings for each image generation
        batch_size = len(prompt)
        pooled_prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, -1)

        return pooled_prompt_embeds

    def encode_prompt(self, 
        prompt, 
        num_images_per_prompt=1, 
        device=None
    ):
        """ Encode the prompt using both CLIP and T5 models. """
        prompt = [prompt] if isinstance(prompt, str) else prompt

        # Get embeddings from both CLIP models (on GPU 0)
        pooled_prompt_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=0)
        pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=1)
        pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)

        # Get T5 embeddings (on GPU 0)
        prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0)

        return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds

    def forward(self, input_prompts):
        """ Forward pass for encoding prompts. """
        with torch.no_grad():
            prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts)

        return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds

# Example code for using GPU 1 for other parts of the model
class OtherModel(nn.Module):
    def __init__(self):
        super(OtherModel, self).__init__()
        # Define your model layers
        self.fc = nn.Linear(512, 512).to('cuda:1')  # Example layer on GPU 1

    def forward(self, x):
        return self.fc(x)

# In the main script or generation process, use GPU 1 for other tasks
other_model = OtherModel().to('cuda:1')  # Load on GPU 1
input_data = torch.randn(64, 512).to('cuda:1')  # Move input data to GPU 1

# Perform forward pass on GPU 1
output = other_model(input_data)
print(output)