xxxpo13 commited on
Commit
968db84
·
verified ·
1 Parent(s): 719504e

Upload modeling_text_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_text_encoder.py +142 -0
modeling_text_encoder.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ from transformers import (
5
+ CLIPTextModelWithProjection,
6
+ CLIPTokenizer,
7
+ T5EncoderModel,
8
+ T5TokenizerFast,
9
+ )
10
+ from typing import Union, List, Optional
11
+
12
+ class SD3TextEncoderWithMask(nn.Module):
13
+ def __init__(self, model_path, torch_dtype):
14
+ super().__init__()
15
+
16
+ # Initialization of models and tokenizers, but delay moving to device
17
+ self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
18
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
19
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
20
+
21
+ # Lazy loading of models for memory efficiency
22
+ self.text_encoder = None
23
+ self.text_encoder_2 = None
24
+ self.text_encoder_3 = None
25
+ self.model_path = model_path
26
+ self.torch_dtype = torch_dtype
27
+ self.tokenizer_max_length = self.tokenizer.model_max_length
28
+
29
+ # Freeze parameters to avoid unnecessary gradient computation
30
+ self._freeze()
31
+
32
+ def _freeze(self):
33
+ """ Freeze all model parameters to avoid training overhead. """
34
+ for param in self.parameters():
35
+ param.requires_grad = False
36
+
37
+ def _load_models_if_needed(self):
38
+ """ Load models only if they haven't been loaded already. """
39
+ if self.text_encoder is None:
40
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
41
+ os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
42
+ ).to("cuda")
43
+
44
+ if self.text_encoder_2 is None:
45
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
46
+ os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
47
+ ).to("cuda")
48
+
49
+ if self.text_encoder_3 is None:
50
+ self.text_encoder_3 = T5EncoderModel.from_pretrained(
51
+ os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
52
+ ).to("cuda")
53
+
54
+ def _get_t5_prompt_embeds(
55
+ self,
56
+ prompt: Union[str, List[str]] = None,
57
+ num_images_per_prompt: int = 1,
58
+ device: Optional[torch.device] = None,
59
+ max_sequence_length: int = 128,
60
+ ):
61
+ """ Get embeddings from T5 model. """
62
+ self._load_models_if_needed() # Lazy loading
63
+ prompt = [prompt] if isinstance(prompt, str) else prompt
64
+ text_inputs = self.tokenizer_3(
65
+ prompt,
66
+ padding="max_length",
67
+ max_length=max_sequence_length,
68
+ truncation=True,
69
+ add_special_tokens=True,
70
+ return_tensors="pt",
71
+ )
72
+
73
+ text_input_ids = text_inputs.input_ids.to(device)
74
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
75
+ prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
76
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype, device=device)
77
+
78
+ # Duplicate embeddings for each image generation
79
+ batch_size = len(prompt)
80
+ _, seq_len, _ = prompt_embeds.shape
81
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, seq_len, -1)
82
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1).repeat(num_images_per_prompt, 1)
83
+
84
+ return prompt_embeds, prompt_attention_mask
85
+
86
+ def _get_clip_prompt_embeds(
87
+ self,
88
+ prompt: Union[str, List[str]],
89
+ num_images_per_prompt: int = 1,
90
+ device: Optional[torch.device] = None,
91
+ clip_model_index: int = 0,
92
+ ):
93
+ """ Get embeddings from CLIP model. """
94
+ self._load_models_if_needed() # Lazy loading
95
+
96
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
97
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
98
+
99
+ tokenizer = clip_tokenizers[clip_model_index]
100
+ text_encoder = clip_text_encoders[clip_model_index]
101
+
102
+ text_inputs = tokenizer(
103
+ prompt,
104
+ padding="max_length",
105
+ max_length=self.tokenizer_max_length,
106
+ truncation=True,
107
+ return_tensors="pt",
108
+ )
109
+
110
+ text_input_ids = text_inputs.input_ids.to(device)
111
+ prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]
112
+
113
+ # Duplicate embeddings for each image generation
114
+ batch_size = len(prompt)
115
+ pooled_prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, -1)
116
+
117
+ return pooled_prompt_embeds
118
+
119
+ def encode_prompt(self,
120
+ prompt,
121
+ num_images_per_prompt=1,
122
+ device=None
123
+ ):
124
+ """ Encode the prompt using both CLIP and T5 models. """
125
+ prompt = [prompt] if isinstance(prompt, str) else prompt
126
+
127
+ # Get embeddings from both CLIP models
128
+ pooled_prompt_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=0)
129
+ pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=1)
130
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
131
+
132
+ # Get T5 embeddings
133
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device)
134
+
135
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
136
+
137
+ def forward(self, input_prompts, device):
138
+ """ Forward pass for encoding prompts. """
139
+ with torch.no_grad():
140
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, num_images_per_prompt=1, device=device)
141
+
142
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds