Spaces:
Sleeping
Sleeping
File size: 3,617 Bytes
846683d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | from __future__ import annotations
from typing import Any
from workflow_twin.environment import WorkflowTwinEnv
from workflow_twin.quantizer import RotatedQuantizedMemory
class MemoryBoundedEnv:
def __init__(
self,
base_env: WorkflowTwinEnv,
memory_budget: int = 10_000,
bits: int = 3,
mode: str = "quant",
) -> None:
self.base_env = base_env
self.memory_budget = memory_budget
self.bits = bits
self.mode = mode
self.quantizer = RotatedQuantizedMemory(dimension=base_env.embedding_dim, seed=base_env.seed)
self.max_compressed_chunks = 50
self.state_history: list[dict[str, Any]] = []
self.compressed_embeddings: list[dict[str, Any]] = []
self.truncated_snapshots = 0
def reset(self):
obs = self.base_env.reset()
self.state_history = [self.base_env.state()]
self.compressed_embeddings = []
self.truncated_snapshots = 0
return obs
def step(self, action: dict):
obs, reward, done, info = self.base_env.step(action)
state_snapshot = self._snapshot_state()
self.state_history.append(state_snapshot)
memory_used = self._compute_memory_usage()
if memory_used > self.memory_budget:
if self.mode == "quant":
self._compress_history()
else:
self._truncate_history()
memory_used = self._compute_memory_usage()
if memory_used > self.memory_budget:
reward = max(-1.0, reward - 0.2)
if self.mode == "baseline":
reward = max(-1.0, reward - 0.05)
info["memory"] = {
"memory_used": memory_used,
"memory_budget": self.memory_budget,
"compressed_chunks": len(self.compressed_embeddings),
"truncated_snapshots": self.truncated_snapshots,
"mode": self.mode,
}
obs.memory_used = memory_used
obs.memory_budget = self.memory_budget
return obs, reward, done, info
def _compute_memory_usage(self) -> int:
raw = len(self.state_history) * 600
compressed = len(self.compressed_embeddings) * 24
return raw + compressed
def _snapshot_state(self) -> dict[str, Any]:
raw_state = getattr(self.base_env, "_state", None)
if raw_state is not None:
return raw_state.model_dump()
return self.base_env.state()
def _compress_history(self) -> None:
keep_last = 3
if len(self.state_history) <= keep_last:
return
old = self.state_history[:-keep_last]
keep = self.state_history[-keep_last:]
compressed: list[dict[str, Any]] = []
for snapshot in old:
ticket = snapshot.get("current_ticket")
if not ticket:
continue
embedding = ticket.get("embedding")
if not embedding:
continue
code = self.quantizer.quantize_prod(embedding, bits=self.bits)
compressed.append(code)
self.compressed_embeddings.extend(compressed)
if len(self.compressed_embeddings) > self.max_compressed_chunks:
self.compressed_embeddings = self.compressed_embeddings[-self.max_compressed_chunks :]
self.state_history = keep
def _truncate_history(self) -> None:
if len(self.state_history) <= 10:
return
removed = max(0, len(self.state_history) - 10)
self.truncated_snapshots += removed
self.state_history = self.state_history[-10:]
|