from __future__ import annotations import hashlib import random import uuid from typing import Any from .dynamics import apply_action, is_terminal from .entities import SystemState, Ticket from .graders import grade_episode from .models import Action, Observation, StepResult from .quantizer import RotatedQuantizedMemory from .reward import compute_reward from .tasks import load_tasks class OpenEnv: def __init__( self, difficulty: str = "easy", seed: int | None = None, use_quantizer: bool = False, quant_mode: str = "full", quant_every_n_steps: int = 1, embedding_dim: int = 16, quant_bits: int = 3, distortion_lambda: float = 0.2, inner_product_lambda: float = 0.1, ) -> None: self.difficulty = difficulty self._random = random.Random(seed) self._task_pool = load_tasks(difficulty) self._state: SystemState | None = None self.use_quantizer = use_quantizer self.quant_mode = quant_mode self.quant_every_n_steps = max(1, quant_every_n_steps) self.embedding_dim = embedding_dim self.quant_bits = quant_bits self.distortion_lambda = distortion_lambda self.inner_product_lambda = inner_product_lambda self._quantizer = RotatedQuantizedMemory(embedding_dim, seed or 42) if use_quantizer else None def _should_quantize(self, previous_ticket: Ticket, current_ticket: Ticket) -> tuple[bool, str]: if not self.use_quantizer or not self._quantizer: return False, "disabled" mode = self.quant_mode.lower() if mode == "off": return False, "mode_off" if current_ticket.embedding is None: return False, "no_embedding" if mode == "full": return True, "mode_full" step_count = self._state.step_count if self._state else 0 on_schedule = step_count % self.quant_every_n_steps == 0 status_changed = previous_ticket.status != current_ticket.status if mode == "throttle": return (on_schedule, "schedule" if on_schedule else "throttled") if mode == "status": return (status_changed, "status_change" if status_changed else "no_status_change") if mode == "hybrid": should = on_schedule or status_changed if should: return True, "schedule_or_status" return False, "throttled_no_status_change" return True, "unknown_mode_fallback_full" def _build_embedding(self, summary: str, severity: str) -> list[float]: key = f"{summary}|{severity}|{self.embedding_dim}".encode("utf-8") digest = hashlib.sha256(key).digest() values: list[float] = [] for i in range(self.embedding_dim): byte = digest[i % len(digest)] values.append((byte / 127.5) - 1.0) norm = sum(v * v for v in values) ** 0.5 if norm > 0: values = [v / norm for v in values] return values def _sample_ticket(self) -> Ticket: task = self._random.choice(self._task_pool) embedding = task.get("embedding") if embedding is None: embedding = self._build_embedding(task["summary"], task.get("severity", "low")) return Ticket( id=task["id"], summary=task["summary"], severity=task.get("severity", "low"), embedding=embedding, max_attempts=task.get("max_attempts", 4), ) def _to_observation(self) -> Observation: if self._state is None: raise RuntimeError("Environment has not been reset.") ticket = self._state.ticket return Observation( ticket_id=ticket.id, ticket_status=ticket.status, attempts_used=ticket.attempts_used, attempts_remaining=max(ticket.max_attempts - ticket.attempts_used, 0), severity=ticket.severity, summary=ticket.summary, embedding=ticket.embedding, ) def reset(self) -> Observation: ticket = self._sample_ticket() self._state = SystemState(episode_id=str(uuid.uuid4()), ticket=ticket) return self._to_observation() def step(self, action: Action | dict[str, Any]) -> tuple[Observation, float, bool, dict[str, Any]]: if self._state is None: self.reset() if isinstance(action, dict): action = Action(**action) assert self._state is not None previous_ticket = self._state.ticket current_ticket = apply_action(previous_ticket, action) self._state.ticket = current_ticket self._state.step_count += 1 self._state.done = is_terminal(current_ticket) distortion_penalty = 0.0 inner_product_penalty = 0.0 quantization_info: dict[str, Any] = { "enabled": False, "mode": self.quant_mode, "compression_bits": self.quant_bits, "applied": False, "decision": "disabled", } should_quantize, decision = self._should_quantize(previous_ticket, current_ticket) quantization_info["decision"] = decision if should_quantize: original_embedding = current_ticket.embedding quant_code, reconstructed_embedding = self._quantizer.quantize_and_dequantize_prod( original_embedding, self.quant_bits, ) query_embedding = previous_ticket.embedding or original_embedding distortion = self._quantizer.compute_distortion( original_embedding, reconstructed_embedding, query_embedding, ) distortion_penalty = self.distortion_lambda * distortion["mse"] inner_product_penalty = self.inner_product_lambda * distortion["inner_product_error"] current_ticket.embedding = reconstructed_embedding quantization_info = { "enabled": True, "mode": self.quant_mode, "quantizer": "rotated_quantized_memory", "bits": self.quant_bits, "compression_bits": self.quant_bits, "distortion_mse": distortion["mse"], "inner_product_error": distortion["inner_product_error"], "applied": True, "decision": decision, } reward = compute_reward( previous_ticket, current_ticket, action, distortion_penalty=distortion_penalty, inner_product_penalty=inner_product_penalty, ) if self._state.done: self._state.score = grade_episode(self._state) result = StepResult( observation=self._to_observation(), reward=reward, done=self._state.done, info={ "episode_id": self._state.episode_id, "step_count": self._state.step_count, "score": self._state.score, "quantization": quantization_info, }, ) return ( result.observation, result.reward.value, result.done, result.info, ) def get_state(self) -> dict[str, Any]: if self._state is None: return {"initialized": False} return self._state.model_dump() def state(self) -> dict[str, Any]: return self.get_state()