xxxpo13 commited on
Commit
b6a633b
·
verified ·
1 Parent(s): c951146

Update modeling_text_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_text_encoder.py +164 -160
modeling_text_encoder.py CHANGED
@@ -1,168 +1,172 @@
 
 
1
  import torch
2
  import torch.nn as nn
3
  import os
 
 
4
  from transformers import (
5
- CLIPTextModelWithProjection,
6
- CLIPTokenizer,
7
- T5EncoderModel,
8
- T5TokenizerFast,
9
- BitsAndBytesConfig
10
  )
 
 
11
  from typing import Any, Callable, Dict, List, Optional, Union
12
 
 
 
13
  class SD3TextEncoderWithMask(nn.Module):
14
- def __init__(self, model_path, torch_dtype, use_8bit_quantization=True):
15
- super().__init__()
16
- self.model_path = model_path
17
- self.torch_dtype = torch_dtype
18
- self.device = "cpu"
19
- self.use_8bit_quantization = use_8bit_quantization
20
-
21
- # Initialize tokenizers
22
- self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
23
- self.tokenizer_max_length = self.tokenizer.model_max_length
24
- self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
25
- self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
26
-
27
- # Initialize encoders (kept on CPU initially)
28
- self.text_encoder = None
29
- self.text_encoder_2 = None
30
- self.text_encoder_3 = None
31
-
32
- self._freeze()
33
-
34
- def _freeze(self):
35
- for param in self.parameters():
36
- param.requires_grad = False
37
-
38
- def to(self, device):
39
- self.device = device
40
- return self
41
-
42
- def _load_and_move_encoder(self, encoder_name):
43
- if encoder_name == 'text_encoder' and self.text_encoder is None:
44
- self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
45
- os.path.join(self.model_path, 'text_encoder'),
46
- torch_dtype=self.torch_dtype
47
- ).to(self.device)
48
- elif encoder_name == 'text_encoder_2' and self.text_encoder_2 is None:
49
- self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
50
- os.path.join(self.model_path, 'text_encoder_2'),
51
- torch_dtype=self.torch_dtype
52
- ).to(self.device)
53
- elif encoder_name == 'text_encoder_3' and self.text_encoder_3 is None:
54
- if self.use_8bit_quantization:
55
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
56
- self.text_encoder_3 = T5EncoderModel.from_pretrained(
57
- os.path.join(self.model_path, 'text_encoder_3'),
58
- quantization_config=quantization_config,
59
- device_map="auto" # This will automatically decide the best device mapping
60
- )
61
- else:
62
- self.text_encoder_3 = T5EncoderModel.from_pretrained(
63
- os.path.join(self.model_path, 'text_encoder_3'),
64
- torch_dtype=self.torch_dtype
65
- ).to(self.device)
66
-
67
- def _unload_encoder(self, encoder_name):
68
- if encoder_name == 'text_encoder':
69
- self.text_encoder = None
70
- elif encoder_name == 'text_encoder_2':
71
- self.text_encoder_2 = None
72
- elif encoder_name == 'text_encoder_3':
73
- self.text_encoder_3 = None
74
- torch.cuda.empty_cache()
75
-
76
- def _get_t5_prompt_embeds(
77
- self,
78
- prompt: Union[str, List[str]] = None,
79
- num_images_per_prompt: int = 1,
80
- max_sequence_length: int = 128,
81
- ):
82
- self._load_and_move_encoder('text_encoder_3')
83
-
84
- prompt = [prompt] if isinstance(prompt, str) else prompt
85
- batch_size = len(prompt)
86
-
87
- text_inputs = self.tokenizer_3(
88
- prompt,
89
- padding="max_length",
90
- max_length=max_sequence_length,
91
- truncation=True,
92
- add_special_tokens=True,
93
- return_tensors="pt",
94
- )
95
- text_input_ids = text_inputs.input_ids.to(self.device)
96
- prompt_attention_mask = text_inputs.attention_mask.to(self.device)
97
-
98
- prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
99
- if not self.use_8bit_quantization:
100
- prompt_embeds = prompt_embeds.to(dtype=self.torch_dtype)
101
-
102
- _, seq_len, _ = prompt_embeds.shape
103
-
104
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
105
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
106
- prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
107
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
108
-
109
- self._unload_encoder('text_encoder_3')
110
- return prompt_embeds, prompt_attention_mask
111
-
112
- def _get_clip_prompt_embeds(
113
- self,
114
- prompt: Union[str, List[str]],
115
- num_images_per_prompt: int = 1,
116
- device: Optional[torch.device] = None,
117
- clip_model_index: int = 0,
118
- ):
119
- """ Get embeddings from CLIP model. """
120
- self._load_models_if_needed() # Lazy loading
121
-
122
- clip_tokenizers = [self.tokenizer, self.tokenizer_2]
123
- clip_text_encoders = [self.text_encoder, self.text_encoder_2]
124
-
125
- tokenizer = clip_tokenizers[clip_model_index]
126
- text_encoder = clip_text_encoders[clip_model_index]
127
-
128
- text_inputs = tokenizer(
129
- prompt,
130
- padding="max_length",
131
- max_length=self.tokenizer_max_length,
132
- truncation=True,
133
- return_tensors="pt",
134
- )
135
-
136
- text_input_ids = text_inputs.input_ids.to(device)
137
- prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]
138
-
139
- # Duplicate embeddings for each image generation
140
- batch_size = len(prompt)
141
- pooled_prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, -1)
142
-
143
- return pooled_prompt_embeds
144
-
145
- def encode_prompt(self,
146
- prompt,
147
- num_images_per_prompt=1,
148
- device=None
149
- ):
150
- """ Encode the prompt using both CLIP and T5 models. """
151
- prompt = [prompt] if isinstance(prompt, str) else prompt
152
-
153
- # Get embeddings from both CLIP models
154
- pooled_prompt_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=0)
155
- pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=1)
156
- pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
157
-
158
- # Get T5 embeddings
159
- prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device)
160
-
161
- return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
162
-
163
- def forward(self, input_prompts, device):
164
- """ Forward pass for encoding prompts. """
165
- with torch.no_grad():
166
- prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, num_images_per_prompt=1, device=device)
167
-
168
- return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
 
1
+ now cam you optimise this file
2
+
3
  import torch
4
  import torch.nn as nn
5
  import os
6
+
7
+
8
  from transformers import (
9
+     CLIPTextModelWithProjection,
10
+     CLIPTokenizer,
11
+     T5EncoderModel,
12
+     T5TokenizerFast,
 
13
  )
14
+
15
+
16
  from typing import Any, Callable, Dict, List, Optional, Union
17
 
18
+
19
+
20
  class SD3TextEncoderWithMask(nn.Module):
21
+     def __init__(self, model_path, torch_dtype):
22
+         super().__init__()
23
+
24
+ self.device_0 = torch.device('cuda:0') # GPU 0 for text encoder
25
+ self.device_1 = torch.device('cuda:1') # GPU 1 for other tasks
26
+
27
+         # CLIP-L
28
+         self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
29
+         self.tokenizer_max_length = self.tokenizer.model_max_length
30
+         self.text_encoder = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
31
+
32
+
33
+         # CLIP-G
34
+         self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
35
+         self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
36
+
37
+
38
+         # T5
39
+         self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
40
+         self.text_encoder_3 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_3'), torch_dtype=torch_dtype).to(self.device_1)
41
+    
42
+         self._freeze()
43
+
44
+
45
+     def _freeze(self):
46
+         for param in self.parameters():
47
+             param.requires_grad = False
48
+
49
+
50
+     def _get_t5_prompt_embeds(
51
+         self,
52
+         prompt: Union[str, List[str]] = None,
53
+         num_images_per_prompt: int = 1,
54
+         device: Optional[torch.device] = None,
55
+         max_sequence_length: int = 128,
56
+     ):
57
+         prompt = [prompt] if isinstance(prompt, str) else prompt
58
+         batch_size = len(prompt)
59
+
60
+
61
+         text_inputs = self.tokenizer_3(
62
+             prompt,
63
+             padding="max_length",
64
+             max_length=max_sequence_length,
65
+             truncation=True,
66
+             add_special_tokens=True,
67
+             return_tensors="pt",
68
+         )
69
+         text_input_ids = text_inputs.input_ids
70
+         prompt_attention_mask = text_inputs.attention_mask
71
+         prompt_attention_mask = prompt_attention_mask.to(device)
72
+         prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
73
+         dtype = self.text_encoder_3.dtype
74
+         prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
75
+
76
+
77
+         _, seq_len, _ = prompt_embeds.shape
78
+
79
+
80
+         # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
81
+         prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
82
+         prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
83
+         prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
84
+         prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
85
+
86
+
87
+         return prompt_embeds, prompt_attention_mask
88
+
89
+
90
+     def _get_clip_prompt_embeds(
91
+         self,
92
+         prompt: Union[str, List[str]],
93
+         num_images_per_prompt: int = 1,
94
+         device: Optional[torch.device] = None,
95
+         clip_skip: Optional[int] = None,
96
+         clip_model_index: int = 0,
97
+     ):
98
+
99
+
100
+         clip_tokenizers = [self.tokenizer, self.tokenizer_2]
101
+         clip_text_encoders = [self.text_encoder, self.text_encoder_2]
102
+
103
+
104
+         tokenizer = clip_tokenizers[clip_model_index]
105
+         text_encoder = clip_text_encoders[clip_model_index]
106
+
107
+
108
+         batch_size = len(prompt)
109
+
110
+
111
+         text_inputs = tokenizer(
112
+             prompt,
113
+             padding="max_length",
114
+             max_length=self.tokenizer_max_length,
115
+             truncation=True,
116
+             return_tensors="pt",
117
+         )
118
+
119
+
120
+         text_input_ids = text_inputs.input_ids
121
+         prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
122
+         pooled_prompt_embeds = prompt_embeds[0]
123
+         pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
124
+         pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
125
+
126
+
127
+         return pooled_prompt_embeds
128
+
129
+
130
+     def encode_prompt(self,
131
+         prompt,
132
+         num_images_per_prompt=1,
133
+         clip_skip: Optional[int] = None,
134
+         device=None,
135
+     ):
136
+         prompt = [prompt] if isinstance(prompt, str) else prompt
137
+
138
+
139
+         pooled_prompt_embed = self._get_clip_prompt_embeds(
140
+             prompt=prompt,
141
+             device=device,
142
+             num_images_per_prompt=num_images_per_prompt,
143
+             clip_skip=clip_skip,
144
+             clip_model_index=0,
145
+         )
146
+         pooled_prompt_2_embed = self._get_clip_prompt_embeds(
147
+             prompt=prompt,
148
+             device=device,
149
+             num_images_per_prompt=num_images_per_prompt,
150
+             clip_skip=clip_skip,
151
+             clip_model_index=1,
152
+         )
153
+         pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
154
+
155
+
156
+         prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
157
+             prompt=prompt,
158
+             num_images_per_prompt=num_images_per_prompt,
159
+             device=device,
160
+         )
161
+         return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
162
+
163
+
164
+     def forward(self, input_prompts, device):
165
+         with torch.no_grad():
166
+             prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, clip_skip=None, device=device)
167
+
168
+
169
+         return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
170
+
171
+
172
+ implement cpu offload