xxxpo13 commited on
Commit
f02e267
·
verified ·
1 Parent(s): d8af457

Update modeling_text_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_text_encoder.py +10 -15
modeling_text_encoder.py CHANGED
@@ -5,13 +5,12 @@ from transformers import (
5
  CLIPTextModelWithProjection,
6
  CLIPTokenizer,
7
  T5EncoderModel,
8
- T5TokenizerFast,
9
- BitsAndBytesConfig # Import for 8-bit quantization
10
  )
11
  from typing import Union, List, Optional
12
 
13
  class SD3TextEncoderWithMask(nn.Module):
14
- def __init__(self, model_path, torch_dtype):
15
  super().__init__()
16
 
17
  # Tokenizers for CLIP and T5
@@ -27,9 +26,6 @@ class SD3TextEncoderWithMask(nn.Module):
27
  self.torch_dtype = torch_dtype
28
  self.tokenizer_max_length = self.tokenizer.model_max_length
29
 
30
- # Quantization config for T5 model
31
- self.quantization_config = BitsAndBytesConfig(load_in_8bit=True) # Quantize T5 to 8-bit
32
-
33
  # Freeze parameters to avoid unnecessary gradient computation
34
  self._freeze()
35
 
@@ -51,12 +47,11 @@ class SD3TextEncoderWithMask(nn.Module):
51
  ).to("cuda")
52
 
53
  if self.text_encoder_3 is None:
54
- # Load the T5 model with 8-bit quantization
55
  self.text_encoder_3 = T5EncoderModel.from_pretrained(
56
- os.path.join(self.model_path, 'text_encoder_3'),
57
- torch_dtype=self.torch_dtype,
58
- quantization_config=self.quantization_config # Apply quantization
59
- ) # Do NOT use .to("cuda") for 8-bit quantized models, as they handle device placement automatically
60
 
61
  def _get_t5_prompt_embeds(
62
  self,
@@ -65,7 +60,7 @@ class SD3TextEncoderWithMask(nn.Module):
65
  device: Optional[torch.device] = None,
66
  max_sequence_length: int = 128,
67
  ):
68
- """ Get embeddings from the T5 model. """
69
  self._load_models_if_needed() # Lazy loading
70
  prompt = [prompt] if isinstance(prompt, str) else prompt
71
  text_inputs = self.tokenizer_3(
@@ -80,7 +75,7 @@ class SD3TextEncoderWithMask(nn.Module):
80
  text_input_ids = text_inputs.input_ids.to(device)
81
  prompt_attention_mask = text_inputs.attention_mask.to(device)
82
 
83
- # Use the T5 model to generate embeddings (quantized)
84
  prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
85
  prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype) # Ensure correct dtype
86
 
@@ -116,7 +111,7 @@ class SD3TextEncoderWithMask(nn.Module):
116
  return_tensors="pt",
117
  )
118
 
119
- text_input_ids = text_inputs.input_ids.to(device)
120
  prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]
121
 
122
  # Duplicate embeddings for each image generation
@@ -138,7 +133,7 @@ class SD3TextEncoderWithMask(nn.Module):
138
  pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=1)
139
  pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
140
 
141
- # Get T5 embeddings
142
  prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device)
143
 
144
  return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
 
5
  CLIPTextModelWithProjection,
6
  CLIPTokenizer,
7
  T5EncoderModel,
8
+ T5TokenizerFast
 
9
  )
10
  from typing import Union, List, Optional
11
 
12
  class SD3TextEncoderWithMask(nn.Module):
13
+ def __init__(self, model_path, torch_dtype=torch.float16):
14
  super().__init__()
15
 
16
  # Tokenizers for CLIP and T5
 
26
  self.torch_dtype = torch_dtype
27
  self.tokenizer_max_length = self.tokenizer.model_max_length
28
 
 
 
 
29
  # Freeze parameters to avoid unnecessary gradient computation
30
  self._freeze()
31
 
 
47
  ).to("cuda")
48
 
49
  if self.text_encoder_3 is None:
50
+ # Load the FP8 T5 model (adjust path if needed for specific FP8 model)
51
  self.text_encoder_3 = T5EncoderModel.from_pretrained(
52
+ os.path.join(self.model_path, 'text_encoder_3_fp8'), # Use the FP8 version
53
+ torch_dtype=torch.float8 # FP8 precision
54
+ ).to("cuda")
 
55
 
56
  def _get_t5_prompt_embeds(
57
  self,
 
60
  device: Optional[torch.device] = None,
61
  max_sequence_length: int = 128,
62
  ):
63
+ """ Get embeddings from T5 model. """
64
  self._load_models_if_needed() # Lazy loading
65
  prompt = [prompt] if isinstance(prompt, str) else prompt
66
  text_inputs = self.tokenizer_3(
 
75
  text_input_ids = text_inputs.input_ids.to(device)
76
  prompt_attention_mask = text_inputs.attention_mask.to(device)
77
 
78
+ # Use the T5 model to generate embeddings in FP8
79
  prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
80
  prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype) # Ensure correct dtype
81
 
 
111
  return_tensors="pt",
112
  )
113
 
114
+ text_input_ids = text_inputs.input.ids.to(device)
115
  prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0]
116
 
117
  # Duplicate embeddings for each image generation
 
133
  pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device, clip_model_index=1)
134
  pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
135
 
136
+ # Get T5 embeddings in FP8
137
  prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=device)
138
 
139
  return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds