Spaces:
Runtime error
Runtime error
Update inference.py
Browse files- inference.py +5 -2
inference.py
CHANGED
|
@@ -45,6 +45,11 @@ def generate_response(goal, option1, option2):
|
|
| 45 |
enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
| 46 |
enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
for k in enc1:
|
| 49 |
enc1[k] = enc1[k].to(device)
|
| 50 |
enc2[k] = enc2[k].to(device)
|
|
@@ -56,13 +61,11 @@ def generate_response(goal, option1, option2):
|
|
| 56 |
logits1 = get_logits(out1)
|
| 57 |
logits2 = get_logits(out2)
|
| 58 |
|
| 59 |
-
# Sanity check shape
|
| 60 |
if logits1.shape[-1] < 2 or logits2.shape[-1] < 2:
|
| 61 |
raise ValueError("Model did not return 2-class logits.")
|
| 62 |
|
| 63 |
score1 = logits1[0][1].item()
|
| 64 |
score2 = logits2[0][1].item()
|
| 65 |
-
|
| 66 |
evo_result = option1 if score1 > score2 else option2
|
| 67 |
|
| 68 |
except Exception as e:
|
|
|
|
| 45 |
enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
| 46 |
enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
| 47 |
|
| 48 |
+
# Remove token_type_ids to avoid crash in EvoTransformer
|
| 49 |
+
enc1.pop("token_type_ids", None)
|
| 50 |
+
enc2.pop("token_type_ids", None)
|
| 51 |
+
|
| 52 |
+
# Move tensors to device
|
| 53 |
for k in enc1:
|
| 54 |
enc1[k] = enc1[k].to(device)
|
| 55 |
enc2[k] = enc2[k].to(device)
|
|
|
|
| 61 |
logits1 = get_logits(out1)
|
| 62 |
logits2 = get_logits(out2)
|
| 63 |
|
|
|
|
| 64 |
if logits1.shape[-1] < 2 or logits2.shape[-1] < 2:
|
| 65 |
raise ValueError("Model did not return 2-class logits.")
|
| 66 |
|
| 67 |
score1 = logits1[0][1].item()
|
| 68 |
score2 = logits2[0][1].item()
|
|
|
|
| 69 |
evo_result = option1 if score1 > score2 else option2
|
| 70 |
|
| 71 |
except Exception as e:
|