Spaces:
Sleeping
Sleeping
File size: 7,554 Bytes
846683d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | from __future__ import annotations
import hashlib
import random
import uuid
from typing import Any
from .dynamics import apply_action, is_terminal
from .entities import SystemState, Ticket
from .graders import grade_episode
from .models import Action, Observation, StepResult
from .quantizer import RotatedQuantizedMemory
from .reward import compute_reward
from .tasks import load_tasks
class OpenEnv:
def __init__(
self,
difficulty: str = "easy",
seed: int | None = None,
use_quantizer: bool = False,
quant_mode: str = "full",
quant_every_n_steps: int = 1,
embedding_dim: int = 16,
quant_bits: int = 3,
distortion_lambda: float = 0.2,
inner_product_lambda: float = 0.1,
) -> None:
self.difficulty = difficulty
self._random = random.Random(seed)
self._task_pool = load_tasks(difficulty)
self._state: SystemState | None = None
self.use_quantizer = use_quantizer
self.quant_mode = quant_mode
self.quant_every_n_steps = max(1, quant_every_n_steps)
self.embedding_dim = embedding_dim
self.quant_bits = quant_bits
self.distortion_lambda = distortion_lambda
self.inner_product_lambda = inner_product_lambda
self._quantizer = RotatedQuantizedMemory(embedding_dim, seed or 42) if use_quantizer else None
def _should_quantize(self, previous_ticket: Ticket, current_ticket: Ticket) -> tuple[bool, str]:
if not self.use_quantizer or not self._quantizer:
return False, "disabled"
mode = self.quant_mode.lower()
if mode == "off":
return False, "mode_off"
if current_ticket.embedding is None:
return False, "no_embedding"
if mode == "full":
return True, "mode_full"
step_count = self._state.step_count if self._state else 0
on_schedule = step_count % self.quant_every_n_steps == 0
status_changed = previous_ticket.status != current_ticket.status
if mode == "throttle":
return (on_schedule, "schedule" if on_schedule else "throttled")
if mode == "status":
return (status_changed, "status_change" if status_changed else "no_status_change")
if mode == "hybrid":
should = on_schedule or status_changed
if should:
return True, "schedule_or_status"
return False, "throttled_no_status_change"
return True, "unknown_mode_fallback_full"
def _build_embedding(self, summary: str, severity: str) -> list[float]:
key = f"{summary}|{severity}|{self.embedding_dim}".encode("utf-8")
digest = hashlib.sha256(key).digest()
values: list[float] = []
for i in range(self.embedding_dim):
byte = digest[i % len(digest)]
values.append((byte / 127.5) - 1.0)
norm = sum(v * v for v in values) ** 0.5
if norm > 0:
values = [v / norm for v in values]
return values
def _sample_ticket(self) -> Ticket:
task = self._random.choice(self._task_pool)
embedding = task.get("embedding")
if embedding is None:
embedding = self._build_embedding(task["summary"], task.get("severity", "low"))
return Ticket(
id=task["id"],
summary=task["summary"],
severity=task.get("severity", "low"),
embedding=embedding,
max_attempts=task.get("max_attempts", 4),
)
def _to_observation(self) -> Observation:
if self._state is None:
raise RuntimeError("Environment has not been reset.")
ticket = self._state.ticket
return Observation(
ticket_id=ticket.id,
ticket_status=ticket.status,
attempts_used=ticket.attempts_used,
attempts_remaining=max(ticket.max_attempts - ticket.attempts_used, 0),
severity=ticket.severity,
summary=ticket.summary,
embedding=ticket.embedding,
)
def reset(self) -> Observation:
ticket = self._sample_ticket()
self._state = SystemState(episode_id=str(uuid.uuid4()), ticket=ticket)
return self._to_observation()
def step(self, action: Action | dict[str, Any]) -> tuple[Observation, float, bool, dict[str, Any]]:
if self._state is None:
self.reset()
if isinstance(action, dict):
action = Action(**action)
assert self._state is not None
previous_ticket = self._state.ticket
current_ticket = apply_action(previous_ticket, action)
self._state.ticket = current_ticket
self._state.step_count += 1
self._state.done = is_terminal(current_ticket)
distortion_penalty = 0.0
inner_product_penalty = 0.0
quantization_info: dict[str, Any] = {
"enabled": False,
"mode": self.quant_mode,
"compression_bits": self.quant_bits,
"applied": False,
"decision": "disabled",
}
should_quantize, decision = self._should_quantize(previous_ticket, current_ticket)
quantization_info["decision"] = decision
if should_quantize:
original_embedding = current_ticket.embedding
quant_code, reconstructed_embedding = self._quantizer.quantize_and_dequantize_prod(
original_embedding,
self.quant_bits,
)
query_embedding = previous_ticket.embedding or original_embedding
distortion = self._quantizer.compute_distortion(
original_embedding,
reconstructed_embedding,
query_embedding,
)
distortion_penalty = self.distortion_lambda * distortion["mse"]
inner_product_penalty = self.inner_product_lambda * distortion["inner_product_error"]
current_ticket.embedding = reconstructed_embedding
quantization_info = {
"enabled": True,
"mode": self.quant_mode,
"quantizer": "rotated_quantized_memory",
"bits": self.quant_bits,
"compression_bits": self.quant_bits,
"distortion_mse": distortion["mse"],
"inner_product_error": distortion["inner_product_error"],
"applied": True,
"decision": decision,
}
reward = compute_reward(
previous_ticket,
current_ticket,
action,
distortion_penalty=distortion_penalty,
inner_product_penalty=inner_product_penalty,
)
if self._state.done:
self._state.score = grade_episode(self._state)
result = StepResult(
observation=self._to_observation(),
reward=reward,
done=self._state.done,
info={
"episode_id": self._state.episode_id,
"step_count": self._state.step_count,
"score": self._state.score,
"quantization": quantization_info,
},
)
return (
result.observation,
result.reward.value,
result.done,
result.info,
)
def get_state(self) -> dict[str, Any]:
if self._state is None:
return {"initialized": False}
return self._state.model_dump()
def state(self) -> dict[str, Any]:
return self.get_state()
|