from __future__ import annotations from typing import Any from workflow_twin.environment import WorkflowTwinEnv from workflow_twin.quantizer import RotatedQuantizedMemory class MemoryBoundedEnv: def __init__( self, base_env: WorkflowTwinEnv, memory_budget: int = 10_000, bits: int = 3, mode: str = "quant", ) -> None: self.base_env = base_env self.memory_budget = memory_budget self.bits = bits self.mode = mode self.quantizer = RotatedQuantizedMemory(dimension=base_env.embedding_dim, seed=base_env.seed) self.max_compressed_chunks = 50 self.state_history: list[dict[str, Any]] = [] self.compressed_embeddings: list[dict[str, Any]] = [] self.truncated_snapshots = 0 def reset(self): obs = self.base_env.reset() self.state_history = [self.base_env.state()] self.compressed_embeddings = [] self.truncated_snapshots = 0 return obs def step(self, action: dict): obs, reward, done, info = self.base_env.step(action) state_snapshot = self._snapshot_state() self.state_history.append(state_snapshot) memory_used = self._compute_memory_usage() if memory_used > self.memory_budget: if self.mode == "quant": self._compress_history() else: self._truncate_history() memory_used = self._compute_memory_usage() if memory_used > self.memory_budget: reward = max(-1.0, reward - 0.2) if self.mode == "baseline": reward = max(-1.0, reward - 0.05) info["memory"] = { "memory_used": memory_used, "memory_budget": self.memory_budget, "compressed_chunks": len(self.compressed_embeddings), "truncated_snapshots": self.truncated_snapshots, "mode": self.mode, } obs.memory_used = memory_used obs.memory_budget = self.memory_budget return obs, reward, done, info def _compute_memory_usage(self) -> int: raw = len(self.state_history) * 600 compressed = len(self.compressed_embeddings) * 24 return raw + compressed def _snapshot_state(self) -> dict[str, Any]: raw_state = getattr(self.base_env, "_state", None) if raw_state is not None: return raw_state.model_dump() return self.base_env.state() def _compress_history(self) -> None: keep_last = 3 if len(self.state_history) <= keep_last: return old = self.state_history[:-keep_last] keep = self.state_history[-keep_last:] compressed: list[dict[str, Any]] = [] for snapshot in old: ticket = snapshot.get("current_ticket") if not ticket: continue embedding = ticket.get("embedding") if not embedding: continue code = self.quantizer.quantize_prod(embedding, bits=self.bits) compressed.append(code) self.compressed_embeddings.extend(compressed) if len(self.compressed_embeddings) > self.max_compressed_chunks: self.compressed_embeddings = self.compressed_embeddings[-self.max_compressed_chunks :] self.state_history = keep def _truncate_history(self) -> None: if len(self.state_history) <= 10: return removed = max(0, len(self.state_history) - 10) self.truncated_snapshots += removed self.state_history = self.state_history[-10:]