workflow-twin / env /environment.py
NDGCodes's picture
fix repo structure for HF
1a692ce
from __future__ import annotations
import hashlib
import random
import uuid
from typing import Any
from .dynamics import apply_action, is_terminal
from .entities import SystemState, Ticket
from .graders import grade_episode
from .models import Action, Observation, StepResult
from .quantizer import RotatedQuantizedMemory
from .reward import compute_reward
from .tasks import load_tasks
class OpenEnv:
def __init__(
self,
difficulty: str = "easy",
seed: int | None = None,
use_quantizer: bool = False,
quant_mode: str = "full",
quant_every_n_steps: int = 1,
embedding_dim: int = 16,
quant_bits: int = 3,
distortion_lambda: float = 0.2,
inner_product_lambda: float = 0.1,
) -> None:
self.difficulty = difficulty
self._random = random.Random(seed)
self._task_pool = load_tasks(difficulty)
self._state: SystemState | None = None
self.use_quantizer = use_quantizer
self.quant_mode = quant_mode
self.quant_every_n_steps = max(1, quant_every_n_steps)
self.embedding_dim = embedding_dim
self.quant_bits = quant_bits
self.distortion_lambda = distortion_lambda
self.inner_product_lambda = inner_product_lambda
self._quantizer = RotatedQuantizedMemory(embedding_dim, seed or 42) if use_quantizer else None
def _should_quantize(self, previous_ticket: Ticket, current_ticket: Ticket) -> tuple[bool, str]:
if not self.use_quantizer or not self._quantizer:
return False, "disabled"
mode = self.quant_mode.lower()
if mode == "off":
return False, "mode_off"
if current_ticket.embedding is None:
return False, "no_embedding"
if mode == "full":
return True, "mode_full"
step_count = self._state.step_count if self._state else 0
on_schedule = step_count % self.quant_every_n_steps == 0
status_changed = previous_ticket.status != current_ticket.status
if mode == "throttle":
return (on_schedule, "schedule" if on_schedule else "throttled")
if mode == "status":
return (status_changed, "status_change" if status_changed else "no_status_change")
if mode == "hybrid":
should = on_schedule or status_changed
if should:
return True, "schedule_or_status"
return False, "throttled_no_status_change"
return True, "unknown_mode_fallback_full"
def _build_embedding(self, summary: str, severity: str) -> list[float]:
key = f"{summary}|{severity}|{self.embedding_dim}".encode("utf-8")
digest = hashlib.sha256(key).digest()
values: list[float] = []
for i in range(self.embedding_dim):
byte = digest[i % len(digest)]
values.append((byte / 127.5) - 1.0)
norm = sum(v * v for v in values) ** 0.5
if norm > 0:
values = [v / norm for v in values]
return values
def _sample_ticket(self) -> Ticket:
task = self._random.choice(self._task_pool)
embedding = task.get("embedding")
if embedding is None:
embedding = self._build_embedding(task["summary"], task.get("severity", "low"))
return Ticket(
id=task["id"],
summary=task["summary"],
severity=task.get("severity", "low"),
embedding=embedding,
max_attempts=task.get("max_attempts", 4),
)
def _to_observation(self) -> Observation:
if self._state is None:
raise RuntimeError("Environment has not been reset.")
ticket = self._state.ticket
return Observation(
ticket_id=ticket.id,
ticket_status=ticket.status,
attempts_used=ticket.attempts_used,
attempts_remaining=max(ticket.max_attempts - ticket.attempts_used, 0),
severity=ticket.severity,
summary=ticket.summary,
embedding=ticket.embedding,
)
def reset(self) -> Observation:
ticket = self._sample_ticket()
self._state = SystemState(episode_id=str(uuid.uuid4()), ticket=ticket)
return self._to_observation()
def step(self, action: Action | dict[str, Any]) -> tuple[Observation, float, bool, dict[str, Any]]:
if self._state is None:
self.reset()
if isinstance(action, dict):
action = Action(**action)
assert self._state is not None
previous_ticket = self._state.ticket
current_ticket = apply_action(previous_ticket, action)
self._state.ticket = current_ticket
self._state.step_count += 1
self._state.done = is_terminal(current_ticket)
distortion_penalty = 0.0
inner_product_penalty = 0.0
quantization_info: dict[str, Any] = {
"enabled": False,
"mode": self.quant_mode,
"compression_bits": self.quant_bits,
"applied": False,
"decision": "disabled",
}
should_quantize, decision = self._should_quantize(previous_ticket, current_ticket)
quantization_info["decision"] = decision
if should_quantize:
original_embedding = current_ticket.embedding
quant_code, reconstructed_embedding = self._quantizer.quantize_and_dequantize_prod(
original_embedding,
self.quant_bits,
)
query_embedding = previous_ticket.embedding or original_embedding
distortion = self._quantizer.compute_distortion(
original_embedding,
reconstructed_embedding,
query_embedding,
)
distortion_penalty = self.distortion_lambda * distortion["mse"]
inner_product_penalty = self.inner_product_lambda * distortion["inner_product_error"]
current_ticket.embedding = reconstructed_embedding
quantization_info = {
"enabled": True,
"mode": self.quant_mode,
"quantizer": "rotated_quantized_memory",
"bits": self.quant_bits,
"compression_bits": self.quant_bits,
"distortion_mse": distortion["mse"],
"inner_product_error": distortion["inner_product_error"],
"applied": True,
"decision": decision,
}
reward = compute_reward(
previous_ticket,
current_ticket,
action,
distortion_penalty=distortion_penalty,
inner_product_penalty=inner_product_penalty,
)
if self._state.done:
self._state.score = grade_episode(self._state)
result = StepResult(
observation=self._to_observation(),
reward=reward,
done=self._state.done,
info={
"episode_id": self._state.episode_id,
"step_count": self._state.step_count,
"score": self._state.score,
"quantization": quantization_info,
},
)
return (
result.observation,
result.reward.value,
result.done,
result.info,
)
def get_state(self) -> dict[str, Any]:
if self._state is None:
return {"initialized": False}
return self._state.model_dump()
def state(self) -> dict[str, Any]:
return self.get_state()