File size: 3,617 Bytes
846683d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from __future__ import annotations

from typing import Any

from workflow_twin.environment import WorkflowTwinEnv
from workflow_twin.quantizer import RotatedQuantizedMemory


class MemoryBoundedEnv:
    def __init__(
        self,
        base_env: WorkflowTwinEnv,
        memory_budget: int = 10_000,
        bits: int = 3,
        mode: str = "quant",
    ) -> None:
        self.base_env = base_env
        self.memory_budget = memory_budget
        self.bits = bits
        self.mode = mode
        self.quantizer = RotatedQuantizedMemory(dimension=base_env.embedding_dim, seed=base_env.seed)
        self.max_compressed_chunks = 50

        self.state_history: list[dict[str, Any]] = []
        self.compressed_embeddings: list[dict[str, Any]] = []
        self.truncated_snapshots = 0

    def reset(self):
        obs = self.base_env.reset()
        self.state_history = [self.base_env.state()]
        self.compressed_embeddings = []
        self.truncated_snapshots = 0
        return obs

    def step(self, action: dict):
        obs, reward, done, info = self.base_env.step(action)
        state_snapshot = self._snapshot_state()
        self.state_history.append(state_snapshot)

        memory_used = self._compute_memory_usage()
        if memory_used > self.memory_budget:
            if self.mode == "quant":
                self._compress_history()
            else:
                self._truncate_history()
            memory_used = self._compute_memory_usage()
            if memory_used > self.memory_budget:
                reward = max(-1.0, reward - 0.2)
            if self.mode == "baseline":
                reward = max(-1.0, reward - 0.05)

        info["memory"] = {
            "memory_used": memory_used,
            "memory_budget": self.memory_budget,
            "compressed_chunks": len(self.compressed_embeddings),
            "truncated_snapshots": self.truncated_snapshots,
            "mode": self.mode,
        }
        obs.memory_used = memory_used
        obs.memory_budget = self.memory_budget
        return obs, reward, done, info

    def _compute_memory_usage(self) -> int:
        raw = len(self.state_history) * 600
        compressed = len(self.compressed_embeddings) * 24
        return raw + compressed

    def _snapshot_state(self) -> dict[str, Any]:
        raw_state = getattr(self.base_env, "_state", None)
        if raw_state is not None:
            return raw_state.model_dump()
        return self.base_env.state()

    def _compress_history(self) -> None:
        keep_last = 3
        if len(self.state_history) <= keep_last:
            return

        old = self.state_history[:-keep_last]
        keep = self.state_history[-keep_last:]

        compressed: list[dict[str, Any]] = []
        for snapshot in old:
            ticket = snapshot.get("current_ticket")
            if not ticket:
                continue
            embedding = ticket.get("embedding")
            if not embedding:
                continue
            code = self.quantizer.quantize_prod(embedding, bits=self.bits)
            compressed.append(code)

        self.compressed_embeddings.extend(compressed)
        if len(self.compressed_embeddings) > self.max_compressed_chunks:
            self.compressed_embeddings = self.compressed_embeddings[-self.max_compressed_chunks :]
        self.state_history = keep

    def _truncate_history(self) -> None:
        if len(self.state_history) <= 10:
            return
        removed = max(0, len(self.state_history) - 10)
        self.truncated_snapshots += removed
        self.state_history = self.state_history[-10:]