visual_reasoning-env / server /visual_reasoning_environment.py
sreeramajay's picture
Upload folder using huggingface_hub
9bcd1a5 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import copy
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import VisualReasoningAction, VisualReasoningObservation
except ImportError:
from models import VisualReasoningAction, VisualReasoningObservation
try:
from .regions import REGION_STYLES, layout_members, resolve_bounds, _GRID_CELL
except ImportError:
from server.regions import REGION_STYLES, layout_members, resolve_bounds, _GRID_CELL
try:
from .scenario_loader import build_repository
except ImportError:
from server.scenario_loader import build_repository
try:
from .scoring import (
BREAKDOWN_KEYS,
compute_conflicts,
compute_efficiency_bonus,
compute_score_breakdown,
first_conflict_message,
is_concept_evidenced,
is_hard_invalid,
is_no_op,
normalize_action,
)
except ImportError:
from server.scoring import (
BREAKDOWN_KEYS,
compute_conflicts,
compute_efficiency_bonus,
compute_score_breakdown,
first_conflict_message,
is_concept_evidenced,
is_hard_invalid,
is_no_op,
normalize_action,
)
try:
from .constants import ROLE_VALUES, REGION_STYLES as _RS
except ImportError:
from server.constants import ROLE_VALUES, REGION_STYLES as _RS
try:
from .narration_scorer import warmup_scorer
except ImportError:
from server.narration_scorer import warmup_scorer
_INVALID_ACTION_PENALTY = -0.02
_NOOP_ACTION_PENALTY = -0.05
_SUBMIT_BONUS_THRESHOLD = 0.90
_SUBMIT_BONUS = 0.05
# Flat bonus per newly-covered concept. Adds a dense positive signal
# for genuine pedagogical progress on top of the delta-based reward —
# addresses sparse rewards during iterative/steady phases.
_NEW_CONCEPT_BONUS = 0.05
_REWARD_MIN = -0.2
_REWARD_MAX = 1.0
_LOG_LEVEL = os.getenv("VISUAL_REASONING_LOG_LEVEL", "INFO").upper()
_LOG_EVERY_STEP = os.getenv("VISUAL_REASONING_LOG_EVERY_STEP", "0") == "1"
logger = logging.getLogger(__name__)
if not logger.handlers:
logging.basicConfig(level=_LOG_LEVEL)
logger.setLevel(_LOG_LEVEL)
class VisualReasoningEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self) -> None:
super().__init__()
self._state = State(episode_id=str(uuid4()), step_count=0)
self._reset_count = 0
self._repository = build_repository()
self._scenario: Dict[str, Any] = {}
warmup_scorer()
self._entities: Dict[str, Dict[str, Any]] = {}
self._relations: List[Dict[str, Any]] = []
self._layout: Dict[str, Dict[str, float]] = {}
self._annotations: List[Dict[str, Any]] = []
self._notes: List[Dict[str, Any]] = []
self._narration_history: List[str] = []
self._step_history: List[Dict[str, Any]] = []
self._concept_coverage: List[str] = []
self._remaining_steps = 0
self._previous_score = 0.0
self._done_reason: Optional[str] = None
self._consecutive_noops = 0
self._op_dispatch: Dict[str, Tuple[Callable[..., Optional[str]], int]] = {
"add_region": (lambda t, p: self._op_add_region(t[0], p), 1),
"add_node": (lambda t, p: self._op_add_placeable(t[0], "node", p), 1),
"add_pointer": (lambda t, p: self._op_add_placeable(t[0], "pointer", p), 1),
"add_container": (lambda t, p: self._op_add_placeable(t[0], "container", p), 1),
"add_edge": (lambda t, p: self._op_add_edge(t[0], t[1], p), 2),
"push_to": (lambda t, p: self._op_push(t[0], t[1]), 2),
"pop_from": (lambda t, p: self._op_pop(t[0]), 1),
"move_pointer": (lambda t, p: self._op_move_pointer(t[0], p), 1),
"annotate": (lambda t, p: self._op_annotate(t[0], p), 1),
"highlight": (lambda t, p: self._op_highlight(t[0]), 1),
"unhighlight": (lambda t, p: self._op_unhighlight(t[0]), 1),
"add_note": (lambda t, p: self._op_add_note(p), 0),
"set_role": (lambda t, p: self._op_set_role(t[0], p), 1),
"set_value": (lambda t, p: self._op_set_value(t[0], p), 1),
"remove_edge": (lambda t, p: self._op_remove_edge(t[0], t[1], p), 2),
"remove_entity": (lambda t, p: self._op_remove_entity(t[0]), 1),
}
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
scenario_id: Optional[str] = None,
task_name: Optional[str] = None,
**kwargs: Any,
) -> VisualReasoningObservation:
self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
self._reset_count += 1
try:
self._load_scenario(
scenario_id=_clean_str(scenario_id),
task_name=_clean_str(task_name, lower=True),
)
except (ValueError, KeyError, TypeError):
logger.warning(
"Invalid reset params scenario_id=%r task_name=%r — falling back to next_scenario",
scenario_id,
task_name,
)
self._load_scenario()
self._previous_score = self._reset_score_baseline()
self._done_reason = None
self._consecutive_noops = 0
display = {key: 0.0 for key in BREAKDOWN_KEYS}
return self._build_observation(
score_breakdown=display, reward=0.0, done=False,
)
def step(
self,
action: VisualReasoningAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> VisualReasoningObservation:
if not self._scenario:
logger.warning("step() called before reset() — auto-resetting")
self.reset()
self._state.step_count += 1
if self._remaining_steps <= 0:
self._done_reason = "step_budget_exhausted"
return self._build_observation(
score_breakdown={k: 0.0 for k in BREAKDOWN_KEYS},
reward=0.0, done=True,
action_error="step_budget_exhausted",
)
prev_state = self._snapshot_state()
attempted_ops = _raw_attempted_ops(action)
applied_ops: List[Dict[str, Any]] = []
normalized, parse_error = normalize_action(action)
action_error: Optional[str] = parse_error
penalty_flat = 0.0
was_hard_invalid = parse_error is not None
if normalized is None:
raw_step = _raw_step_type(action)
normalized = _empty_action(raw_step if raw_step in ("advance", "complete") else "advance")
penalty_flat += _INVALID_ACTION_PENALTY
conflicts: Dict[str, int] = {}
else:
conflicts = compute_conflicts(normalized, prev_state)
if is_hard_invalid(conflicts):
action_error = first_conflict_message(conflicts, normalized)
penalty_flat += _INVALID_ACTION_PENALTY
was_hard_invalid = True
else:
applied_ops = self._apply_action(normalized)
if conflicts.get("invalid_role", 0):
action_error = first_conflict_message(conflicts, normalized)
scoring_action = dict(normalized)
scoring_action["ops"] = copy.deepcopy(applied_ops)
scoring_action["attempted_ops"] = copy.deepcopy(attempted_ops)
self._narration_history.append(normalized.get("narration", ""))
checklist = set(self._scenario.get("concept_checklist") or [])
credited_concepts: List[str] = []
newly_covered_count = 0
if not was_hard_invalid and applied_ops:
for concept in normalized.get("covered_concepts", []):
if (
concept in checklist
and concept not in credited_concepts
and is_concept_evidenced(concept, scoring_action)
):
credited_concepts.append(concept)
if concept not in self._concept_coverage:
self._concept_coverage.append(concept)
newly_covered_count += 1
self._step_history.append({
"step_id": len(self._step_history) + 1,
"step_type": normalized.get("step_type", "advance"),
"narration": normalized.get("narration", ""),
"ops": copy.deepcopy(applied_ops),
"attempted_ops": copy.deepcopy(attempted_ops),
"covered_concepts": list(credited_concepts),
"intent": normalized.get("intent", ""),
"action_error": action_error,
})
self._remaining_steps = max(0, self._remaining_steps - 1)
new_state = self._snapshot_state()
template = self._scenario.get("template", "")
breakdown = compute_score_breakdown(
self._scenario.get("task_name", "easy"),
template,
scoring_action,
prev_state,
new_state,
self._scenario,
conflicts,
self._step_history,
)
was_noop = (
not was_hard_invalid
and is_no_op(scoring_action, prev_state, new_state)
)
if was_noop:
self._consecutive_noops += 1
noop_penalty = _NOOP_ACTION_PENALTY
if self._consecutive_noops >= 2:
noop_penalty *= 2
penalty_flat += noop_penalty
elif not was_hard_invalid:
self._consecutive_noops = 0
score = breakdown.get("overall_score", 0.0)
if was_hard_invalid:
reward = penalty_flat
# After a hard-invalid, the next valid step's score often
# drops from the prior peak because a dropped op means less
# state progress (coverage, attention, etc.). Pin
# previous_score to the current score so the next step's
# delta starts from a realistic baseline and isn't punished
# twice for the same rejection.
self._previous_score = score
else:
reward = (score - self._previous_score) + penalty_flat
# Dense shaping: immediate positive signal per newly-covered
# concept — partially replaces the all-at-end submit bonus
# so the LLM gets feedback throughout an episode.
if newly_covered_count > 0:
reward += _NEW_CONCEPT_BONUS * newly_covered_count
if was_noop and reward > penalty_flat:
reward = penalty_flat
if was_noop:
self._previous_score = max(self._previous_score, score)
else:
self._previous_score = score
done = False
if normalized.get("step_type") == "complete":
done = True
self._done_reason = "complete"
coverage = breakdown.get("coverage_score", 0.0)
completion_ready = (
coverage >= 0.8
and action_error is None
and not was_hard_invalid
and _has_real_canvas_work(new_state)
)
# The complete step sometimes dips pacing/attention because its
# narration is a summary. Don't punish the dip if the episode
# actually wrapped up cleanly (≥80% checklist covered & no
# action error). Floor the complete-step delta at 0 in that
# case so a clean wrap-up isn't net-negative.
if completion_ready and reward < 0.0:
reward = max(reward, 0.0)
if completion_ready and score >= _SUBMIT_BONUS_THRESHOLD:
reward += _SUBMIT_BONUS
max_budget = int(self._scenario.get("step_budget", 0))
if completion_ready:
reward += compute_efficiency_bonus(self._remaining_steps, max_budget, coverage)
elif not applied_ops and newly_covered_count == 0 and reward > 0.0:
reward = 0.0
elif self._remaining_steps == 0:
done = True
self._done_reason = self._done_reason or "step_budget_exhausted"
reward = max(_REWARD_MIN, min(_REWARD_MAX, reward))
return self._build_observation(
score_breakdown=breakdown,
reward=reward,
done=done,
action_error=action_error,
action_penalty=penalty_flat if penalty_flat != 0.0 else None,
)
@property
def state(self) -> State:
return self._state
# ----- scenario loading -----------------------------------------------
def _load_scenario(
self,
scenario_id: Optional[str] = None,
task_name: Optional[str] = None,
) -> None:
if scenario_id:
scenario = self._repository.get_scenario(scenario_id)
elif task_name:
scenario = self._repository.first_scenario_for_task(task_name)
else:
scenario = self._repository.next_scenario()
self._scenario = scenario
self._remaining_steps = int(scenario.get("step_budget", 10))
self._entities = {}
self._relations = []
self._layout = {}
self._annotations = []
self._notes = []
self._narration_history = []
self._step_history = []
self._concept_coverage = []
# ----- op application --------------------------------------------------
def _reset_score_baseline(self) -> float:
state = self._snapshot_state()
baseline_action = _empty_action("advance")
baseline_action["narration"] = "Baseline state."
conflicts = compute_conflicts(baseline_action, state)
breakdown = compute_score_breakdown(
self._scenario.get("task_name", "easy"),
self._scenario.get("template", ""),
baseline_action,
state,
state,
self._scenario,
conflicts,
[],
)
return float(breakdown.get("overall_score", 0.0))
def _apply_action(self, action: Dict[str, Any]) -> List[Dict[str, Any]]:
dirty_regions: Set[str] = set()
applied_ops: List[Dict[str, Any]] = []
for op in action.get("ops", []):
region, applied = self._apply_op(op)
if applied:
applied_ops.append(copy.deepcopy(op))
if region:
dirty_regions.add(region)
for region_id in dirty_regions:
self._relayout_region(region_id)
return applied_ops
def _apply_op(self, op: Dict[str, Any]) -> Tuple[Optional[str], bool]:
name = op.get("op") or ""
targets = list(op.get("target_ids") or [])
params = dict(op.get("params") or {})
entry = self._op_dispatch.get(name)
if entry is None:
return None, False
handler, min_targets = entry
if len(targets) < min_targets:
return None, False
before = self._mutation_fingerprint()
region = handler(targets, params)
return region, before != self._mutation_fingerprint()
def _mutation_fingerprint(self) -> Tuple[Any, ...]:
return (
_copy_entities(self._entities),
[dict(r) for r in self._relations],
{k: dict(v) for k, v in self._layout.items()},
[dict(a) for a in self._annotations],
[dict(n) for n in self._notes],
)
# ----- per-op handlers -------------------------------------------------
def _op_add_region(self, rid: str, params: Dict[str, Any]) -> Optional[str]:
if rid in self._entities:
return None
style = str(params.get("style") or "graph")
if style not in REGION_STYLES:
style = "graph"
position = params.get("position")
if isinstance(position, str):
position = position.strip().lower()
else:
position = None
# Gather already-occupied cells and existing position mappings
occupied, existing_positions = self._region_occupancy()
region_count = len([e for e in self._entities.values()
if e.get("entity_type") == "region"])
bounds, resolved_position = resolve_bounds(
position, occupied, existing_positions, region_count,
)
self._entities[rid] = {
"entity_id": rid, "entity_type": "region", "style": style,
"title": str(params.get("title") or rid),
"bounds": dict(bounds),
"members": [],
"root": params.get("root"),
"source": params.get("source"),
"position": resolved_position,
}
self._layout[rid] = {"x": bounds["x0"], "y": bounds["y0"]}
# Adding a region changes the adaptive expansion of all existing
# regions — recompute their bounds so earlier regions shrink to
# make room rather than overlapping.
self._recompute_all_region_bounds()
return None
def _op_add_placeable(
self, eid: str, etype: str, params: Dict[str, Any]
) -> Optional[str]:
if eid in self._entities:
return None
region_id = str(params.get("region") or self._default_region() or "")
region = self._entities.get(region_id)
if not region or region.get("entity_type") != "region":
region_id = self._ensure_default_region()
region = self._entities[region_id]
entity: Dict[str, Any] = {
"entity_id": eid, "entity_type": etype,
"region": region_id,
"role": "default",
}
if etype == "node":
entity["value"] = params.get("value")
elif etype == "pointer":
entity["value"] = params.get("index")
elif etype == "container":
entity["ordered"] = bool(params.get("ordered", True))
self._entities[eid] = entity
region["members"].append(eid)
return region_id
def _op_add_edge(self, src: str, dst: str, params: Dict[str, Any]) -> Optional[str]:
if src not in self._entities or dst not in self._entities:
return None
kind = str(params.get("kind") or "edge")
label = str(params.get("label") or "")
key = (src, dst, kind)
if any(
(r.get("src"), r.get("dst"), r.get("kind", "edge")) == key
for r in self._relations
):
return None
rel: Dict[str, Any] = {"src": src, "dst": dst, "kind": kind}
if label:
rel["label"] = label[:20]
self._relations.append(rel)
region_id = self._entities[src].get("region") or self._entities[dst].get("region")
return region_id
def _op_push(self, container_id: str, item_id: str) -> Optional[str]:
container = self._entities.get(container_id)
if not container or container.get("entity_type") != "container":
return None
if item_id not in self._entities:
return None
members = container.setdefault("contents", [])
if item_id in members:
return None
members.append(item_id)
return container.get("region")
def _op_pop(self, container_id: str) -> Optional[str]:
container = self._entities.get(container_id)
if not container or container.get("entity_type") != "container":
return None
members = container.get("contents") or []
if not members:
return None
removed = members.pop(0) if container.get("ordered", True) else members.pop()
self._add_annotation(removed, "[popped]")
return container.get("region")
def _op_move_pointer(self, pid: str, params: Dict[str, Any]) -> None:
ent = self._entities.get(pid)
if not ent or ent.get("entity_type") != "pointer":
return
if "index" not in params:
return
idx = params["index"]
try:
ent["value"] = int(idx)
return
except (TypeError, ValueError):
pass
# Accept an entity ID as a symbolic pointer target.
if isinstance(idx, str) and idx in self._entities:
ent["value"] = idx
def _op_annotate(self, eid: str, params: Dict[str, Any]) -> None:
if eid not in self._entities:
return
text = str(params.get("text", "")).strip()
self._set_text_annotation(eid, text[:50] if text else "")
def _op_highlight(self, eid: str) -> None:
if eid not in self._entities:
return
self._add_annotation(eid, "[highlight]")
def _op_unhighlight(self, eid: str) -> None:
if eid not in self._entities:
return
self._annotations = [
a for a in self._annotations
if not (a.get("target_id") == eid and a.get("text") == "[highlight]")
]
def _op_add_note(self, params: dict) -> None:
text = str(params.get("text", "")).strip()
if not text:
return
note_id = f"note_{len(self._notes) + 1}"
self._notes.append({
"note_id": note_id,
"text": text[:100],
"region": params.get("region"),
})
def _op_set_role(self, eid: str, params: Dict[str, Any]) -> None:
ent = self._entities.get(eid)
if not ent:
return
role = str(params.get("role") or "default")
if role not in ROLE_VALUES:
return
ent["role"] = role
def _op_set_value(self, eid: str, params: Dict[str, Any]) -> None:
ent = self._entities.get(eid)
if not ent or "value" not in params:
return
if ent.get("entity_type") == "pointer":
try:
ent["value"] = int(params["value"])
except (TypeError, ValueError):
return
else:
ent["value"] = params.get("value")
def _op_remove_edge(self, src: str, dst: str, params: Dict[str, Any]) -> Optional[str]:
kind = str(params.get("kind") or "edge")
before = len(self._relations)
self._relations = [
r for r in self._relations
if not (r.get("src") == src and r.get("dst") == dst
and r.get("kind", "edge") == kind)
]
if len(self._relations) == before:
return None
ent = self._entities.get(src) or self._entities.get(dst)
return ent.get("region") if ent else None
def _op_remove_entity(self, eid: str) -> Optional[str]:
ent = self._entities.pop(eid, None)
if ent is None:
return None
self._relations = [
r for r in self._relations
if r.get("src") != eid and r.get("dst") != eid
]
self._annotations = [a for a in self._annotations if a.get("target_id") != eid]
self._layout.pop(eid, None)
for other in self._entities.values():
if isinstance(other, dict) and other.get("entity_type") == "container":
contents = other.get("contents") or []
if eid in contents:
other["contents"] = [c for c in contents if c != eid]
affected_region: Optional[str] = None
if ent.get("entity_type") == "region":
fallback_region: Optional[str] = None
for other in self._entities.values():
if isinstance(other, dict) and other.get("entity_type") == "region":
fallback_region = other.get("entity_id")
break
if fallback_region is None:
fallback_region = self._ensure_default_region()
for other_eid, other in self._entities.items():
if isinstance(other, dict) and other.get("region") == eid:
other["region"] = fallback_region
target = self._entities.get(fallback_region)
if isinstance(target, dict):
target.setdefault("members", []).append(other_eid)
affected_region = fallback_region
else:
region_id = ent.get("region")
region = self._entities.get(region_id) if region_id else None
if isinstance(region, dict):
region["members"] = [m for m in region.get("members", []) if m != eid]
affected_region = region_id
return affected_region
def _set_text_annotation(self, target_id: str, text: str) -> None:
self._annotations = [
a for a in self._annotations
if not (a.get("target_id") == target_id and a.get("text") != "[highlight]")
]
if text:
self._annotations.append({"target_id": target_id, "text": text})
# ----- layout helpers --------------------------------------------------
def _default_region(self) -> Optional[str]:
for eid, ent in self._entities.items():
if ent.get("entity_type") == "region":
return eid
return None
def _ensure_default_region(self) -> str:
existing = self._default_region()
if existing:
return existing
rid = "_default_region"
self._op_add_region(rid, {"style": "graph", "title": "Canvas"})
return rid
def _region_occupancy(self) -> Tuple[Set[Tuple[int, int]], Dict[str, str]]:
"""Return (occupied_cells, {region_id: position_name}) for all regions."""
occupied: Set[Tuple[int, int]] = set()
existing_positions: Dict[str, str] = {}
for eid, ent in self._entities.items():
if ent.get("entity_type") != "region":
continue
pos_name = ent.get("position")
if pos_name and pos_name in _GRID_CELL:
occupied.add(_GRID_CELL[pos_name])
existing_positions[eid] = pos_name
return occupied, existing_positions
def _recompute_all_region_bounds(self) -> None:
"""Recompute adaptive bounds for every region and relayout members."""
regions = [
(eid, ent) for eid, ent in self._entities.items()
if ent.get("entity_type") == "region"
]
if not regions:
return
# Collect all occupied cells first
occupied: Set[Tuple[int, int]] = set()
positions_map: Dict[str, str] = {}
for eid, ent in regions:
pos_name = ent.get("position")
if pos_name and pos_name in _GRID_CELL:
occupied.add(_GRID_CELL[pos_name])
positions_map[eid] = pos_name
# Recompute each region's bounds using full occupancy info
for eid, ent in regions:
pos_name = ent.get("position")
if not pos_name or pos_name not in _GRID_CELL:
continue
others = occupied - {_GRID_CELL[pos_name]}
bounds, _ = resolve_bounds(pos_name, others, positions_map, len(regions))
ent["bounds"] = dict(bounds)
self._layout[eid] = {"x": bounds["x0"], "y": bounds["y0"]}
# Relayout all members since bounds changed
for eid, ent in regions:
self._relayout_region(eid)
def _relayout_region(self, region_id: str) -> None:
region = self._entities.get(region_id)
if not region or region.get("entity_type") != "region":
return
members = [m for m in region.get("members", [])
if self._entities.get(m, {}).get("region") == region_id]
region["members"] = members
positions = layout_members(region, members, self._relations)
for eid, pos in positions.items():
self._layout[eid] = pos
def _add_annotation(self, target_id: str, text: str) -> None:
for existing in self._annotations:
if existing.get("target_id") == target_id and existing.get("text") == text:
return
self._annotations.append({"target_id": target_id, "text": text})
# ----- observation + snapshots ----------------------------------------
def _snapshot_state(self) -> Dict[str, Any]:
return {
"entities": _copy_entities(self._entities),
"relations": [dict(r) for r in self._relations],
"layout": {k: dict(v) for k, v in self._layout.items()},
"annotations": [dict(a) for a in self._annotations],
"notes": [dict(n) for n in self._notes],
"narration_history": list(self._narration_history),
"concept_coverage": list(self._concept_coverage),
"concept_checklist": list(self._scenario.get("concept_checklist") or []),
}
def _build_observation(
self,
score_breakdown: Dict[str, float],
reward: float,
done: bool,
action_error: Optional[str] = None,
action_penalty: Optional[float] = None,
) -> VisualReasoningObservation:
return VisualReasoningObservation(
task_name=self._scenario.get("task_name", ""),
scenario_id=self._scenario.get("scenario_id", ""),
goal=self._scenario.get("goal", ""),
input_data=dict(self._scenario.get("input_data", {})),
constraints=list(self._scenario.get("constraints", [])),
concept_checklist=list(self._scenario.get("concept_checklist", [])),
max_steps=int(self._scenario.get("step_budget", 0)),
entities=_copy_entities(self._entities),
relations=[dict(r) for r in self._relations],
layout={k: dict(v) for k, v in self._layout.items()},
annotations=[dict(a) for a in self._annotations],
notes=[dict(n) for n in self._notes],
narration_history=list(self._narration_history),
step_history=[dict(s) for s in self._step_history],
concept_coverage=list(self._concept_coverage),
step_id=len(self._step_history),
done_reason=self._done_reason,
score_breakdown=score_breakdown,
remaining_step_budget=self._remaining_steps,
done=done,
reward=reward,
action_error=action_error,
action_penalty=action_penalty,
)
def _copy_entities(entities: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""Fast copy of entities dict — avoids copy.deepcopy overhead.
Entity values are primitives, dicts, or lists of strings. We shallow-copy
each entity dict and list-copy any list/dict values so mutations in
_apply_action don't leak across snapshots.
"""
out: Dict[str, Dict[str, Any]] = {}
for eid, ent in entities.items():
copied: Dict[str, Any] = {}
for k, v in ent.items():
if isinstance(v, list):
copied[k] = list(v)
elif isinstance(v, dict):
copied[k] = dict(v)
else:
copied[k] = v
out[eid] = copied
return out
def _clean_str(value: Any, lower: bool = False) -> Optional[str]:
if not isinstance(value, str):
return None
stripped = value.strip()
if not stripped:
return None
return stripped.lower() if lower else stripped
def _raw_step_type(action: Any) -> str:
if hasattr(action, "step_type"):
return str(action.step_type or "").strip().lower()
if isinstance(action, dict):
return str(action.get("step_type") or "").strip().lower()
return ""
def _raw_attempted_ops(action: Any) -> List[Dict[str, Any]]:
if hasattr(action, "model_dump"):
data = action.model_dump()
elif isinstance(action, dict):
data = action
else:
data = {"ops": getattr(action, "ops", [])}
ops_raw = data.get("ops") or []
if not isinstance(ops_raw, list):
return []
attempted: List[Dict[str, Any]] = []
for op in ops_raw:
if not isinstance(op, dict):
continue
targets_raw = op.get("target_ids") or []
targets = (
[str(t) for t in targets_raw if t is not None]
if isinstance(targets_raw, list)
else []
)
params = op.get("params") or {}
if not isinstance(params, dict):
params = {}
attempted.append({
"op": str(op.get("op", "")).strip(),
"target_ids": targets,
"params": dict(params),
})
return attempted
def _has_real_canvas_work(state: Dict[str, Any]) -> bool:
return bool(
state.get("entities")
or state.get("relations")
or state.get("annotations")
or state.get("notes")
)
def _empty_action(step_type: str) -> Dict[str, Any]:
return {
"step_type": step_type, "narration": "", "ops": [],
"covered_concepts": [], "intent": "",
}