Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import hashlib | |
| import random | |
| import uuid | |
| from typing import Any | |
| from .dynamics import apply_action, is_terminal | |
| from .entities import SystemState, Ticket | |
| from .graders import grade_episode | |
| from .models import Action, Observation, StepResult | |
| from .quantizer import RotatedQuantizedMemory | |
| from .reward import compute_reward | |
| from .tasks import load_tasks | |
| class OpenEnv: | |
| def __init__( | |
| self, | |
| difficulty: str = "easy", | |
| seed: int | None = None, | |
| use_quantizer: bool = False, | |
| quant_mode: str = "full", | |
| quant_every_n_steps: int = 1, | |
| embedding_dim: int = 16, | |
| quant_bits: int = 3, | |
| distortion_lambda: float = 0.2, | |
| inner_product_lambda: float = 0.1, | |
| ) -> None: | |
| self.difficulty = difficulty | |
| self._random = random.Random(seed) | |
| self._task_pool = load_tasks(difficulty) | |
| self._state: SystemState | None = None | |
| self.use_quantizer = use_quantizer | |
| self.quant_mode = quant_mode | |
| self.quant_every_n_steps = max(1, quant_every_n_steps) | |
| self.embedding_dim = embedding_dim | |
| self.quant_bits = quant_bits | |
| self.distortion_lambda = distortion_lambda | |
| self.inner_product_lambda = inner_product_lambda | |
| self._quantizer = RotatedQuantizedMemory(embedding_dim, seed or 42) if use_quantizer else None | |
| def _should_quantize(self, previous_ticket: Ticket, current_ticket: Ticket) -> tuple[bool, str]: | |
| if not self.use_quantizer or not self._quantizer: | |
| return False, "disabled" | |
| mode = self.quant_mode.lower() | |
| if mode == "off": | |
| return False, "mode_off" | |
| if current_ticket.embedding is None: | |
| return False, "no_embedding" | |
| if mode == "full": | |
| return True, "mode_full" | |
| step_count = self._state.step_count if self._state else 0 | |
| on_schedule = step_count % self.quant_every_n_steps == 0 | |
| status_changed = previous_ticket.status != current_ticket.status | |
| if mode == "throttle": | |
| return (on_schedule, "schedule" if on_schedule else "throttled") | |
| if mode == "status": | |
| return (status_changed, "status_change" if status_changed else "no_status_change") | |
| if mode == "hybrid": | |
| should = on_schedule or status_changed | |
| if should: | |
| return True, "schedule_or_status" | |
| return False, "throttled_no_status_change" | |
| return True, "unknown_mode_fallback_full" | |
| def _build_embedding(self, summary: str, severity: str) -> list[float]: | |
| key = f"{summary}|{severity}|{self.embedding_dim}".encode("utf-8") | |
| digest = hashlib.sha256(key).digest() | |
| values: list[float] = [] | |
| for i in range(self.embedding_dim): | |
| byte = digest[i % len(digest)] | |
| values.append((byte / 127.5) - 1.0) | |
| norm = sum(v * v for v in values) ** 0.5 | |
| if norm > 0: | |
| values = [v / norm for v in values] | |
| return values | |
| def _sample_ticket(self) -> Ticket: | |
| task = self._random.choice(self._task_pool) | |
| embedding = task.get("embedding") | |
| if embedding is None: | |
| embedding = self._build_embedding(task["summary"], task.get("severity", "low")) | |
| return Ticket( | |
| id=task["id"], | |
| summary=task["summary"], | |
| severity=task.get("severity", "low"), | |
| embedding=embedding, | |
| max_attempts=task.get("max_attempts", 4), | |
| ) | |
| def _to_observation(self) -> Observation: | |
| if self._state is None: | |
| raise RuntimeError("Environment has not been reset.") | |
| ticket = self._state.ticket | |
| return Observation( | |
| ticket_id=ticket.id, | |
| ticket_status=ticket.status, | |
| attempts_used=ticket.attempts_used, | |
| attempts_remaining=max(ticket.max_attempts - ticket.attempts_used, 0), | |
| severity=ticket.severity, | |
| summary=ticket.summary, | |
| embedding=ticket.embedding, | |
| ) | |
| def reset(self) -> Observation: | |
| ticket = self._sample_ticket() | |
| self._state = SystemState(episode_id=str(uuid.uuid4()), ticket=ticket) | |
| return self._to_observation() | |
| def step(self, action: Action | dict[str, Any]) -> tuple[Observation, float, bool, dict[str, Any]]: | |
| if self._state is None: | |
| self.reset() | |
| if isinstance(action, dict): | |
| action = Action(**action) | |
| assert self._state is not None | |
| previous_ticket = self._state.ticket | |
| current_ticket = apply_action(previous_ticket, action) | |
| self._state.ticket = current_ticket | |
| self._state.step_count += 1 | |
| self._state.done = is_terminal(current_ticket) | |
| distortion_penalty = 0.0 | |
| inner_product_penalty = 0.0 | |
| quantization_info: dict[str, Any] = { | |
| "enabled": False, | |
| "mode": self.quant_mode, | |
| "compression_bits": self.quant_bits, | |
| "applied": False, | |
| "decision": "disabled", | |
| } | |
| should_quantize, decision = self._should_quantize(previous_ticket, current_ticket) | |
| quantization_info["decision"] = decision | |
| if should_quantize: | |
| original_embedding = current_ticket.embedding | |
| quant_code, reconstructed_embedding = self._quantizer.quantize_and_dequantize_prod( | |
| original_embedding, | |
| self.quant_bits, | |
| ) | |
| query_embedding = previous_ticket.embedding or original_embedding | |
| distortion = self._quantizer.compute_distortion( | |
| original_embedding, | |
| reconstructed_embedding, | |
| query_embedding, | |
| ) | |
| distortion_penalty = self.distortion_lambda * distortion["mse"] | |
| inner_product_penalty = self.inner_product_lambda * distortion["inner_product_error"] | |
| current_ticket.embedding = reconstructed_embedding | |
| quantization_info = { | |
| "enabled": True, | |
| "mode": self.quant_mode, | |
| "quantizer": "rotated_quantized_memory", | |
| "bits": self.quant_bits, | |
| "compression_bits": self.quant_bits, | |
| "distortion_mse": distortion["mse"], | |
| "inner_product_error": distortion["inner_product_error"], | |
| "applied": True, | |
| "decision": decision, | |
| } | |
| reward = compute_reward( | |
| previous_ticket, | |
| current_ticket, | |
| action, | |
| distortion_penalty=distortion_penalty, | |
| inner_product_penalty=inner_product_penalty, | |
| ) | |
| if self._state.done: | |
| self._state.score = grade_episode(self._state) | |
| result = StepResult( | |
| observation=self._to_observation(), | |
| reward=reward, | |
| done=self._state.done, | |
| info={ | |
| "episode_id": self._state.episode_id, | |
| "step_count": self._state.step_count, | |
| "score": self._state.score, | |
| "quantization": quantization_info, | |
| }, | |
| ) | |
| return ( | |
| result.observation, | |
| result.reward.value, | |
| result.done, | |
| result.info, | |
| ) | |
| def get_state(self) -> dict[str, Any]: | |
| if self._state is None: | |
| return {"initialized": False} | |
| return self._state.model_dump() | |
| def state(self) -> dict[str, Any]: | |
| return self.get_state() | |