Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Any | |
| from workflow_twin.environment import WorkflowTwinEnv | |
| from workflow_twin.quantizer import RotatedQuantizedMemory | |
| class MemoryBoundedEnv: | |
| def __init__( | |
| self, | |
| base_env: WorkflowTwinEnv, | |
| memory_budget: int = 10_000, | |
| bits: int = 3, | |
| mode: str = "quant", | |
| ) -> None: | |
| self.base_env = base_env | |
| self.memory_budget = memory_budget | |
| self.bits = bits | |
| self.mode = mode | |
| self.quantizer = RotatedQuantizedMemory(dimension=base_env.embedding_dim, seed=base_env.seed) | |
| self.max_compressed_chunks = 50 | |
| self.state_history: list[dict[str, Any]] = [] | |
| self.compressed_embeddings: list[dict[str, Any]] = [] | |
| self.truncated_snapshots = 0 | |
| def reset(self): | |
| obs = self.base_env.reset() | |
| self.state_history = [self.base_env.state()] | |
| self.compressed_embeddings = [] | |
| self.truncated_snapshots = 0 | |
| return obs | |
| def step(self, action: dict): | |
| obs, reward, done, info = self.base_env.step(action) | |
| state_snapshot = self._snapshot_state() | |
| self.state_history.append(state_snapshot) | |
| memory_used = self._compute_memory_usage() | |
| if memory_used > self.memory_budget: | |
| if self.mode == "quant": | |
| self._compress_history() | |
| else: | |
| self._truncate_history() | |
| memory_used = self._compute_memory_usage() | |
| if memory_used > self.memory_budget: | |
| reward = max(-1.0, reward - 0.2) | |
| if self.mode == "baseline": | |
| reward = max(-1.0, reward - 0.05) | |
| info["memory"] = { | |
| "memory_used": memory_used, | |
| "memory_budget": self.memory_budget, | |
| "compressed_chunks": len(self.compressed_embeddings), | |
| "truncated_snapshots": self.truncated_snapshots, | |
| "mode": self.mode, | |
| } | |
| obs.memory_used = memory_used | |
| obs.memory_budget = self.memory_budget | |
| return obs, reward, done, info | |
| def _compute_memory_usage(self) -> int: | |
| raw = len(self.state_history) * 600 | |
| compressed = len(self.compressed_embeddings) * 24 | |
| return raw + compressed | |
| def _snapshot_state(self) -> dict[str, Any]: | |
| raw_state = getattr(self.base_env, "_state", None) | |
| if raw_state is not None: | |
| return raw_state.model_dump() | |
| return self.base_env.state() | |
| def _compress_history(self) -> None: | |
| keep_last = 3 | |
| if len(self.state_history) <= keep_last: | |
| return | |
| old = self.state_history[:-keep_last] | |
| keep = self.state_history[-keep_last:] | |
| compressed: list[dict[str, Any]] = [] | |
| for snapshot in old: | |
| ticket = snapshot.get("current_ticket") | |
| if not ticket: | |
| continue | |
| embedding = ticket.get("embedding") | |
| if not embedding: | |
| continue | |
| code = self.quantizer.quantize_prod(embedding, bits=self.bits) | |
| compressed.append(code) | |
| self.compressed_embeddings.extend(compressed) | |
| if len(self.compressed_embeddings) > self.max_compressed_chunks: | |
| self.compressed_embeddings = self.compressed_embeddings[-self.max_compressed_chunks :] | |
| self.state_history = keep | |
| def _truncate_history(self) -> None: | |
| if len(self.state_history) <= 10: | |
| return | |
| removed = max(0, len(self.state_history) - 10) | |
| self.truncated_snapshots += removed | |
| self.state_history = self.state_history[-10:] | |