xxxpo13 commited on
Commit
d8af457
·
verified ·
1 Parent(s): 87af658

Update modeling_text_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_text_encoder.py +15 -6
modeling_text_encoder.py CHANGED
@@ -6,6 +6,7 @@ from transformers import (
6
  CLIPTokenizer,
7
  T5EncoderModel,
8
  T5TokenizerFast,
 
9
  )
10
  from typing import Union, List, Optional
11
 
@@ -13,12 +14,12 @@ class SD3TextEncoderWithMask(nn.Module):
13
  def __init__(self, model_path, torch_dtype):
14
  super().__init__()
15
 
16
- # Initialization of models and tokenizers, but delay moving to device
17
  self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
18
  self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
19
  self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
20
 
21
- # Lazy loading of models for memory efficiency
22
  self.text_encoder = None
23
  self.text_encoder_2 = None
24
  self.text_encoder_3 = None
@@ -26,6 +27,9 @@ class SD3TextEncoderWithMask(nn.Module):
26
  self.torch_dtype = torch_dtype
27
  self.tokenizer_max_length = self.tokenizer.model_max_length
28
 
 
 
 
29
  # Freeze parameters to avoid unnecessary gradient computation
30
  self._freeze()
31
 
@@ -47,9 +51,12 @@ class SD3TextEncoderWithMask(nn.Module):
47
  ).to("cuda")
48
 
49
  if self.text_encoder_3 is None:
 
50
  self.text_encoder_3 = T5EncoderModel.from_pretrained(
51
- os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
52
- ).to("cuda")
 
 
53
 
54
  def _get_t5_prompt_embeds(
55
  self,
@@ -58,7 +65,7 @@ class SD3TextEncoderWithMask(nn.Module):
58
  device: Optional[torch.device] = None,
59
  max_sequence_length: int = 128,
60
  ):
61
- """ Get embeddings from T5 model. """
62
  self._load_models_if_needed() # Lazy loading
63
  prompt = [prompt] if isinstance(prompt, str) else prompt
64
  text_inputs = self.tokenizer_3(
@@ -72,8 +79,10 @@ class SD3TextEncoderWithMask(nn.Module):
72
 
73
  text_input_ids = text_inputs.input_ids.to(device)
74
  prompt_attention_mask = text_inputs.attention_mask.to(device)
 
 
75
  prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
76
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype, device=device)
77
 
78
  # Duplicate embeddings for each image generation
79
  batch_size = len(prompt)
 
6
  CLIPTokenizer,
7
  T5EncoderModel,
8
  T5TokenizerFast,
9
+ BitsAndBytesConfig # Import for 8-bit quantization
10
  )
11
  from typing import Union, List, Optional
12
 
 
14
  def __init__(self, model_path, torch_dtype):
15
  super().__init__()
16
 
17
+ # Tokenizers for CLIP and T5
18
  self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
19
  self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
20
  self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
21
 
22
+ # Lazy loading for memory efficiency
23
  self.text_encoder = None
24
  self.text_encoder_2 = None
25
  self.text_encoder_3 = None
 
27
  self.torch_dtype = torch_dtype
28
  self.tokenizer_max_length = self.tokenizer.model_max_length
29
 
30
+ # Quantization config for T5 model
31
+ self.quantization_config = BitsAndBytesConfig(load_in_8bit=True) # Quantize T5 to 8-bit
32
+
33
  # Freeze parameters to avoid unnecessary gradient computation
34
  self._freeze()
35
 
 
51
  ).to("cuda")
52
 
53
  if self.text_encoder_3 is None:
54
+ # Load the T5 model with 8-bit quantization
55
  self.text_encoder_3 = T5EncoderModel.from_pretrained(
56
+ os.path.join(self.model_path, 'text_encoder_3'),
57
+ torch_dtype=self.torch_dtype,
58
+ quantization_config=self.quantization_config # Apply quantization
59
+ ) # Do NOT use .to("cuda") for 8-bit quantized models, as they handle device placement automatically
60
 
61
  def _get_t5_prompt_embeds(
62
  self,
 
65
  device: Optional[torch.device] = None,
66
  max_sequence_length: int = 128,
67
  ):
68
+ """ Get embeddings from the T5 model. """
69
  self._load_models_if_needed() # Lazy loading
70
  prompt = [prompt] if isinstance(prompt, str) else prompt
71
  text_inputs = self.tokenizer_3(
 
79
 
80
  text_input_ids = text_inputs.input_ids.to(device)
81
  prompt_attention_mask = text_inputs.attention_mask.to(device)
82
+
83
+ # Use the T5 model to generate embeddings (quantized)
84
  prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0]
85
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype) # Ensure correct dtype
86
 
87
  # Duplicate embeddings for each image generation
88
  batch_size = len(prompt)