Spaces:
Runtime error
Runtime error
Update ldm/modules/encoders/modules.py
Browse files
ldm/modules/encoders/modules.py
CHANGED
|
@@ -61,7 +61,7 @@ class FrozenT5Embedder(AbstractEncoder):
|
|
| 61 |
super().__init__()
|
| 62 |
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
| 63 |
self.transformer = T5EncoderModel.from_pretrained(version)
|
| 64 |
-
self.device = device
|
| 65 |
self.max_length = max_length # TODO: typical value?
|
| 66 |
if freeze:
|
| 67 |
self.freeze()
|
|
@@ -98,7 +98,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|
| 98 |
assert layer in self.LAYERS
|
| 99 |
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
| 100 |
self.transformer = CLIPTextModel.from_pretrained(version)
|
| 101 |
-
self.device = device
|
| 102 |
self.max_length = max_length
|
| 103 |
if freeze:
|
| 104 |
self.freeze()
|
|
@@ -148,7 +148,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|
| 148 |
del model.visual
|
| 149 |
self.model = model
|
| 150 |
|
| 151 |
-
self.device = device
|
| 152 |
self.max_length = max_length
|
| 153 |
if freeze:
|
| 154 |
self.freeze()
|
|
@@ -194,7 +194,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|
| 194 |
|
| 195 |
|
| 196 |
class FrozenCLIPT5Encoder(AbstractEncoder):
|
| 197 |
-
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=
|
| 198 |
clip_max_length=77, t5_max_length=77):
|
| 199 |
super().__init__()
|
| 200 |
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
|
|
|
| 61 |
super().__init__()
|
| 62 |
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
| 63 |
self.transformer = T5EncoderModel.from_pretrained(version)
|
| 64 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 65 |
self.max_length = max_length # TODO: typical value?
|
| 66 |
if freeze:
|
| 67 |
self.freeze()
|
|
|
|
| 98 |
assert layer in self.LAYERS
|
| 99 |
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
| 100 |
self.transformer = CLIPTextModel.from_pretrained(version)
|
| 101 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 102 |
self.max_length = max_length
|
| 103 |
if freeze:
|
| 104 |
self.freeze()
|
|
|
|
| 148 |
del model.visual
|
| 149 |
self.model = model
|
| 150 |
|
| 151 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 152 |
self.max_length = max_length
|
| 153 |
if freeze:
|
| 154 |
self.freeze()
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
class FrozenCLIPT5Encoder(AbstractEncoder):
|
| 197 |
+
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
|
| 198 |
clip_max_length=77, t5_max_length=77):
|
| 199 |
super().__init__()
|
| 200 |
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|