Spaces:
Runtime error
Runtime error
Badr AlKhamissi
commited on
Commit
·
e8f6bdd
1
Parent(s):
9530fad
losses fix
Browse files- code/losses.py +2 -5
code/losses.py
CHANGED
|
@@ -14,12 +14,11 @@ from transformers import CLIPProcessor, CLIPModel
|
|
| 14 |
from diffusers import StableDiffusionPipeline
|
| 15 |
|
| 16 |
class SDSLoss(nn.Module):
|
| 17 |
-
def __init__(self, cfg, device):
|
| 18 |
super(SDSLoss, self).__init__()
|
| 19 |
self.cfg = cfg
|
| 20 |
self.device = device
|
| 21 |
-
self.pipe =
|
| 22 |
-
torch_dtype=torch.float16, use_auth_token=cfg.token)
|
| 23 |
self.pipe = self.pipe.to(self.device)
|
| 24 |
|
| 25 |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
|
|
@@ -55,8 +54,6 @@ class SDSLoss(nn.Module):
|
|
| 55 |
text_embeddings = img_emb
|
| 56 |
uncond_embeddings = img_emb
|
| 57 |
|
| 58 |
-
print(text_embeddings.size())
|
| 59 |
-
print(uncond_embeddings.size())
|
| 60 |
self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 61 |
self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
|
| 62 |
del self.pipe.tokenizer
|
|
|
|
| 14 |
from diffusers import StableDiffusionPipeline
|
| 15 |
|
| 16 |
class SDSLoss(nn.Module):
|
| 17 |
+
def __init__(self, cfg, device, model):
|
| 18 |
super(SDSLoss, self).__init__()
|
| 19 |
self.cfg = cfg
|
| 20 |
self.device = device
|
| 21 |
+
self.pipe = model
|
|
|
|
| 22 |
self.pipe = self.pipe.to(self.device)
|
| 23 |
|
| 24 |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
|
|
|
|
| 54 |
text_embeddings = img_emb
|
| 55 |
uncond_embeddings = img_emb
|
| 56 |
|
|
|
|
|
|
|
| 57 |
self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 58 |
self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
|
| 59 |
del self.pipe.tokenizer
|