xxxpo13 commited on
Commit
a6a08bc
·
verified ·
1 Parent(s): 8e86ad2

Update modeling_text_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_text_encoder.py +73 -90
modeling_text_encoder.py CHANGED
@@ -1,49 +1,66 @@
1
  import torch
2
  import torch.nn as nn
3
  import os
4
- from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
 
 
 
 
 
5
  from typing import Union, List, Optional
6
- from functools import lru_cache
7
 
8
  class SD3TextEncoderWithMask(nn.Module):
9
  def __init__(self, model_path, torch_dtype):
10
  super().__init__()
11
-
12
- # Use lazy loading for models
13
- self.model_path = model_path
14
- self.torch_dtype = torch_dtype
15
-
16
- # CLIP-L
17
  self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
18
- self.tokenizer_max_length = self.tokenizer.model_max_length
19
-
20
- # CLIP-G
21
  self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
22
-
23
- # T5
24
  self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
25
-
 
 
 
 
 
 
 
 
 
26
  self._freeze()
27
 
28
  def _freeze(self):
 
29
  for param in self.parameters():
30
  param.requires_grad = False
31
 
32
- @lru_cache(maxsize=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def _get_t5_prompt_embeds(
34
  self,
35
- prompt: Union[str, tuple],
36
  num_images_per_prompt: int = 1,
37
  device: Optional[torch.device] = None,
38
  max_sequence_length: int = 128,
39
  ):
40
- if isinstance(prompt, tuple):
41
- prompt = list(prompt)
42
- else:
43
- prompt = [prompt]
44
-
45
- batch_size = len(prompt)
46
-
47
  text_inputs = self.tokenizer_3(
48
  prompt,
49
  padding="max_length",
@@ -52,57 +69,35 @@ class SD3TextEncoderWithMask(nn.Module):
52
  add_special_tokens=True,
53
  return_tensors="pt",
54
  )
55
- text_input_ids = text_inputs.input_ids
56
- prompt_attention_mask = text_inputs.attention_mask.to(device)
57
-
58
- if not hasattr(self, 'text_encoder_3'):
59
- self.text_encoder_3 = T5EncoderModel.from_pretrained(
60
- os.path.join(self.model_path, 'text_encoder_3'),
61
- torch_dtype=self.torch_dtype
62
- ).to(device)
63
 
64
- with torch.no_grad():
65
- prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
66
-
67
- prompt_embeds = prompt_embeds.to(dtype=self.torch_dtype, device=device)
68
 
 
 
69
  _, seq_len, _ = prompt_embeds.shape
70
-
71
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
72
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
73
- prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
74
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
75
 
76
  return prompt_embeds, prompt_attention_mask
77
 
78
- @lru_cache(maxsize=128)
79
  def _get_clip_prompt_embeds(
80
  self,
81
- prompt: Union[str, tuple],
82
  num_images_per_prompt: int = 1,
83
  device: Optional[torch.device] = None,
84
  clip_model_index: int = 0,
85
  ):
86
- if isinstance(prompt, tuple):
87
- prompt = list(prompt)
88
- else:
89
- prompt = [prompt]
90
-
91
- batch_size = len(prompt)
92
 
93
  clip_tokenizers = [self.tokenizer, self.tokenizer_2]
94
- clip_text_encoders_attr = ['text_encoder', 'text_encoder_2']
95
 
96
  tokenizer = clip_tokenizers[clip_model_index]
97
- text_encoder_attr = clip_text_encoders_attr[clip_model_index]
98
-
99
- if not hasattr(self, text_encoder_attr):
100
- setattr(self, text_encoder_attr, CLIPTextModelWithProjection.from_pretrained(
101
- os.path.join(self.model_path, text_encoder_attr),
102
- torch_dtype=self.torch_dtype
103
- ).to(device))
104
-
105
- text_encoder = getattr(self, text_encoder_attr)
106
 
107
  text_inputs = tokenizer(
108
  prompt,
@@ -112,48 +107,36 @@ class SD3TextEncoderWithMask(nn.Module):
112
  return_tensors="pt",
113
  )
114
 
115
- text_input_ids = text_inputs.input_ids
 
116
 
117
- with torch.no_grad():
118
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
119
-
120
- pooled_prompt_embeds = prompt_embeds[0]
121
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
122
- pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
123
 
124
  return pooled_prompt_embeds
125
 
126
- @torch.no_grad()
127
  def encode_prompt(self,
128
  prompt,
129
  num_images_per_prompt=1,
130
- device=None,
131
  ):
132
- if isinstance(prompt, str):
133
- prompt = (prompt,)
134
- elif isinstance(prompt, list):
135
- prompt = tuple(prompt)
136
-
137
- pooled_prompt_embed = self._get_clip_prompt_embeds(
138
- prompt=prompt,
139
- device=device,
140
- num_images_per_prompt=num_images_per_prompt,
141
- clip_model_index=0,
142
- )
143
- pooled_prompt_2_embed = self._get_clip_prompt_embeds(
144
- prompt=prompt,
145
- device=device,
146
- num_images_per_prompt=num_images_per_prompt,
147
- clip_model_index=1,
148
- )
149
  pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
150
 
151
- prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
152
- prompt=prompt,
153
- num_images_per_prompt=num_images_per_prompt,
154
- device=device,
155
- )
156
  return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
157
 
158
  def forward(self, input_prompts, device):
159
- return self.encode_prompt(input_prompts, 1, device=device)
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import os
4
+ from transformers import (
5
+ CLIPTextModelWithProjection,
6
+ CLIPTokenizer,
7
+ T5EncoderModel,
8
+ T5TokenizerFast,
9
+ )
10
  from typing import Union, List, Optional
 
11
 
12
  class SD3TextEncoderWithMask(nn.Module):
13
  def __init__(self, model_path, torch_dtype):
14
  super().__init__()
15
+
16
+ # Initialization of models and tokenizers, but delay moving to device
 
 
 
 
17
  self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
 
 
 
18
  self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
 
 
19
  self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
20
+
21
+ # Lazy loading of models for memory efficiency
22
+ self.text_encoder = None
23
+ self.text_encoder_2 = None
24
+ self.text_encoder_3 = None
25
+ self.model_path = model_path
26
+ self.torch_dtype = torch_dtype
27
+ self.tokenizer_max_length = self.tokenizer.model_max_length
28
+
29
+ # Freeze parameters to avoid unnecessary gradient computation
30
  self._freeze()
31
 
32
  def _freeze(self):
33
+ """ Freeze all model parameters to avoid training overhead. """
34
  for param in self.parameters():
35
  param.requires_grad = False
36
 
37
+ def _load_models_if_needed(self):
38
+ """ Load models only if they haven't been loaded already. """
39
+ if self.text_encoder is None:
40
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
41
+ os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
42
+ ).to("cuda")
43
+
44
+ if self.text_encoder_2 is None:
45
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
46
+ os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
47
+ ).to("cuda")
48
+
49
+ if self.text_encoder_3 is None:
50
+ self.text_encoder_3 = T5EncoderModel.from_pretrained(
51
+ os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
52
+ ).to("cuda")
53
+
54
  def _get_t5_prompt_embeds(
55
  self,
56
+ prompt: Union[str, List[str]] = None,
57
  num_images_per_prompt: int = 1,
58
  device: Optional[torch.device] = None,
59
  max_sequence_length: int = 128,
60
  ):
61
+ """ Get embeddings from T5 model. """
62
+ self._load_models_if_needed() # Lazy loading
63
+ prompt = [prompt] if isinstance(prompt, str) else prompt
 
 
 
 
64
  text_inputs = self.tokenizer_3(
65
  prompt,
66
  padding="max_length",
 
69
  add_special_tokens=True,
70
  return_tensors="pt",
71
  )
 
 
 
 
 
 
 
 
72
 
73
+ text_input_ids = text_inputs.input_ids.to(device)
74
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
75
+ prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
76
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype, device=device)
77
 
78
+ # Duplicate embeddings for each image generation
79
+ batch_size = len(prompt)
80
  _, seq_len, _ = prompt_embeds.shape
81
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, seq_len, -1)
82
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1).repeat(num_images_per_prompt, 1)
 
 
 
83
 
84
  return prompt_embeds, prompt_attention_mask
85
 
 
86
  def _get_clip_prompt_embeds(
87
  self,
88
+ prompt: Union[str, List[str]],
89
  num_images_per_prompt: int = 1,
90
  device: Optional[torch.device] = None,
91
  clip_model_index: int = 0,
92
  ):
93
+ """ Get embeddings from CLIP model. """
94
+ self._load_models_if_needed() # Lazy loading
 
 
 
 
95
 
96
  clip_tokenizers = [self.tokenizer, self.tokenizer_2]
97
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
98
 
99
  tokenizer = clip_tokenizers[clip_model_index]
100
+ text_encoder = clip_text_encoders[clip_model_index]
 
 
 
 
 
 
 
 
101
 
102
  text_inputs = tokenizer(
103
  prompt,
 
107
  return_tensors="pt",
108
  )
109
 
110
+ text_input_ids = text_inputs.input_ids.to(device)
111
+ prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]
112
 
113
+ # Duplicate embeddings for each image generation
114
+ batch_size = len(prompt)
115
+ pooled_prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, -1)
 
 
 
116
 
117
  return pooled_prompt_embeds
118
 
 
119
  def encode_prompt(self,
120
  prompt,
121
  num_images_per_prompt=1,
122
+ device=None
123
  ):
124
+ """ Encode the prompt using both CLIP and T5 models. """
125
+ prompt = [prompt] if isinstance(prompt, str) else prompt
126
+
127
+ # Get embeddings from both CLIP models
128
+ pooled_prompt_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=0)
129
+ pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=1)
 
 
 
 
 
 
 
 
 
 
 
130
  pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
131
 
132
+ # Get T5 embeddings
133
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device)
134
+
 
 
135
  return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
136
 
137
  def forward(self, input_prompts, device):
138
+ """ Forward pass for encoding prompts. """
139
+ with torch.no_grad():
140
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, num_images_per_prompt=1, device=device)
141
+
142
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds