xxxpo13 commited on
Commit
b3af2c2
·
verified ·
1 Parent(s): b6a633b

Update modeling_text_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_text_encoder.py +109 -111
modeling_text_encoder.py CHANGED
@@ -1,15 +1,13 @@
1
- now cam you optimise this file
2
-
3
  import torch
4
  import torch.nn as nn
5
  import os
6
 
7
 
8
  from transformers import (
9
-     CLIPTextModelWithProjection,
10
-     CLIPTokenizer,
11
-     T5EncoderModel,
12
-     T5TokenizerFast,
13
  )
14
 
15
 
@@ -18,155 +16,155 @@ from typing import Any, Callable, Dict, List, Optional, Union
18
 
19
 
20
  class SD3TextEncoderWithMask(nn.Module):
21
-     def __init__(self, model_path, torch_dtype):
22
-         super().__init__()
23
 
24
  self.device_0 = torch.device('cuda:0') # GPU 0 for text encoder
25
  self.device_1 = torch.device('cuda:1') # GPU 1 for other tasks
26
 
27
-         # CLIP-L
28
-         self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
29
-         self.tokenizer_max_length = self.tokenizer.model_max_length
30
-         self.text_encoder = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
31
 
32
 
33
-         # CLIP-G
34
-         self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
35
-         self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
36
 
37
 
38
-         # T5
39
-         self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
40
-         self.text_encoder_3 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_3'), torch_dtype=torch_dtype).to(self.device_1)
41
-    
42
-         self._freeze()
43
 
44
 
45
-     def _freeze(self):
46
-         for param in self.parameters():
47
-             param.requires_grad = False
48
 
49
 
50
-     def _get_t5_prompt_embeds(
51
-         self,
52
-         prompt: Union[str, List[str]] = None,
53
-         num_images_per_prompt: int = 1,
54
-         device: Optional[torch.device] = None,
55
-         max_sequence_length: int = 128,
56
-     ):
57
-         prompt = [prompt] if isinstance(prompt, str) else prompt
58
-         batch_size = len(prompt)
59
 
60
 
61
-         text_inputs = self.tokenizer_3(
62
-             prompt,
63
-             padding="max_length",
64
-             max_length=max_sequence_length,
65
-             truncation=True,
66
-             add_special_tokens=True,
67
-             return_tensors="pt",
68
-         )
69
-         text_input_ids = text_inputs.input_ids
70
-         prompt_attention_mask = text_inputs.attention_mask
71
-         prompt_attention_mask = prompt_attention_mask.to(device)
72
-         prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
73
-         dtype = self.text_encoder_3.dtype
74
-         prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
75
 
76
 
77
-         _, seq_len, _ = prompt_embeds.shape
78
 
79
 
80
-         # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
81
-         prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
82
-         prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
83
-         prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
84
-         prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
85
 
86
 
87
-         return prompt_embeds, prompt_attention_mask
88
 
89
 
90
-     def _get_clip_prompt_embeds(
91
-         self,
92
-         prompt: Union[str, List[str]],
93
-         num_images_per_prompt: int = 1,
94
-         device: Optional[torch.device] = None,
95
-         clip_skip: Optional[int] = None,
96
-         clip_model_index: int = 0,
97
-     ):
98
 
99
 
100
-         clip_tokenizers = [self.tokenizer, self.tokenizer_2]
101
-         clip_text_encoders = [self.text_encoder, self.text_encoder_2]
102
 
103
 
104
-         tokenizer = clip_tokenizers[clip_model_index]
105
-         text_encoder = clip_text_encoders[clip_model_index]
106
 
107
 
108
-         batch_size = len(prompt)
109
 
110
 
111
-         text_inputs = tokenizer(
112
-             prompt,
113
-             padding="max_length",
114
-             max_length=self.tokenizer_max_length,
115
-             truncation=True,
116
-             return_tensors="pt",
117
-         )
118
 
119
 
120
-         text_input_ids = text_inputs.input_ids
121
-         prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
122
-         pooled_prompt_embeds = prompt_embeds[0]
123
-         pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
124
-         pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
125
 
126
 
127
-         return pooled_prompt_embeds
128
 
129
 
130
-     def encode_prompt(self,
131
-         prompt,
132
-         num_images_per_prompt=1,
133
-         clip_skip: Optional[int] = None,
134
-         device=None,
135
-     ):
136
-         prompt = [prompt] if isinstance(prompt, str) else prompt
137
 
138
 
139
-         pooled_prompt_embed = self._get_clip_prompt_embeds(
140
-             prompt=prompt,
141
-             device=device,
142
-             num_images_per_prompt=num_images_per_prompt,
143
-             clip_skip=clip_skip,
144
-             clip_model_index=0,
145
-         )
146
-         pooled_prompt_2_embed = self._get_clip_prompt_embeds(
147
-             prompt=prompt,
148
-             device=device,
149
-             num_images_per_prompt=num_images_per_prompt,
150
-             clip_skip=clip_skip,
151
-             clip_model_index=1,
152
-         )
153
-         pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
154
 
155
 
156
-         prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
157
-             prompt=prompt,
158
-             num_images_per_prompt=num_images_per_prompt,
159
-             device=device,
160
-         )
161
-         return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
162
 
163
 
164
-     def forward(self, input_prompts, device):
165
-         with torch.no_grad():
166
-             prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, clip_skip=None, device=device)
167
 
168
 
169
-         return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
170
 
171
 
172
  implement cpu offload
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import os
4
 
5
 
6
  from transformers import (
7
+ CLIPTextModelWithProjection,
8
+ CLIPTokenizer,
9
+ T5EncoderModel,
10
+ T5TokenizerFast,
11
  )
12
 
13
 
 
16
 
17
 
18
  class SD3TextEncoderWithMask(nn.Module):
19
+ def __init__(self, model_path, torch_dtype):
20
+ super().__init__()
21
 
22
  self.device_0 = torch.device('cuda:0') # GPU 0 for text encoder
23
  self.device_1 = torch.device('cuda:1') # GPU 1 for other tasks
24
 
25
+ # CLIP-L
26
+ self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
27
+ self.tokenizer_max_length = self.tokenizer.model_max_length
28
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
29
 
30
 
31
+ # CLIP-G
32
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
33
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
34
 
35
 
36
+ # T5
37
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
38
+ self.text_encoder_3 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_3'), torch_dtype=torch_dtype).to(self.device_1)
39
+
40
+ self._freeze()
41
 
42
 
43
+ def _freeze(self):
44
+ for param in self.parameters():
45
+ param.requires_grad = False
46
 
47
 
48
+ def _get_t5_prompt_embeds(
49
+ self,
50
+ prompt: Union[str, List[str]] = None,
51
+ num_images_per_prompt: int = 1,
52
+ device: Optional[torch.device] = None,
53
+ max_sequence_length: int = 128,
54
+ ):
55
+ prompt = [prompt] if isinstance(prompt, str) else prompt
56
+ batch_size = len(prompt)
57
 
58
 
59
+ text_inputs = self.tokenizer_3(
60
+ prompt,
61
+ padding="max_length",
62
+ max_length=max_sequence_length,
63
+ truncation=True,
64
+ add_special_tokens=True,
65
+ return_tensors="pt",
66
+ )
67
+ text_input_ids = text_inputs.input_ids
68
+ prompt_attention_mask = text_inputs.attention_mask
69
+ prompt_attention_mask = prompt_attention_mask.to(device)
70
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
71
+ dtype = self.text_encoder_3.dtype
72
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
73
 
74
 
75
+ _, seq_len, _ = prompt_embeds.shape
76
 
77
 
78
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
79
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
80
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
81
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
82
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
83
 
84
 
85
+ return prompt_embeds, prompt_attention_mask
86
 
87
 
88
+ def _get_clip_prompt_embeds(
89
+ self,
90
+ prompt: Union[str, List[str]],
91
+ num_images_per_prompt: int = 1,
92
+ device: Optional[torch.device] = None,
93
+ clip_skip: Optional[int] = None,
94
+ clip_model_index: int = 0,
95
+ ):
96
 
97
 
98
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
99
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
100
 
101
 
102
+ tokenizer = clip_tokenizers[clip_model_index]
103
+ text_encoder = clip_text_encoders[clip_model_index]
104
 
105
 
106
+ batch_size = len(prompt)
107
 
108
 
109
+ text_inputs = tokenizer(
110
+ prompt,
111
+ padding="max_length",
112
+ max_length=self.tokenizer_max_length,
113
+ truncation=True,
114
+ return_tensors="pt",
115
+ )
116
 
117
 
118
+ text_input_ids = text_inputs.input_ids
119
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
120
+ pooled_prompt_embeds = prompt_embeds[0]
121
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
122
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
123
 
124
 
125
+ return pooled_prompt_embeds
126
 
127
 
128
+ def encode_prompt(self,
129
+ prompt,
130
+ num_images_per_prompt=1,
131
+ clip_skip: Optional[int] = None,
132
+ device=None,
133
+ ):
134
+ prompt = [prompt] if isinstance(prompt, str) else prompt
135
 
136
 
137
+ pooled_prompt_embed = self._get_clip_prompt_embeds(
138
+ prompt=prompt,
139
+ device=device,
140
+ num_images_per_prompt=num_images_per_prompt,
141
+ clip_skip=clip_skip,
142
+ clip_model_index=0,
143
+ )
144
+ pooled_prompt_2_embed = self._get_clip_prompt_embeds(
145
+ prompt=prompt,
146
+ device=device,
147
+ num_images_per_prompt=num_images_per_prompt,
148
+ clip_skip=clip_skip,
149
+ clip_model_index=1,
150
+ )
151
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
152
 
153
 
154
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
155
+ prompt=prompt,
156
+ num_images_per_prompt=num_images_per_prompt,
157
+ device=device,
158
+ )
159
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
160
 
161
 
162
+ def forward(self, input_prompts, device):
163
+ with torch.no_grad():
164
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, clip_skip=None, device=device)
165
 
166
 
167
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
168
 
169
 
170
  implement cpu offload