Spaces:
Runtime error
Runtime error
entry
Browse files
app.py
CHANGED
|
@@ -234,7 +234,10 @@ def main():
|
|
| 234 |
|
| 235 |
# Prepare the conditioning
|
| 236 |
cond_tokens, cond_str_tokens = description2tokens(description, metadata.word2id , cfg)
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
| 238 |
if pointer_words is not None:
|
| 239 |
numberical_conditioning = [float(description["cost_to_pointer"][key]) for key in pointer_words if key in description["cost_to_pointer"]]
|
| 240 |
else:
|
|
|
|
| 234 |
|
| 235 |
# Prepare the conditioning
|
| 236 |
cond_tokens, cond_str_tokens = description2tokens(description, metadata.word2id , cfg)
|
| 237 |
+
if is_cuda:
|
| 238 |
+
cond_tokens = torch.tensor(cond_tokens).long().cuda()
|
| 239 |
+
else:
|
| 240 |
+
cond_tokens = torch.tensor(cond_tokens).long()
|
| 241 |
if pointer_words is not None:
|
| 242 |
numberical_conditioning = [float(description["cost_to_pointer"][key]) for key in pointer_words if key in description["cost_to_pointer"]]
|
| 243 |
else:
|