Spaces:
Sleeping
Sleeping
Commit ·
ef601ea
1
Parent(s): 5612ce6
Fix ACTION6 error: check available_actions, graceful fallback to CNN
Browse files
app.py
CHANGED
|
@@ -521,7 +521,17 @@ class FluidAgent:
|
|
| 521 |
'candidates': [(signal, answer, confidence)] if answer is not None else []
|
| 522 |
}
|
| 523 |
|
| 524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
# Im found hypothesis, Re found cell → ACTION6
|
| 526 |
r, c, _ = target_cell
|
| 527 |
H, W = grid.shape
|
|
@@ -535,15 +545,19 @@ class FluidAgent:
|
|
| 535 |
|
| 536 |
chosen_id = 6
|
| 537 |
else:
|
| 538 |
-
# CNN fallback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
with torch.no_grad():
|
| 540 |
logits = self.model(feat.unsqueeze(0)).squeeze(0)
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
for a in available_actions if
|
| 545 |
-
int(a.value if hasattr(a, 'value') else a) <= 6]
|
| 546 |
-
indices = [m-1 for m in avail if 1 <= m <= 6]
|
| 547 |
masked = torch.full((6,), float('-inf'))
|
| 548 |
for i in indices:
|
| 549 |
masked[i] = logits[i]
|
|
@@ -555,8 +569,11 @@ class FluidAgent:
|
|
| 555 |
cnn_action_idx = np.random.choice(6, p=probs)
|
| 556 |
|
| 557 |
chosen_id = cnn_action_idx + 1
|
| 558 |
-
meta['source'] = 'cnn'
|
| 559 |
meta['probs'] = probs.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
|
| 561 |
self.prev_feat = feat
|
| 562 |
self.prev_action = chosen_id - 1
|
|
@@ -672,6 +689,11 @@ def _run_agent(game_id, api_key, max_steps):
|
|
| 672 |
|
| 673 |
action, meta = _agent.choose(grid, avail, levels=levels, state=state)
|
| 674 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
diff = (grid != prev_grid) if prev_grid is not None else None
|
| 676 |
prev_grid = grid.copy()
|
| 677 |
|
|
@@ -783,9 +805,15 @@ def pull_frame():
|
|
| 783 |
reasoning = meta.get('reasoning', [])
|
| 784 |
hyp_text = '\n'.join(reasoning[:2]) if reasoning else 'none'
|
| 785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 786 |
_latest['status'] = (
|
| 787 |
-
f"{source_emoji} **{
|
| 788 |
-
f"Step {step} | Action {action}\n\n
|
|
|
|
| 789 |
)
|
| 790 |
|
| 791 |
return (_latest['grid_img'], _latest['hyp_img'], _latest['cand_img'], _latest['status'])
|
|
|
|
| 521 |
'candidates': [(signal, answer, confidence)] if answer is not None else []
|
| 522 |
}
|
| 523 |
|
| 524 |
+
# Check what actions are actually available
|
| 525 |
+
avail_ids = set()
|
| 526 |
+
if available_actions:
|
| 527 |
+
avail_ids = set(int(a.value if hasattr(a, 'value') else a)
|
| 528 |
+
for a in available_actions)
|
| 529 |
+
else:
|
| 530 |
+
avail_ids = set(range(1, 7))
|
| 531 |
+
|
| 532 |
+
# If we have a strong analytic answer and ACTION6 exists, use it
|
| 533 |
+
if (answer is not None and confidence > 0.40 and
|
| 534 |
+
target_cell is not None and 6 in avail_ids):
|
| 535 |
# Im found hypothesis, Re found cell → ACTION6
|
| 536 |
r, c, _ = target_cell
|
| 537 |
H, W = grid.shape
|
|
|
|
| 545 |
|
| 546 |
chosen_id = 6
|
| 547 |
else:
|
| 548 |
+
# CNN fallback (or analytic without click action)
|
| 549 |
+
if answer is not None and confidence > 0.40:
|
| 550 |
+
# We have strong answer but no ACTION6 - pick best alternative
|
| 551 |
+
meta['source'] = 'analytic_fallback'
|
| 552 |
+
meta['note'] = f"Confidence {confidence:.2f} but ACTION6 not available"
|
| 553 |
+
else:
|
| 554 |
+
meta['source'] = 'cnn'
|
| 555 |
+
|
| 556 |
with torch.no_grad():
|
| 557 |
logits = self.model(feat.unsqueeze(0)).squeeze(0)
|
| 558 |
+
|
| 559 |
+
# Mask to only available actions
|
| 560 |
+
indices = [m-1 for m in avail_ids if 1 <= m <= 6]
|
|
|
|
|
|
|
|
|
|
| 561 |
masked = torch.full((6,), float('-inf'))
|
| 562 |
for i in indices:
|
| 563 |
masked[i] = logits[i]
|
|
|
|
| 569 |
cnn_action_idx = np.random.choice(6, p=probs)
|
| 570 |
|
| 571 |
chosen_id = cnn_action_idx + 1
|
|
|
|
| 572 |
meta['probs'] = probs.tolist()
|
| 573 |
+
|
| 574 |
+
# Make sure chosen action is actually available
|
| 575 |
+
if chosen_id not in avail_ids and avail_ids:
|
| 576 |
+
chosen_id = list(avail_ids)[0]
|
| 577 |
|
| 578 |
self.prev_feat = feat
|
| 579 |
self.prev_action = chosen_id - 1
|
|
|
|
| 689 |
|
| 690 |
action, meta = _agent.choose(grid, avail, levels=levels, state=state)
|
| 691 |
|
| 692 |
+
# Add available actions to meta for debugging
|
| 693 |
+
if avail:
|
| 694 |
+
meta['available_actions'] = [int(a.value if hasattr(a, 'value') else a)
|
| 695 |
+
for a in avail]
|
| 696 |
+
|
| 697 |
diff = (grid != prev_grid) if prev_grid is not None else None
|
| 698 |
prev_grid = grid.copy()
|
| 699 |
|
|
|
|
| 805 |
reasoning = meta.get('reasoning', [])
|
| 806 |
hyp_text = '\n'.join(reasoning[:2]) if reasoning else 'none'
|
| 807 |
|
| 808 |
+
avail_actions = meta.get('available_actions', [])
|
| 809 |
+
avail_str = f"Available: {avail_actions}" if avail_actions else ""
|
| 810 |
+
|
| 811 |
+
source_label = {'analytic': 'Analytic', 'analytic_fallback': 'Analytic (no click)', 'cnn': 'CNN'}
|
| 812 |
+
|
| 813 |
_latest['status'] = (
|
| 814 |
+
f"{source_emoji} **{source_label.get(source, source)}** | "
|
| 815 |
+
f"Step {step} | Action {action}\n\n"
|
| 816 |
+
f"{hyp_text}\n\n{avail_str}"
|
| 817 |
)
|
| 818 |
|
| 819 |
return (_latest['grid_img'], _latest['hyp_img'], _latest['cand_img'], _latest['status'])
|