Update modeling_text_encoder.py
Browse files- modeling_text_encoder.py +3 -3
modeling_text_encoder.py
CHANGED
|
@@ -43,17 +43,17 @@ class SD3TextEncoderWithMask(nn.Module):
|
|
| 43 |
if self.text_encoder is None:
|
| 44 |
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
| 45 |
os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
|
| 46 |
-
).to(self.
|
| 47 |
|
| 48 |
if self.text_encoder_2 is None:
|
| 49 |
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
|
| 50 |
os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
|
| 51 |
-
).to(self.
|
| 52 |
|
| 53 |
if self.text_encoder_3 is None:
|
| 54 |
self.text_encoder_3 = T5EncoderModel.from_pretrained(
|
| 55 |
os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
|
| 56 |
-
).to(self.
|
| 57 |
|
| 58 |
def _get_t5_prompt_embeds(
|
| 59 |
self,
|
|
|
|
| 43 |
if self.text_encoder is None:
|
| 44 |
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
| 45 |
os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype
|
| 46 |
+
).to(self.device_0) # Move to GPU 0
|
| 47 |
|
| 48 |
if self.text_encoder_2 is None:
|
| 49 |
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
|
| 50 |
os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype
|
| 51 |
+
).to(self.device_0) # Move to GPU 0
|
| 52 |
|
| 53 |
if self.text_encoder_3 is None:
|
| 54 |
self.text_encoder_3 = T5EncoderModel.from_pretrained(
|
| 55 |
os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype
|
| 56 |
+
).to(self.device_0) # Move to GPU 0
|
| 57 |
|
| 58 |
def _get_t5_prompt_embeds(
|
| 59 |
self,
|