Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -117,15 +117,15 @@ def infer(model, replay_file,
|
|
| 117 |
remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool)
|
| 118 |
remove_ties_mask = remove_ties_mask.numpy()
|
| 119 |
if remove_ties_mask.any():
|
| 120 |
-
tie_probs = preds[remove_ties_mask, "Tie"]
|
| 121 |
q = (1 - tie_probs)
|
| 122 |
for c in preds.columns:
|
| 123 |
if c.startswith("Blue") or c.startswith("Orange"):
|
| 124 |
-
preds[remove_ties_mask, c] /= q
|
| 125 |
if ignore_ties:
|
| 126 |
preds = preds.drop("Tie", axis=1)
|
| 127 |
else:
|
| 128 |
-
preds[remove_ties_mask, "Tie"] = 0.0
|
| 129 |
|
| 130 |
return preds
|
| 131 |
|
|
|
|
| 117 |
remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool)
|
| 118 |
remove_ties_mask = remove_ties_mask.numpy()
|
| 119 |
if remove_ties_mask.any():
|
| 120 |
+
tie_probs = preds.loc[remove_ties_mask, "Tie"]
|
| 121 |
q = (1 - tie_probs)
|
| 122 |
for c in preds.columns:
|
| 123 |
if c.startswith("Blue") or c.startswith("Orange"):
|
| 124 |
+
preds.loc[remove_ties_mask, c] /= q
|
| 125 |
if ignore_ties:
|
| 126 |
preds = preds.drop("Tie", axis=1)
|
| 127 |
else:
|
| 128 |
+
preds.loc[remove_ties_mask, "Tie"] = 0.0
|
| 129 |
|
| 130 |
return preds
|
| 131 |
|