beanapologist commited on
Commit
ef601ea
·
1 Parent(s): 5612ce6

Fix ACTION6 error: check available_actions, graceful fallback to CNN

Browse files
Files changed (1) hide show
  1. app.py +39 -11
app.py CHANGED
@@ -521,7 +521,17 @@ class FluidAgent:
521
  'candidates': [(signal, answer, confidence)] if answer is not None else []
522
  }
523
 
524
- if answer is not None and confidence > 0.40 and target_cell is not None:
 
 
 
 
 
 
 
 
 
 
525
  # Im found hypothesis, Re found cell → ACTION6
526
  r, c, _ = target_cell
527
  H, W = grid.shape
@@ -535,15 +545,19 @@ class FluidAgent:
535
 
536
  chosen_id = 6
537
  else:
538
- # CNN fallback
 
 
 
 
 
 
 
539
  with torch.no_grad():
540
  logits = self.model(feat.unsqueeze(0)).squeeze(0)
541
- avail = list(range(1, 7))
542
- if available_actions:
543
- avail = [int(a.value if hasattr(a, 'value') else a)
544
- for a in available_actions if
545
- int(a.value if hasattr(a, 'value') else a) <= 6]
546
- indices = [m-1 for m in avail if 1 <= m <= 6]
547
  masked = torch.full((6,), float('-inf'))
548
  for i in indices:
549
  masked[i] = logits[i]
@@ -555,8 +569,11 @@ class FluidAgent:
555
  cnn_action_idx = np.random.choice(6, p=probs)
556
 
557
  chosen_id = cnn_action_idx + 1
558
- meta['source'] = 'cnn'
559
  meta['probs'] = probs.tolist()
 
 
 
 
560
 
561
  self.prev_feat = feat
562
  self.prev_action = chosen_id - 1
@@ -672,6 +689,11 @@ def _run_agent(game_id, api_key, max_steps):
672
 
673
  action, meta = _agent.choose(grid, avail, levels=levels, state=state)
674
 
 
 
 
 
 
675
  diff = (grid != prev_grid) if prev_grid is not None else None
676
  prev_grid = grid.copy()
677
 
@@ -783,9 +805,15 @@ def pull_frame():
783
  reasoning = meta.get('reasoning', [])
784
  hyp_text = '\n'.join(reasoning[:2]) if reasoning else 'none'
785
 
 
 
 
 
 
786
  _latest['status'] = (
787
- f"{source_emoji} **{'Analytic' if source == 'analytic' else 'CNN'}** | "
788
- f"Step {step} | Action {action}\n\n{hyp_text}"
 
789
  )
790
 
791
  return (_latest['grid_img'], _latest['hyp_img'], _latest['cand_img'], _latest['status'])
 
521
  'candidates': [(signal, answer, confidence)] if answer is not None else []
522
  }
523
 
524
+ # Check what actions are actually available
525
+ avail_ids = set()
526
+ if available_actions:
527
+ avail_ids = set(int(a.value if hasattr(a, 'value') else a)
528
+ for a in available_actions)
529
+ else:
530
+ avail_ids = set(range(1, 7))
531
+
532
+ # If we have a strong analytic answer and ACTION6 exists, use it
533
+ if (answer is not None and confidence > 0.40 and
534
+ target_cell is not None and 6 in avail_ids):
535
  # Im found hypothesis, Re found cell → ACTION6
536
  r, c, _ = target_cell
537
  H, W = grid.shape
 
545
 
546
  chosen_id = 6
547
  else:
548
+ # CNN fallback (or analytic without click action)
549
+ if answer is not None and confidence > 0.40:
550
+ # We have strong answer but no ACTION6 - pick best alternative
551
+ meta['source'] = 'analytic_fallback'
552
+ meta['note'] = f"Confidence {confidence:.2f} but ACTION6 not available"
553
+ else:
554
+ meta['source'] = 'cnn'
555
+
556
  with torch.no_grad():
557
  logits = self.model(feat.unsqueeze(0)).squeeze(0)
558
+
559
+ # Mask to only available actions
560
+ indices = [m-1 for m in avail_ids if 1 <= m <= 6]
 
 
 
561
  masked = torch.full((6,), float('-inf'))
562
  for i in indices:
563
  masked[i] = logits[i]
 
569
  cnn_action_idx = np.random.choice(6, p=probs)
570
 
571
  chosen_id = cnn_action_idx + 1
 
572
  meta['probs'] = probs.tolist()
573
+
574
+ # Make sure chosen action is actually available
575
+ if chosen_id not in avail_ids and avail_ids:
576
+ chosen_id = list(avail_ids)[0]
577
 
578
  self.prev_feat = feat
579
  self.prev_action = chosen_id - 1
 
689
 
690
  action, meta = _agent.choose(grid, avail, levels=levels, state=state)
691
 
692
+ # Add available actions to meta for debugging
693
+ if avail:
694
+ meta['available_actions'] = [int(a.value if hasattr(a, 'value') else a)
695
+ for a in avail]
696
+
697
  diff = (grid != prev_grid) if prev_grid is not None else None
698
  prev_grid = grid.copy()
699
 
 
805
  reasoning = meta.get('reasoning', [])
806
  hyp_text = '\n'.join(reasoning[:2]) if reasoning else 'none'
807
 
808
+ avail_actions = meta.get('available_actions', [])
809
+ avail_str = f"Available: {avail_actions}" if avail_actions else ""
810
+
811
+ source_label = {'analytic': 'Analytic', 'analytic_fallback': 'Analytic (no click)', 'cnn': 'CNN'}
812
+
813
  _latest['status'] = (
814
+ f"{source_emoji} **{source_label.get(source, source)}** | "
815
+ f"Step {step} | Action {action}\n\n"
816
+ f"{hyp_text}\n\n{avail_str}"
817
  )
818
 
819
  return (_latest['grid_img'], _latest['hyp_img'], _latest['cand_img'], _latest['status'])