""" Main OpenEnv-style environment for incident operations and escalation handling. """ from __future__ import annotations from dataclasses import asdict, dataclass, is_dataclass, replace import os from pathlib import Path import re from typing import Any from rag_optimizer_env.corpus import Chunk, load_corpus, resolve_corpus_path from rag_optimizer_env.context_tuner import ContextTunedPlanner from rag_optimizer_env.graders import TaskGrader from rag_optimizer_env.llm_runtime import estimate_tokens from rag_optimizer_env.models import ChunkSummary, RagAction, RagObservation from rag_optimizer_env.retriever import HybridRetriever from rag_optimizer_env.tasks import ALL_TASKS, TASKS_BY_NAME, Task @dataclass(slots=True) class StepResult: observation: RagObservation reward: float done: bool info: dict[str, Any] class RagContextOptimizerEnv: _PROJECT_STOPWORDS = { "the", "and", "for", "with", "that", "this", "from", "into", "your", "have", "will", "using", "used", "use", "into", "they", "them", "their", "about", "while", "where", "when", "what", "which", "should", "would", "could", "there", "here", "then", "than", "each", "such", "only", "also", "been", "being", "does", "did", "done", "just", "more", "most", "very", "over", "under", "like", "same", "across", "because", "through", "make", "made", "many", "much", "some", "into", "onto", "must", "need", "needs", "task", "tasks", "chunk", "chunks", "query", "prompt", "environment", "agent", "agents", "model", "models", } _PROJECT_QUERY_HINTS = { "openenv", "benchmark", "rag-context-optimizer", "readme", "docker", "fastapi", "api", "endpoint", "inference.py", "app.py", "tasks.py", "graders.py", "environment.py", "repo", "repository", "codebase", "ui", "frontend", "backend", "space", "validator", } def __init__( self, task_name: str = "refund_triage_easy", query_override: str | None = None, token_budget_override: int | None = None, max_steps_override: int | None = None, corpus_family_override: str | None = None, ): if task_name not in TASKS_BY_NAME: raise ValueError(f"Unknown task_name: {task_name}") self._corpus_family = corpus_family_override or os.getenv("RAG_CORPUS_FAMILY") or "enterprise_v1" explicit_path = os.getenv("RAG_CORPUS_PATH") self._corpus_path = resolve_corpus_path(explicit_path, family=None if explicit_path else self._corpus_family) self._all_chunks = load_corpus(self._corpus_path) self._query_overridden = bool(query_override and query_override.strip()) self._include_project_chunks = os.getenv("ENABLE_PROJECT_CORPUS", "").strip().lower() in {"1", "true", "yes"} self._project_chunks = self._load_project_chunks() if self._include_project_chunks else [] self.retriever = HybridRetriever(self._all_chunks + self._project_chunks) self.context_tuner = ContextTunedPlanner( self.retriever, self._all_chunks + self._project_chunks, list(ALL_TASKS), ) self.grader = TaskGrader() self.task: Task = self._build_task( TASKS_BY_NAME[task_name], query_override=query_override, token_budget_override=token_budget_override, max_steps_override=max_steps_override, ) self._available_chunks: list[Chunk] = [] self._reviewed_artifacts: list[str] = [] self._selected_chunks: list[str] = [] self._compression_ratios: dict[str, float] = {} self._step_number = 0 self._done = False self._last_action_feedback: str | None = None self._last_answer = "" self._plan_draft = "" self._workflow_stage: str = "triage" self._case_id = f"{self.task.name}-001" self._last_tuning = None @staticmethod def _build_task( base_task: Task, query_override: str | None = None, token_budget_override: int | None = None, max_steps_override: int | None = None, ) -> Task: updated_task = base_task if query_override and query_override.strip(): updated_task = replace(updated_task, query=query_override.strip(), domain_filter=None) if token_budget_override is not None and token_budget_override > 0: updated_task = replace(updated_task, token_budget=token_budget_override) if max_steps_override is not None and max_steps_override > 0: updated_task = replace(updated_task, max_steps=max_steps_override) return updated_task async def reset(self) -> StepResult: candidate_chunks = self._filter_chunks_for_task(self.task) self._available_chunks = self._rank_chunks_for_query(self.task.query, candidate_chunks) if not self._query_overridden: chunk_by_id = {chunk.chunk_id: chunk for chunk in candidate_chunks} for chunk_id in self.task.required_artifact_ids + self.task.optional_artifact_ids: chunk = chunk_by_id.get(chunk_id) if chunk and all(existing.chunk_id != chunk_id for existing in self._available_chunks): self._available_chunks.append(chunk) self._reviewed_artifacts = [] self._selected_chunks = [] self._compression_ratios = {} self._step_number = 0 self._done = False self._last_action_feedback = None self._last_answer = "" self._plan_draft = "" self._workflow_stage = "triage" self._case_id = f"{self.task.name}-001" observation = self._build_observation() return StepResult( observation=observation, reward=0.0, done=False, info={"task": self.task.name, "event": "reset"}, ) async def step(self, action: RagAction) -> StepResult: if self._done: return StepResult( observation=self._build_observation(), reward=0.0, done=True, info={"task": self.task.name, "event": "episode_already_done"}, ) reward = 0.0 info: dict[str, Any] = {"task": self.task.name, "action_type": action.action_type} artifact_id = action.artifact_id or action.chunk_id or "" if action.action_type == "inspect_artifact": reward, info = self._handle_inspect(artifact_id, auto_prioritize=False) elif action.action_type == "select_chunk": reward, info = self._handle_inspect(artifact_id, auto_prioritize=True) elif action.action_type == "prioritize_artifact": reward, info = self._handle_prioritize(artifact_id) elif action.action_type == "deselect_chunk": reward, info = self._handle_deprioritize(artifact_id) elif action.action_type in {"summarize_artifact", "compress_chunk"}: reward, info = self._handle_compress(artifact_id, float(action.compression_ratio or 0.0)) elif action.action_type == "set_resolution_plan": reward, info = self._handle_plan(action.plan or "") elif action.action_type in {"submit_report", "submit_answer"}: self._last_answer = action.answer or "" result = await self._finalize_submission(reason="submit_report") self._step_number += 1 result.observation.step_number = self._step_number return result self._step_number += 1 self._update_workflow_stage() if self._step_number >= self.task.max_steps: return await self._finalize_submission(reason="max_steps_reached") observation = self._build_observation() return StepResult(observation=observation, reward=reward, done=False, info=info) async def state(self) -> dict: prioritized_artifact_details = [] for chunk_id in self._selected_chunks: chunk = self._chunk_map().get(chunk_id) if chunk is None: continue prioritized_artifact_details.append( { "artifact_id": chunk.chunk_id, "chunk_id": chunk.chunk_id, "domain": chunk.domain, "original_tokens": chunk.tokens, "effective_tokens": self._effective_chunk_tokens(chunk_id), "compression_ratio": round(self._compression_ratios.get(chunk_id, 1.0), 3), "text": self._effective_chunk_text(chunk_id), "keywords": chunk.keywords, } ) optimized_prompt = self._build_optimized_prompt() optimized_prompt_tokens = await estimate_tokens(optimized_prompt) if optimized_prompt else 0 return { "task": asdict(self.task) if is_dataclass(self.task) else self.task, "case_id": self._case_id, "case_summary": self.task.case_summary, "objective": self.task.query, "workflow_stage": self._workflow_stage, "customer_tier": self.task.customer_tier, "incident_severity": self.task.incident_severity, "step_number": self._step_number, "done": self._done, "reviewed_artifacts": list(self._reviewed_artifacts), "prioritized_artifacts": list(self._selected_chunks), "selected_chunks": list(self._selected_chunks), "compression_ratios": dict(self._compression_ratios), "plan_draft": self._plan_draft, "report_requirements": list(self.task.report_requirements), "progress_signals": self._progress_signals(), "total_tokens_used": self._total_tokens_used(), "token_budget": self.task.token_budget, "last_action_feedback": self._last_action_feedback, "last_answer": self._last_answer, "corpus_family": self._corpus_family, "corpus_path": str(self._corpus_path), "available_artifact_ids": [chunk.chunk_id for chunk in self._available_chunks], "available_chunk_ids": [chunk.chunk_id for chunk in self._available_chunks], "prioritized_artifact_details": prioritized_artifact_details, "selected_chunk_details": prioritized_artifact_details, "optimized_prompt_preview": optimized_prompt, "optimized_prompt_tokens": optimized_prompt_tokens, "context_tuning": ( { "mode": self._last_tuning.mode, "top_demo_cases": self._last_tuning.top_demo_cases, "suggested_citations": self._last_tuning.suggested_citations, "token_dropout": self._last_tuning.token_dropout, "leave_one_out": self._last_tuning.leave_one_out, } if self._last_tuning is not None else None ), } async def close(self): self._done = True def _filter_chunks_for_task(self, task: Task) -> list[Chunk]: domain_mapping = { "customer_support_operations": "Customer Support Operations", "incident_response_playbooks": "Incident Response Playbooks", "platform_reliability_release_engineering": "Platform Reliability & Release Engineering", } if self._query_overridden: if self._include_project_chunks and self._is_project_query(task.query): return list(self._all_chunks) + list(self._project_chunks) return list(self._all_chunks) if task.domain_filter is None: return list(self._all_chunks) normalized = domain_mapping.get(task.domain_filter, task.domain_filter) return [chunk for chunk in self._all_chunks if chunk.domain == normalized] def _is_project_query(self, query: str) -> bool: lowered = query.lower() return any(hint in lowered for hint in self._PROJECT_QUERY_HINTS) def _rank_chunks_for_query(self, query: str, chunks: list[Chunk], top_k: int = 20) -> list[Chunk]: tuning = self.context_tuner.tune(query, chunks) self._last_tuning = tuning scored = [] for chunk in chunks: tuned = tuning.tuned_scores.get(chunk.chunk_id) score = tuned.final_score if tuned is not None else self.retriever.hybrid_score(query, chunk) if self._include_project_chunks and self._query_overridden and chunk.domain.startswith("Project"): score = min(1.0, score + 0.08) scored.append((chunk, score)) scored.sort(key=lambda item: (-item[1], item[0].tokens, item[0].chunk_id)) return [chunk for chunk, _score in scored[: max(1, min(top_k, len(scored)))]] def _load_project_chunks(self) -> list[Chunk]: root = Path(__file__).resolve().parent.parent chunks: list[Chunk] = [] file_specs = [ ("Project Documentation", root / "README.md", ["project_docs", "readme"]), ("Project Configuration", root / "openenv.yaml", ["project_docs", "config", "openenv_spec"]), ("Project API", root / "app.py", ["project_docs", "api", "server"]), ("Project Baseline", root / "inference.py", ["project_docs", "baseline", "inference"]), ("Project Environment", root / "env" / "environment.py", ["project_docs", "environment", "state_management"]), ("Project Retrieval", root / "env" / "retriever.py", ["project_docs", "retrieval", "ranking"]), ("Project Grading", root / "env" / "graders.py", ["project_docs", "grading", "reward_design"]), ("Project Tasks", root / "env" / "tasks.py", ["project_docs", "tasks", "difficulty"]), ("Project Validation", root / "validate.py", ["project_docs", "validation", "testing"]), ] for domain, path, tags in file_specs: if not path.exists(): continue raw_text = path.read_text(encoding="utf-8", errors="ignore") sections = self._chunk_project_text(raw_text) stem = re.sub(r"[^a-z0-9]+", "_", path.stem.lower()).strip("_") or "file" for index, section in enumerate(sections, start=1): keywords = self._extract_project_keywords(section) or [stem, domain.lower()] chunks.append( Chunk( chunk_id=f"project_{stem}_{index:03d}", domain=domain, text=section, tokens=max(30, len(section) // 4), keywords=keywords[:5], relevance_tags=tags, ) ) return chunks def _chunk_project_text(self, raw_text: str, chunk_words: int = 140, stride_words: int = 100) -> list[str]: cleaned = " ".join(raw_text.split()) words = cleaned.split() if not words: return [] if len(words) <= chunk_words: return [" ".join(words)] chunks: list[str] = [] start = 0 while start < len(words): window = words[start : start + chunk_words] if not window: break chunks.append(" ".join(window)) if start + chunk_words >= len(words): break start += stride_words return chunks def _extract_project_keywords(self, text: str) -> list[str]: terms = re.findall(r"[a-z0-9_]+", text.lower()) counts: dict[str, int] = {} for term in terms: if len(term) < 4 or term in self._PROJECT_STOPWORDS: continue counts[term] = counts.get(term, 0) + 1 ranked = sorted(counts.items(), key=lambda item: (-item[1], item[0])) return [term.replace("_", " ") for term, _count in ranked[:8]] def _build_observation(self) -> RagObservation: available = [ ChunkSummary( chunk_id=chunk.chunk_id, domain=chunk.domain, tokens=self._effective_chunk_tokens(chunk.chunk_id), keywords=chunk.keywords, ) for chunk in self._available_chunks ] return RagObservation( case_id=self._case_id, case_summary=self.task.case_summary, objective=self.task.query, workflow_stage=self._workflow_stage, customer_tier=self.task.customer_tier, incident_severity=self.task.incident_severity, available_artifacts=available, reviewed_artifacts=list(self._reviewed_artifacts), prioritized_artifacts=list(self._selected_chunks), plan_draft=self._plan_draft or None, report_requirements=list(self.task.report_requirements), progress_signals=self._progress_signals(), total_tokens_used=self._total_tokens_used(), token_budget=self.task.token_budget, step_number=self._step_number, task_name=self.task.name, last_action_feedback=self._last_action_feedback, query=self.task.query, available_chunks=available, selected_chunks=list(self._selected_chunks), ) def _chunk_map(self) -> dict[str, Chunk]: return {chunk.chunk_id: chunk for chunk in self._available_chunks} def _effective_chunk_tokens(self, chunk_id: str) -> int: chunk = self._chunk_map().get(chunk_id) if chunk is None: return 0 ratio = self._compression_ratios.get(chunk_id, 1.0) return max(1, int(round(chunk.tokens * ratio))) def _total_tokens_used(self) -> int: return sum(self._effective_chunk_tokens(chunk_id) for chunk_id in self._selected_chunks) def _effective_chunk_text(self, chunk_id: str) -> str: chunk = self._chunk_map().get(chunk_id) if chunk is None: return "" ratio = self._compression_ratios.get(chunk_id, 1.0) text = " ".join(chunk.text.split()) if ratio >= 0.999: return text query_terms = self._query_terms(self.task.query) keyword_terms = self._query_terms(" ".join(chunk.keywords)) sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", text) if segment.strip()] if not sentences: return self._truncate_words(text, ratio) ranked_sentences: list[tuple[int, float, int, str]] = [] for index, sentence in enumerate(sentences): sentence_terms = self._query_terms(sentence) overlap = len(sentence_terms & query_terms) keyword_overlap = len(sentence_terms & keyword_terms) score = (overlap * 2.0) + keyword_overlap + (0.25 if index == 0 else 0.0) ranked_sentences.append((index, score, len(sentence.split()), sentence)) target_words = max(18, int(len(text.split()) * ratio)) chosen: list[tuple[int, str]] = [] used_words = 0 for index, _score, word_count, sentence in sorted(ranked_sentences, key=lambda item: (-item[1], item[2], item[0])): if used_words >= target_words: break chosen.append((index, sentence)) used_words += word_count if not chosen: return self._truncate_words(text, ratio) chosen.sort(key=lambda item: item[0]) compressed = " ".join(sentence for _index, sentence in chosen) return self._truncate_words(compressed, ratio) @staticmethod def _truncate_words(text: str, ratio: float) -> str: words = text.split() if not words: return "" keep = max(10, int(len(words) * ratio)) truncated = " ".join(words[:keep]) if keep < len(words): return truncated + " ..." return truncated @staticmethod def _query_terms(text: str) -> set[str]: return {token for token in re.findall(r"[a-z0-9]+", text.lower()) if len(token) > 2} def _build_optimized_prompt(self) -> str: sections = [ f"Case: {self.task.case_summary}", f"Objective: {self.task.query}", f"Stage: {self._workflow_stage}", ] if self._plan_draft: sections.extend(["", f"Plan Draft: {self._plan_draft}"]) if self._selected_chunks: sections.extend(["", "Prioritized Evidence:"]) for chunk_id in self._selected_chunks: chunk = self._chunk_map().get(chunk_id) if chunk is None: continue sections.append(f"[{chunk.chunk_id} | {self._effective_chunk_tokens(chunk_id)} tokens] {self._effective_chunk_text(chunk_id)}") return "\n".join(sections).strip() def _is_relevant(self, chunk_id: str) -> tuple[bool, float]: chunk = self._chunk_map().get(chunk_id) if chunk is None: return False, 0.0 score = self.retriever.hybrid_score(self.task.query, chunk) return score >= 0.3, score def _is_required(self, chunk_id: str) -> bool: return chunk_id in set(self.task.required_artifact_ids) def _progress_signals(self) -> dict[str, float]: required = set(self.task.required_artifact_ids) reviewed_hits = len(required & set(self._reviewed_artifacts)) / len(required) if required else 1.0 prioritized_hits = len(required & set(self._selected_chunks)) / len(required) if required else 1.0 plan_keywords = sum(1 for keyword in self.task.required_plan_keywords if keyword.lower() in self._plan_draft.lower()) plan_quality = plan_keywords / len(self.task.required_plan_keywords) if self.task.required_plan_keywords else 1.0 return { "review_coverage": round(reviewed_hits, 3), "priority_coverage": round(prioritized_hits, 3), "plan_quality": round(plan_quality, 3), "budget_headroom": round(max(0.0, 1.0 - (self._total_tokens_used() / self.task.token_budget)), 3), } def _update_workflow_stage(self) -> None: if self._done: self._workflow_stage = "submitted" elif self._plan_draft.strip(): self._workflow_stage = "resolution" elif self._reviewed_artifacts: self._workflow_stage = "analysis" else: self._workflow_stage = "triage" def _handle_inspect(self, chunk_id: str, auto_prioritize: bool) -> tuple[float, dict[str, Any]]: chunk = self._chunk_map().get(chunk_id) if chunk is None: self._last_action_feedback = "artifact_not_found" return -0.1, {"event": "artifact_not_found", "artifact_id": chunk_id} if chunk_id not in self._reviewed_artifacts: self._reviewed_artifacts.append(chunk_id) is_relevant, score = self._is_relevant(chunk_id) reward = 0.03 + (0.08 if self._is_required(chunk_id) else 0.0) + (0.05 if is_relevant else 0.0) info = {"event": "artifact_inspected", "artifact_id": chunk_id, "hybrid_score": score} self._last_action_feedback = "artifact_inspected" if auto_prioritize: priority_reward, priority_info = self._handle_prioritize(chunk_id, inspected=True) reward += priority_reward info["auto_prioritize"] = priority_info return min(reward, 0.2), info def _handle_prioritize(self, chunk_id: str, inspected: bool = False) -> tuple[float, dict[str, Any]]: chunk = self._chunk_map().get(chunk_id) if chunk is None: self._last_action_feedback = "artifact_not_found" return -0.1, {"event": "artifact_not_found", "artifact_id": chunk_id} if chunk_id not in self._reviewed_artifacts and not inspected: self._last_action_feedback = "artifact_not_reviewed" return -0.05, {"event": "artifact_not_reviewed", "artifact_id": chunk_id} if chunk_id in self._selected_chunks: self._last_action_feedback = "artifact_already_prioritized" return 0.0, {"event": "artifact_already_prioritized", "artifact_id": chunk_id} projected_tokens = self._total_tokens_used() + self._effective_chunk_tokens(chunk_id) if projected_tokens > self.task.token_budget: self._last_action_feedback = "exceeded_budget" return -0.1, {"event": "exceeded_budget", "artifact_id": chunk_id} self._selected_chunks.append(chunk_id) is_relevant, score = self._is_relevant(chunk_id) domain_bonus = 0.04 if len({self._chunk_map()[cid].domain for cid in self._selected_chunks if cid in self._chunk_map()}) > 1 else 0.0 reward = (0.10 if self._is_required(chunk_id) else 0.03) + (0.05 if is_relevant else 0.0) + domain_bonus self._last_action_feedback = "artifact_prioritized" return min(reward, 0.18), {"event": "artifact_prioritized", "artifact_id": chunk_id, "hybrid_score": score} def _handle_deprioritize(self, chunk_id: str) -> tuple[float, dict[str, Any]]: if chunk_id not in self._selected_chunks: self._last_action_feedback = "artifact_not_prioritized" return 0.0, {"event": "artifact_not_prioritized", "artifact_id": chunk_id} self._selected_chunks.remove(chunk_id) is_required = self._is_required(chunk_id) reward = -0.06 if is_required else 0.03 self._last_action_feedback = "artifact_deprioritized" return reward, {"event": "artifact_deprioritized", "artifact_id": chunk_id, "required": is_required} def _handle_compress(self, chunk_id: str, compression_ratio: float) -> tuple[float, dict[str, Any]]: chunk = self._chunk_map().get(chunk_id) if chunk is None: self._last_action_feedback = "artifact_not_found" return -0.1, {"event": "artifact_not_found", "artifact_id": chunk_id} if chunk_id not in self._selected_chunks: self._last_action_feedback = "artifact_not_prioritized" return -0.04, {"event": "artifact_not_prioritized", "artifact_id": chunk_id} self._compression_ratios[chunk_id] = compression_ratio is_relevant, score = self._is_relevant(chunk_id) reward = 0.04 if is_relevant else 0.0 if self._is_required(chunk_id) and compression_ratio < 0.45: reward -= 0.06 self._last_action_feedback = "overcompressed_required_artifact" return reward, {"event": "overcompressed_required_artifact", "artifact_id": chunk_id, "hybrid_score": score} self._last_action_feedback = "artifact_summarized" return reward, {"event": "artifact_summarized", "artifact_id": chunk_id, "hybrid_score": score} def _handle_plan(self, plan: str) -> tuple[float, dict[str, Any]]: self._plan_draft = plan.strip() if not self._plan_draft: self._last_action_feedback = "empty_plan" return -0.05, {"event": "empty_plan"} hits = sum(1 for keyword in self.task.required_plan_keywords if keyword.lower() in self._plan_draft.lower()) coverage = hits / len(self.task.required_plan_keywords) if self.task.required_plan_keywords else 1.0 reviewed_bonus = min(0.1, 0.02 * len(self._reviewed_artifacts)) reward = (0.04 + (0.18 * coverage) + reviewed_bonus) self._last_action_feedback = "plan_updated" return min(reward, 0.26), {"event": "plan_updated", "plan_quality": coverage} async def _finalize_submission(self, reason: str) -> StepResult: self._done = True self._update_workflow_stage() if not self._selected_chunks: self._last_action_feedback = "no_prioritized_artifacts" observation = self._build_observation() return StepResult( observation=observation, reward=0.0, done=True, info={"event": reason, "grader": None, "passed": False}, ) grader_result = self.grader.grade( prioritized_artifact_ids=list(self._selected_chunks), reviewed_artifact_ids=list(self._reviewed_artifacts), answer=self._last_answer, plan_draft=self._plan_draft, workflow_stage=self._workflow_stage, token_budget=self.task.token_budget, total_tokens_used=self._total_tokens_used(), retriever=self.retriever, task=self.task, ) self._last_action_feedback = reason observation = self._build_observation() return StepResult( observation=observation, reward=grader_result.score, done=True, info={"event": reason, "grader": grader_result.breakdown, "passed": grader_result.passed}, )