Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import annotations | |
| import copy | |
| import logging | |
| import os | |
| from typing import Any, Callable, Dict, List, Optional, Set, Tuple | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..models import VisualReasoningAction, VisualReasoningObservation | |
| except ImportError: | |
| from models import VisualReasoningAction, VisualReasoningObservation | |
| try: | |
| from .regions import REGION_STYLES, layout_members, resolve_bounds, _GRID_CELL | |
| except ImportError: | |
| from server.regions import REGION_STYLES, layout_members, resolve_bounds, _GRID_CELL | |
| try: | |
| from .scenario_loader import build_repository | |
| except ImportError: | |
| from server.scenario_loader import build_repository | |
| try: | |
| from .scoring import ( | |
| BREAKDOWN_KEYS, | |
| compute_conflicts, | |
| compute_efficiency_bonus, | |
| compute_score_breakdown, | |
| first_conflict_message, | |
| is_concept_evidenced, | |
| is_hard_invalid, | |
| is_no_op, | |
| normalize_action, | |
| ) | |
| except ImportError: | |
| from server.scoring import ( | |
| BREAKDOWN_KEYS, | |
| compute_conflicts, | |
| compute_efficiency_bonus, | |
| compute_score_breakdown, | |
| first_conflict_message, | |
| is_concept_evidenced, | |
| is_hard_invalid, | |
| is_no_op, | |
| normalize_action, | |
| ) | |
| try: | |
| from .constants import ROLE_VALUES, REGION_STYLES as _RS | |
| except ImportError: | |
| from server.constants import ROLE_VALUES, REGION_STYLES as _RS | |
| try: | |
| from .narration_scorer import warmup_scorer | |
| except ImportError: | |
| from server.narration_scorer import warmup_scorer | |
| _INVALID_ACTION_PENALTY = -0.02 | |
| _NOOP_ACTION_PENALTY = -0.05 | |
| _SUBMIT_BONUS_THRESHOLD = 0.90 | |
| _SUBMIT_BONUS = 0.05 | |
| # Flat bonus per newly-covered concept. Adds a dense positive signal | |
| # for genuine pedagogical progress on top of the delta-based reward — | |
| # addresses sparse rewards during iterative/steady phases. | |
| _NEW_CONCEPT_BONUS = 0.05 | |
| _REWARD_MIN = -0.2 | |
| _REWARD_MAX = 1.0 | |
| _LOG_LEVEL = os.getenv("VISUAL_REASONING_LOG_LEVEL", "INFO").upper() | |
| _LOG_EVERY_STEP = os.getenv("VISUAL_REASONING_LOG_EVERY_STEP", "0") == "1" | |
| logger = logging.getLogger(__name__) | |
| if not logger.handlers: | |
| logging.basicConfig(level=_LOG_LEVEL) | |
| logger.setLevel(_LOG_LEVEL) | |
| class VisualReasoningEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._reset_count = 0 | |
| self._repository = build_repository() | |
| self._scenario: Dict[str, Any] = {} | |
| warmup_scorer() | |
| self._entities: Dict[str, Dict[str, Any]] = {} | |
| self._relations: List[Dict[str, Any]] = [] | |
| self._layout: Dict[str, Dict[str, float]] = {} | |
| self._annotations: List[Dict[str, Any]] = [] | |
| self._notes: List[Dict[str, Any]] = [] | |
| self._narration_history: List[str] = [] | |
| self._step_history: List[Dict[str, Any]] = [] | |
| self._concept_coverage: List[str] = [] | |
| self._remaining_steps = 0 | |
| self._previous_score = 0.0 | |
| self._done_reason: Optional[str] = None | |
| self._consecutive_noops = 0 | |
| self._op_dispatch: Dict[str, Tuple[Callable[..., Optional[str]], int]] = { | |
| "add_region": (lambda t, p: self._op_add_region(t[0], p), 1), | |
| "add_node": (lambda t, p: self._op_add_placeable(t[0], "node", p), 1), | |
| "add_pointer": (lambda t, p: self._op_add_placeable(t[0], "pointer", p), 1), | |
| "add_container": (lambda t, p: self._op_add_placeable(t[0], "container", p), 1), | |
| "add_edge": (lambda t, p: self._op_add_edge(t[0], t[1], p), 2), | |
| "push_to": (lambda t, p: self._op_push(t[0], t[1]), 2), | |
| "pop_from": (lambda t, p: self._op_pop(t[0]), 1), | |
| "move_pointer": (lambda t, p: self._op_move_pointer(t[0], p), 1), | |
| "annotate": (lambda t, p: self._op_annotate(t[0], p), 1), | |
| "highlight": (lambda t, p: self._op_highlight(t[0]), 1), | |
| "unhighlight": (lambda t, p: self._op_unhighlight(t[0]), 1), | |
| "add_note": (lambda t, p: self._op_add_note(p), 0), | |
| "set_role": (lambda t, p: self._op_set_role(t[0], p), 1), | |
| "set_value": (lambda t, p: self._op_set_value(t[0], p), 1), | |
| "remove_edge": (lambda t, p: self._op_remove_edge(t[0], t[1], p), 2), | |
| "remove_entity": (lambda t, p: self._op_remove_entity(t[0]), 1), | |
| } | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| scenario_id: Optional[str] = None, | |
| task_name: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> VisualReasoningObservation: | |
| self._state = State(episode_id=episode_id or str(uuid4()), step_count=0) | |
| self._reset_count += 1 | |
| try: | |
| self._load_scenario( | |
| scenario_id=_clean_str(scenario_id), | |
| task_name=_clean_str(task_name, lower=True), | |
| ) | |
| except (ValueError, KeyError, TypeError): | |
| logger.warning( | |
| "Invalid reset params scenario_id=%r task_name=%r — falling back to next_scenario", | |
| scenario_id, | |
| task_name, | |
| ) | |
| self._load_scenario() | |
| self._previous_score = self._reset_score_baseline() | |
| self._done_reason = None | |
| self._consecutive_noops = 0 | |
| display = {key: 0.0 for key in BREAKDOWN_KEYS} | |
| return self._build_observation( | |
| score_breakdown=display, reward=0.0, done=False, | |
| ) | |
| def step( | |
| self, | |
| action: VisualReasoningAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> VisualReasoningObservation: | |
| if not self._scenario: | |
| logger.warning("step() called before reset() — auto-resetting") | |
| self.reset() | |
| self._state.step_count += 1 | |
| if self._remaining_steps <= 0: | |
| self._done_reason = "step_budget_exhausted" | |
| return self._build_observation( | |
| score_breakdown={k: 0.0 for k in BREAKDOWN_KEYS}, | |
| reward=0.0, done=True, | |
| action_error="step_budget_exhausted", | |
| ) | |
| prev_state = self._snapshot_state() | |
| attempted_ops = _raw_attempted_ops(action) | |
| applied_ops: List[Dict[str, Any]] = [] | |
| normalized, parse_error = normalize_action(action) | |
| action_error: Optional[str] = parse_error | |
| penalty_flat = 0.0 | |
| was_hard_invalid = parse_error is not None | |
| if normalized is None: | |
| raw_step = _raw_step_type(action) | |
| normalized = _empty_action(raw_step if raw_step in ("advance", "complete") else "advance") | |
| penalty_flat += _INVALID_ACTION_PENALTY | |
| conflicts: Dict[str, int] = {} | |
| else: | |
| conflicts = compute_conflicts(normalized, prev_state) | |
| if is_hard_invalid(conflicts): | |
| action_error = first_conflict_message(conflicts, normalized) | |
| penalty_flat += _INVALID_ACTION_PENALTY | |
| was_hard_invalid = True | |
| else: | |
| applied_ops = self._apply_action(normalized) | |
| if conflicts.get("invalid_role", 0): | |
| action_error = first_conflict_message(conflicts, normalized) | |
| scoring_action = dict(normalized) | |
| scoring_action["ops"] = copy.deepcopy(applied_ops) | |
| scoring_action["attempted_ops"] = copy.deepcopy(attempted_ops) | |
| self._narration_history.append(normalized.get("narration", "")) | |
| checklist = set(self._scenario.get("concept_checklist") or []) | |
| credited_concepts: List[str] = [] | |
| newly_covered_count = 0 | |
| if not was_hard_invalid and applied_ops: | |
| for concept in normalized.get("covered_concepts", []): | |
| if ( | |
| concept in checklist | |
| and concept not in credited_concepts | |
| and is_concept_evidenced(concept, scoring_action) | |
| ): | |
| credited_concepts.append(concept) | |
| if concept not in self._concept_coverage: | |
| self._concept_coverage.append(concept) | |
| newly_covered_count += 1 | |
| self._step_history.append({ | |
| "step_id": len(self._step_history) + 1, | |
| "step_type": normalized.get("step_type", "advance"), | |
| "narration": normalized.get("narration", ""), | |
| "ops": copy.deepcopy(applied_ops), | |
| "attempted_ops": copy.deepcopy(attempted_ops), | |
| "covered_concepts": list(credited_concepts), | |
| "intent": normalized.get("intent", ""), | |
| "action_error": action_error, | |
| }) | |
| self._remaining_steps = max(0, self._remaining_steps - 1) | |
| new_state = self._snapshot_state() | |
| template = self._scenario.get("template", "") | |
| breakdown = compute_score_breakdown( | |
| self._scenario.get("task_name", "easy"), | |
| template, | |
| scoring_action, | |
| prev_state, | |
| new_state, | |
| self._scenario, | |
| conflicts, | |
| self._step_history, | |
| ) | |
| was_noop = ( | |
| not was_hard_invalid | |
| and is_no_op(scoring_action, prev_state, new_state) | |
| ) | |
| if was_noop: | |
| self._consecutive_noops += 1 | |
| noop_penalty = _NOOP_ACTION_PENALTY | |
| if self._consecutive_noops >= 2: | |
| noop_penalty *= 2 | |
| penalty_flat += noop_penalty | |
| elif not was_hard_invalid: | |
| self._consecutive_noops = 0 | |
| score = breakdown.get("overall_score", 0.0) | |
| if was_hard_invalid: | |
| reward = penalty_flat | |
| # After a hard-invalid, the next valid step's score often | |
| # drops from the prior peak because a dropped op means less | |
| # state progress (coverage, attention, etc.). Pin | |
| # previous_score to the current score so the next step's | |
| # delta starts from a realistic baseline and isn't punished | |
| # twice for the same rejection. | |
| self._previous_score = score | |
| else: | |
| reward = (score - self._previous_score) + penalty_flat | |
| # Dense shaping: immediate positive signal per newly-covered | |
| # concept — partially replaces the all-at-end submit bonus | |
| # so the LLM gets feedback throughout an episode. | |
| if newly_covered_count > 0: | |
| reward += _NEW_CONCEPT_BONUS * newly_covered_count | |
| if was_noop and reward > penalty_flat: | |
| reward = penalty_flat | |
| if was_noop: | |
| self._previous_score = max(self._previous_score, score) | |
| else: | |
| self._previous_score = score | |
| done = False | |
| if normalized.get("step_type") == "complete": | |
| done = True | |
| self._done_reason = "complete" | |
| coverage = breakdown.get("coverage_score", 0.0) | |
| completion_ready = ( | |
| coverage >= 0.8 | |
| and action_error is None | |
| and not was_hard_invalid | |
| and _has_real_canvas_work(new_state) | |
| ) | |
| # The complete step sometimes dips pacing/attention because its | |
| # narration is a summary. Don't punish the dip if the episode | |
| # actually wrapped up cleanly (≥80% checklist covered & no | |
| # action error). Floor the complete-step delta at 0 in that | |
| # case so a clean wrap-up isn't net-negative. | |
| if completion_ready and reward < 0.0: | |
| reward = max(reward, 0.0) | |
| if completion_ready and score >= _SUBMIT_BONUS_THRESHOLD: | |
| reward += _SUBMIT_BONUS | |
| max_budget = int(self._scenario.get("step_budget", 0)) | |
| if completion_ready: | |
| reward += compute_efficiency_bonus(self._remaining_steps, max_budget, coverage) | |
| elif not applied_ops and newly_covered_count == 0 and reward > 0.0: | |
| reward = 0.0 | |
| elif self._remaining_steps == 0: | |
| done = True | |
| self._done_reason = self._done_reason or "step_budget_exhausted" | |
| reward = max(_REWARD_MIN, min(_REWARD_MAX, reward)) | |
| return self._build_observation( | |
| score_breakdown=breakdown, | |
| reward=reward, | |
| done=done, | |
| action_error=action_error, | |
| action_penalty=penalty_flat if penalty_flat != 0.0 else None, | |
| ) | |
| def state(self) -> State: | |
| return self._state | |
| # ----- scenario loading ----------------------------------------------- | |
| def _load_scenario( | |
| self, | |
| scenario_id: Optional[str] = None, | |
| task_name: Optional[str] = None, | |
| ) -> None: | |
| if scenario_id: | |
| scenario = self._repository.get_scenario(scenario_id) | |
| elif task_name: | |
| scenario = self._repository.first_scenario_for_task(task_name) | |
| else: | |
| scenario = self._repository.next_scenario() | |
| self._scenario = scenario | |
| self._remaining_steps = int(scenario.get("step_budget", 10)) | |
| self._entities = {} | |
| self._relations = [] | |
| self._layout = {} | |
| self._annotations = [] | |
| self._notes = [] | |
| self._narration_history = [] | |
| self._step_history = [] | |
| self._concept_coverage = [] | |
| # ----- op application -------------------------------------------------- | |
| def _reset_score_baseline(self) -> float: | |
| state = self._snapshot_state() | |
| baseline_action = _empty_action("advance") | |
| baseline_action["narration"] = "Baseline state." | |
| conflicts = compute_conflicts(baseline_action, state) | |
| breakdown = compute_score_breakdown( | |
| self._scenario.get("task_name", "easy"), | |
| self._scenario.get("template", ""), | |
| baseline_action, | |
| state, | |
| state, | |
| self._scenario, | |
| conflicts, | |
| [], | |
| ) | |
| return float(breakdown.get("overall_score", 0.0)) | |
| def _apply_action(self, action: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| dirty_regions: Set[str] = set() | |
| applied_ops: List[Dict[str, Any]] = [] | |
| for op in action.get("ops", []): | |
| region, applied = self._apply_op(op) | |
| if applied: | |
| applied_ops.append(copy.deepcopy(op)) | |
| if region: | |
| dirty_regions.add(region) | |
| for region_id in dirty_regions: | |
| self._relayout_region(region_id) | |
| return applied_ops | |
| def _apply_op(self, op: Dict[str, Any]) -> Tuple[Optional[str], bool]: | |
| name = op.get("op") or "" | |
| targets = list(op.get("target_ids") or []) | |
| params = dict(op.get("params") or {}) | |
| entry = self._op_dispatch.get(name) | |
| if entry is None: | |
| return None, False | |
| handler, min_targets = entry | |
| if len(targets) < min_targets: | |
| return None, False | |
| before = self._mutation_fingerprint() | |
| region = handler(targets, params) | |
| return region, before != self._mutation_fingerprint() | |
| def _mutation_fingerprint(self) -> Tuple[Any, ...]: | |
| return ( | |
| _copy_entities(self._entities), | |
| [dict(r) for r in self._relations], | |
| {k: dict(v) for k, v in self._layout.items()}, | |
| [dict(a) for a in self._annotations], | |
| [dict(n) for n in self._notes], | |
| ) | |
| # ----- per-op handlers ------------------------------------------------- | |
| def _op_add_region(self, rid: str, params: Dict[str, Any]) -> Optional[str]: | |
| if rid in self._entities: | |
| return None | |
| style = str(params.get("style") or "graph") | |
| if style not in REGION_STYLES: | |
| style = "graph" | |
| position = params.get("position") | |
| if isinstance(position, str): | |
| position = position.strip().lower() | |
| else: | |
| position = None | |
| # Gather already-occupied cells and existing position mappings | |
| occupied, existing_positions = self._region_occupancy() | |
| region_count = len([e for e in self._entities.values() | |
| if e.get("entity_type") == "region"]) | |
| bounds, resolved_position = resolve_bounds( | |
| position, occupied, existing_positions, region_count, | |
| ) | |
| self._entities[rid] = { | |
| "entity_id": rid, "entity_type": "region", "style": style, | |
| "title": str(params.get("title") or rid), | |
| "bounds": dict(bounds), | |
| "members": [], | |
| "root": params.get("root"), | |
| "source": params.get("source"), | |
| "position": resolved_position, | |
| } | |
| self._layout[rid] = {"x": bounds["x0"], "y": bounds["y0"]} | |
| # Adding a region changes the adaptive expansion of all existing | |
| # regions — recompute their bounds so earlier regions shrink to | |
| # make room rather than overlapping. | |
| self._recompute_all_region_bounds() | |
| return None | |
| def _op_add_placeable( | |
| self, eid: str, etype: str, params: Dict[str, Any] | |
| ) -> Optional[str]: | |
| if eid in self._entities: | |
| return None | |
| region_id = str(params.get("region") or self._default_region() or "") | |
| region = self._entities.get(region_id) | |
| if not region or region.get("entity_type") != "region": | |
| region_id = self._ensure_default_region() | |
| region = self._entities[region_id] | |
| entity: Dict[str, Any] = { | |
| "entity_id": eid, "entity_type": etype, | |
| "region": region_id, | |
| "role": "default", | |
| } | |
| if etype == "node": | |
| entity["value"] = params.get("value") | |
| elif etype == "pointer": | |
| entity["value"] = params.get("index") | |
| elif etype == "container": | |
| entity["ordered"] = bool(params.get("ordered", True)) | |
| self._entities[eid] = entity | |
| region["members"].append(eid) | |
| return region_id | |
| def _op_add_edge(self, src: str, dst: str, params: Dict[str, Any]) -> Optional[str]: | |
| if src not in self._entities or dst not in self._entities: | |
| return None | |
| kind = str(params.get("kind") or "edge") | |
| label = str(params.get("label") or "") | |
| key = (src, dst, kind) | |
| if any( | |
| (r.get("src"), r.get("dst"), r.get("kind", "edge")) == key | |
| for r in self._relations | |
| ): | |
| return None | |
| rel: Dict[str, Any] = {"src": src, "dst": dst, "kind": kind} | |
| if label: | |
| rel["label"] = label[:20] | |
| self._relations.append(rel) | |
| region_id = self._entities[src].get("region") or self._entities[dst].get("region") | |
| return region_id | |
| def _op_push(self, container_id: str, item_id: str) -> Optional[str]: | |
| container = self._entities.get(container_id) | |
| if not container or container.get("entity_type") != "container": | |
| return None | |
| if item_id not in self._entities: | |
| return None | |
| members = container.setdefault("contents", []) | |
| if item_id in members: | |
| return None | |
| members.append(item_id) | |
| return container.get("region") | |
| def _op_pop(self, container_id: str) -> Optional[str]: | |
| container = self._entities.get(container_id) | |
| if not container or container.get("entity_type") != "container": | |
| return None | |
| members = container.get("contents") or [] | |
| if not members: | |
| return None | |
| removed = members.pop(0) if container.get("ordered", True) else members.pop() | |
| self._add_annotation(removed, "[popped]") | |
| return container.get("region") | |
| def _op_move_pointer(self, pid: str, params: Dict[str, Any]) -> None: | |
| ent = self._entities.get(pid) | |
| if not ent or ent.get("entity_type") != "pointer": | |
| return | |
| if "index" not in params: | |
| return | |
| idx = params["index"] | |
| try: | |
| ent["value"] = int(idx) | |
| return | |
| except (TypeError, ValueError): | |
| pass | |
| # Accept an entity ID as a symbolic pointer target. | |
| if isinstance(idx, str) and idx in self._entities: | |
| ent["value"] = idx | |
| def _op_annotate(self, eid: str, params: Dict[str, Any]) -> None: | |
| if eid not in self._entities: | |
| return | |
| text = str(params.get("text", "")).strip() | |
| self._set_text_annotation(eid, text[:50] if text else "") | |
| def _op_highlight(self, eid: str) -> None: | |
| if eid not in self._entities: | |
| return | |
| self._add_annotation(eid, "[highlight]") | |
| def _op_unhighlight(self, eid: str) -> None: | |
| if eid not in self._entities: | |
| return | |
| self._annotations = [ | |
| a for a in self._annotations | |
| if not (a.get("target_id") == eid and a.get("text") == "[highlight]") | |
| ] | |
| def _op_add_note(self, params: dict) -> None: | |
| text = str(params.get("text", "")).strip() | |
| if not text: | |
| return | |
| note_id = f"note_{len(self._notes) + 1}" | |
| self._notes.append({ | |
| "note_id": note_id, | |
| "text": text[:100], | |
| "region": params.get("region"), | |
| }) | |
| def _op_set_role(self, eid: str, params: Dict[str, Any]) -> None: | |
| ent = self._entities.get(eid) | |
| if not ent: | |
| return | |
| role = str(params.get("role") or "default") | |
| if role not in ROLE_VALUES: | |
| return | |
| ent["role"] = role | |
| def _op_set_value(self, eid: str, params: Dict[str, Any]) -> None: | |
| ent = self._entities.get(eid) | |
| if not ent or "value" not in params: | |
| return | |
| if ent.get("entity_type") == "pointer": | |
| try: | |
| ent["value"] = int(params["value"]) | |
| except (TypeError, ValueError): | |
| return | |
| else: | |
| ent["value"] = params.get("value") | |
| def _op_remove_edge(self, src: str, dst: str, params: Dict[str, Any]) -> Optional[str]: | |
| kind = str(params.get("kind") or "edge") | |
| before = len(self._relations) | |
| self._relations = [ | |
| r for r in self._relations | |
| if not (r.get("src") == src and r.get("dst") == dst | |
| and r.get("kind", "edge") == kind) | |
| ] | |
| if len(self._relations) == before: | |
| return None | |
| ent = self._entities.get(src) or self._entities.get(dst) | |
| return ent.get("region") if ent else None | |
| def _op_remove_entity(self, eid: str) -> Optional[str]: | |
| ent = self._entities.pop(eid, None) | |
| if ent is None: | |
| return None | |
| self._relations = [ | |
| r for r in self._relations | |
| if r.get("src") != eid and r.get("dst") != eid | |
| ] | |
| self._annotations = [a for a in self._annotations if a.get("target_id") != eid] | |
| self._layout.pop(eid, None) | |
| for other in self._entities.values(): | |
| if isinstance(other, dict) and other.get("entity_type") == "container": | |
| contents = other.get("contents") or [] | |
| if eid in contents: | |
| other["contents"] = [c for c in contents if c != eid] | |
| affected_region: Optional[str] = None | |
| if ent.get("entity_type") == "region": | |
| fallback_region: Optional[str] = None | |
| for other in self._entities.values(): | |
| if isinstance(other, dict) and other.get("entity_type") == "region": | |
| fallback_region = other.get("entity_id") | |
| break | |
| if fallback_region is None: | |
| fallback_region = self._ensure_default_region() | |
| for other_eid, other in self._entities.items(): | |
| if isinstance(other, dict) and other.get("region") == eid: | |
| other["region"] = fallback_region | |
| target = self._entities.get(fallback_region) | |
| if isinstance(target, dict): | |
| target.setdefault("members", []).append(other_eid) | |
| affected_region = fallback_region | |
| else: | |
| region_id = ent.get("region") | |
| region = self._entities.get(region_id) if region_id else None | |
| if isinstance(region, dict): | |
| region["members"] = [m for m in region.get("members", []) if m != eid] | |
| affected_region = region_id | |
| return affected_region | |
| def _set_text_annotation(self, target_id: str, text: str) -> None: | |
| self._annotations = [ | |
| a for a in self._annotations | |
| if not (a.get("target_id") == target_id and a.get("text") != "[highlight]") | |
| ] | |
| if text: | |
| self._annotations.append({"target_id": target_id, "text": text}) | |
| # ----- layout helpers -------------------------------------------------- | |
| def _default_region(self) -> Optional[str]: | |
| for eid, ent in self._entities.items(): | |
| if ent.get("entity_type") == "region": | |
| return eid | |
| return None | |
| def _ensure_default_region(self) -> str: | |
| existing = self._default_region() | |
| if existing: | |
| return existing | |
| rid = "_default_region" | |
| self._op_add_region(rid, {"style": "graph", "title": "Canvas"}) | |
| return rid | |
| def _region_occupancy(self) -> Tuple[Set[Tuple[int, int]], Dict[str, str]]: | |
| """Return (occupied_cells, {region_id: position_name}) for all regions.""" | |
| occupied: Set[Tuple[int, int]] = set() | |
| existing_positions: Dict[str, str] = {} | |
| for eid, ent in self._entities.items(): | |
| if ent.get("entity_type") != "region": | |
| continue | |
| pos_name = ent.get("position") | |
| if pos_name and pos_name in _GRID_CELL: | |
| occupied.add(_GRID_CELL[pos_name]) | |
| existing_positions[eid] = pos_name | |
| return occupied, existing_positions | |
| def _recompute_all_region_bounds(self) -> None: | |
| """Recompute adaptive bounds for every region and relayout members.""" | |
| regions = [ | |
| (eid, ent) for eid, ent in self._entities.items() | |
| if ent.get("entity_type") == "region" | |
| ] | |
| if not regions: | |
| return | |
| # Collect all occupied cells first | |
| occupied: Set[Tuple[int, int]] = set() | |
| positions_map: Dict[str, str] = {} | |
| for eid, ent in regions: | |
| pos_name = ent.get("position") | |
| if pos_name and pos_name in _GRID_CELL: | |
| occupied.add(_GRID_CELL[pos_name]) | |
| positions_map[eid] = pos_name | |
| # Recompute each region's bounds using full occupancy info | |
| for eid, ent in regions: | |
| pos_name = ent.get("position") | |
| if not pos_name or pos_name not in _GRID_CELL: | |
| continue | |
| others = occupied - {_GRID_CELL[pos_name]} | |
| bounds, _ = resolve_bounds(pos_name, others, positions_map, len(regions)) | |
| ent["bounds"] = dict(bounds) | |
| self._layout[eid] = {"x": bounds["x0"], "y": bounds["y0"]} | |
| # Relayout all members since bounds changed | |
| for eid, ent in regions: | |
| self._relayout_region(eid) | |
| def _relayout_region(self, region_id: str) -> None: | |
| region = self._entities.get(region_id) | |
| if not region or region.get("entity_type") != "region": | |
| return | |
| members = [m for m in region.get("members", []) | |
| if self._entities.get(m, {}).get("region") == region_id] | |
| region["members"] = members | |
| positions = layout_members(region, members, self._relations) | |
| for eid, pos in positions.items(): | |
| self._layout[eid] = pos | |
| def _add_annotation(self, target_id: str, text: str) -> None: | |
| for existing in self._annotations: | |
| if existing.get("target_id") == target_id and existing.get("text") == text: | |
| return | |
| self._annotations.append({"target_id": target_id, "text": text}) | |
| # ----- observation + snapshots ---------------------------------------- | |
| def _snapshot_state(self) -> Dict[str, Any]: | |
| return { | |
| "entities": _copy_entities(self._entities), | |
| "relations": [dict(r) for r in self._relations], | |
| "layout": {k: dict(v) for k, v in self._layout.items()}, | |
| "annotations": [dict(a) for a in self._annotations], | |
| "notes": [dict(n) for n in self._notes], | |
| "narration_history": list(self._narration_history), | |
| "concept_coverage": list(self._concept_coverage), | |
| "concept_checklist": list(self._scenario.get("concept_checklist") or []), | |
| } | |
| def _build_observation( | |
| self, | |
| score_breakdown: Dict[str, float], | |
| reward: float, | |
| done: bool, | |
| action_error: Optional[str] = None, | |
| action_penalty: Optional[float] = None, | |
| ) -> VisualReasoningObservation: | |
| return VisualReasoningObservation( | |
| task_name=self._scenario.get("task_name", ""), | |
| scenario_id=self._scenario.get("scenario_id", ""), | |
| goal=self._scenario.get("goal", ""), | |
| input_data=dict(self._scenario.get("input_data", {})), | |
| constraints=list(self._scenario.get("constraints", [])), | |
| concept_checklist=list(self._scenario.get("concept_checklist", [])), | |
| max_steps=int(self._scenario.get("step_budget", 0)), | |
| entities=_copy_entities(self._entities), | |
| relations=[dict(r) for r in self._relations], | |
| layout={k: dict(v) for k, v in self._layout.items()}, | |
| annotations=[dict(a) for a in self._annotations], | |
| notes=[dict(n) for n in self._notes], | |
| narration_history=list(self._narration_history), | |
| step_history=[dict(s) for s in self._step_history], | |
| concept_coverage=list(self._concept_coverage), | |
| step_id=len(self._step_history), | |
| done_reason=self._done_reason, | |
| score_breakdown=score_breakdown, | |
| remaining_step_budget=self._remaining_steps, | |
| done=done, | |
| reward=reward, | |
| action_error=action_error, | |
| action_penalty=action_penalty, | |
| ) | |
| def _copy_entities(entities: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: | |
| """Fast copy of entities dict — avoids copy.deepcopy overhead. | |
| Entity values are primitives, dicts, or lists of strings. We shallow-copy | |
| each entity dict and list-copy any list/dict values so mutations in | |
| _apply_action don't leak across snapshots. | |
| """ | |
| out: Dict[str, Dict[str, Any]] = {} | |
| for eid, ent in entities.items(): | |
| copied: Dict[str, Any] = {} | |
| for k, v in ent.items(): | |
| if isinstance(v, list): | |
| copied[k] = list(v) | |
| elif isinstance(v, dict): | |
| copied[k] = dict(v) | |
| else: | |
| copied[k] = v | |
| out[eid] = copied | |
| return out | |
| def _clean_str(value: Any, lower: bool = False) -> Optional[str]: | |
| if not isinstance(value, str): | |
| return None | |
| stripped = value.strip() | |
| if not stripped: | |
| return None | |
| return stripped.lower() if lower else stripped | |
| def _raw_step_type(action: Any) -> str: | |
| if hasattr(action, "step_type"): | |
| return str(action.step_type or "").strip().lower() | |
| if isinstance(action, dict): | |
| return str(action.get("step_type") or "").strip().lower() | |
| return "" | |
| def _raw_attempted_ops(action: Any) -> List[Dict[str, Any]]: | |
| if hasattr(action, "model_dump"): | |
| data = action.model_dump() | |
| elif isinstance(action, dict): | |
| data = action | |
| else: | |
| data = {"ops": getattr(action, "ops", [])} | |
| ops_raw = data.get("ops") or [] | |
| if not isinstance(ops_raw, list): | |
| return [] | |
| attempted: List[Dict[str, Any]] = [] | |
| for op in ops_raw: | |
| if not isinstance(op, dict): | |
| continue | |
| targets_raw = op.get("target_ids") or [] | |
| targets = ( | |
| [str(t) for t in targets_raw if t is not None] | |
| if isinstance(targets_raw, list) | |
| else [] | |
| ) | |
| params = op.get("params") or {} | |
| if not isinstance(params, dict): | |
| params = {} | |
| attempted.append({ | |
| "op": str(op.get("op", "")).strip(), | |
| "target_ids": targets, | |
| "params": dict(params), | |
| }) | |
| return attempted | |
| def _has_real_canvas_work(state: Dict[str, Any]) -> bool: | |
| return bool( | |
| state.get("entities") | |
| or state.get("relations") | |
| or state.get("annotations") | |
| or state.get("notes") | |
| ) | |
| def _empty_action(step_type: str) -> Dict[str, Any]: | |
| return { | |
| "step_type": step_type, "narration": "", "ops": [], | |
| "covered_concepts": [], "intent": "", | |
| } | |