Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -84,6 +84,7 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
|
|
| 84 |
attention_mask=prompt_masks,
|
| 85 |
output_hidden_states=True,
|
| 86 |
).hidden_states[-2]
|
|
|
|
| 87 |
|
| 88 |
return prompt_embeds, prompt_masks
|
| 89 |
|
|
@@ -242,13 +243,13 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
| 242 |
torch.random.manual_seed(int(seed))
|
| 243 |
z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
|
| 244 |
z = z.repeat(2, 1, 1, 1)
|
| 245 |
-
|
| 246 |
with torch.no_grad():
|
| 247 |
if neg_cap != "":
|
| 248 |
cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
|
| 249 |
else:
|
| 250 |
cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
|
| 251 |
-
|
| 252 |
cap_mask = cap_mask.to(cap_feats.device)
|
| 253 |
|
| 254 |
model_kwargs = dict(
|
|
|
|
| 84 |
attention_mask=prompt_masks,
|
| 85 |
output_hidden_states=True,
|
| 86 |
).hidden_states[-2]
|
| 87 |
+
text_encoder.cpu()
|
| 88 |
|
| 89 |
return prompt_embeds, prompt_masks
|
| 90 |
|
|
|
|
| 243 |
torch.random.manual_seed(int(seed))
|
| 244 |
z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
|
| 245 |
z = z.repeat(2, 1, 1, 1)
|
| 246 |
+
model.cpu()
|
| 247 |
with torch.no_grad():
|
| 248 |
if neg_cap != "":
|
| 249 |
cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
|
| 250 |
else:
|
| 251 |
cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
|
| 252 |
+
model.cuda()
|
| 253 |
cap_mask = cap_mask.to(cap_feats.device)
|
| 254 |
|
| 255 |
model_kwargs = dict(
|