Spaces:
Sleeping
Sleeping
Beginning of project
Browse files- .dockerignore +8 -0
- Dockerfile +15 -0
- app.py +4 -0
- debug_tokens.py +21 -0
- openenv.yaml +6 -0
- output.log +0 -0
- output_utf8.log +4 -0
- rag_gc_env/__init__.py +11 -0
- rag_gc_env/__pycache__/__init__.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/environment.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/grader.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/inference.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/models.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/rewards.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/tasks.cpython-311.pyc +0 -0
- rag_gc_env/environment.py +187 -0
- rag_gc_env/grader.py +43 -0
- rag_gc_env/inference.py +49 -0
- rag_gc_env/models.py +53 -0
- rag_gc_env/rewards.py +89 -0
- rag_gc_env/server/__init__.py +1 -0
- rag_gc_env/server/__pycache__/__init__.cpython-311.pyc +0 -0
- rag_gc_env/server/__pycache__/app.cpython-311.pyc +0 -0
- rag_gc_env/server/app.py +23 -0
- rag_gc_env/tasks.py +144 -0
- requirements.txt +5 -0
- test_reward_logic.py +43 -0
.dockerignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
.git
|
| 4 |
+
.venv
|
| 5 |
+
venv
|
| 6 |
+
*.md
|
| 7 |
+
.pytest_cache
|
| 8 |
+
.mypy_cache
|
Dockerfile
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
ENV PYTHONPATH=/app
|
| 7 |
+
|
| 8 |
+
COPY requirements.txt .
|
| 9 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 10 |
+
|
| 11 |
+
COPY rag_gc_env ./rag_gc_env
|
| 12 |
+
|
| 13 |
+
EXPOSE 7860
|
| 14 |
+
|
| 15 |
+
CMD sh -c 'uvicorn rag_gc_env.server.app:app --host 0.0.0.0 --port ${PORT:-7860}'
|
app.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face Spaces default module entry (optional)."""
|
| 2 |
+
from rag_gc_env.server.app import app
|
| 3 |
+
|
| 4 |
+
__all__ = ["app"]
|
debug_tokens.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag_gc_env.environment import RAGGCEnvironment
|
| 2 |
+
from rag_gc_env.models import RAGGCAction
|
| 3 |
+
|
| 4 |
+
def test_medium():
|
| 5 |
+
env = RAGGCEnvironment()
|
| 6 |
+
obs = env.reset(task_name="medium_token_compression")
|
| 7 |
+
print(f"Initial tokens: {obs.token_count}, budget: {obs.token_budget}")
|
| 8 |
+
|
| 9 |
+
obs = env.step(RAGGCAction(verb="delete", document_id="m2"))
|
| 10 |
+
print(f"After delete m2: {obs.token_count}")
|
| 11 |
+
|
| 12 |
+
obs = env.step(RAGGCAction(verb="summarize", document_id="m0"))
|
| 13 |
+
print(f"After summarize m0: {obs.token_count}")
|
| 14 |
+
|
| 15 |
+
if obs.token_count > obs.token_budget:
|
| 16 |
+
print("BUG: token_count still above budget!")
|
| 17 |
+
else:
|
| 18 |
+
print("OK: token_count below budget.")
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
test_medium()
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: rag_gc_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: rag_gc_env.server.app:app
|
| 6 |
+
port: 8000
|
output.log
ADDED
|
Binary file (528 Bytes). View file
|
|
|
output_utf8.log
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
easy_irrelevant_removal score= 1.0 trace= ['reset', 'delete:d1', 'submit:None']
|
| 2 |
+
medium_token_compression score= 1.0 trace= ['reset', 'delete:m2', 'summarize:m0', 'submit:None']
|
| 3 |
+
hard_contradiction_removal score= 1.0 trace= ['reset', 'delete:h1', 'submit:None']
|
| 4 |
+
|
rag_gc_env/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag_gc_env.models import RAGGCAction, RAGGCObservation, RAGGCReward, RAGGCState
|
| 2 |
+
from rag_gc_env.environment import RAGGCEnvironment
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"RAGGCAction",
|
| 6 |
+
"RAGGCObservation",
|
| 7 |
+
"RAGGCReward",
|
| 8 |
+
"RAGGCState",
|
| 9 |
+
"RAGGCEnvironment",
|
| 10 |
+
]
|
| 11 |
+
|
rag_gc_env/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (440 Bytes). View file
|
|
|
rag_gc_env/__pycache__/environment.cpython-311.pyc
ADDED
|
Binary file (9.34 kB). View file
|
|
|
rag_gc_env/__pycache__/grader.cpython-311.pyc
ADDED
|
Binary file (2.75 kB). View file
|
|
|
rag_gc_env/__pycache__/inference.cpython-311.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
rag_gc_env/__pycache__/models.cpython-311.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
rag_gc_env/__pycache__/rewards.cpython-311.pyc
ADDED
|
Binary file (3.92 kB). View file
|
|
|
rag_gc_env/__pycache__/tasks.cpython-311.pyc
ADDED
|
Binary file (5.11 kB). View file
|
|
|
rag_gc_env/environment.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Optional
|
| 4 |
+
from uuid import uuid4
|
| 5 |
+
|
| 6 |
+
from openenv.core.env_server.interfaces import Environment
|
| 7 |
+
|
| 8 |
+
from rag_gc_env.grader import grade_context
|
| 9 |
+
from rag_gc_env.models import DocumentItem, RAGGCAction, RAGGCObservation, RAGGCReward, RAGGCState
|
| 10 |
+
from rag_gc_env.rewards import step_reward, summarize_deterministic
|
| 11 |
+
from rag_gc_env.tasks import ALL_TASKS, TaskSpec, task_by_seed
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RAGGCEnvironment(Environment[RAGGCAction, RAGGCObservation, RAGGCState]):
|
| 15 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 16 |
+
|
| 17 |
+
def __init__(self) -> None:
|
| 18 |
+
super().__init__(transform=None, rubric=None)
|
| 19 |
+
self._state = RAGGCState(episode_id=str(uuid4()), step_count=0)
|
| 20 |
+
self._task: TaskSpec = task_by_seed(0)
|
| 21 |
+
self._docs: dict[str, DocumentItem] = {}
|
| 22 |
+
self._removed_critical = False
|
| 23 |
+
|
| 24 |
+
def _load_task(self, spec: TaskSpec) -> None:
|
| 25 |
+
self._docs = {}
|
| 26 |
+
for did, text, tok, _meta in spec.documents:
|
| 27 |
+
self._docs[did] = DocumentItem(document_id=did, text=text, tokens=tok)
|
| 28 |
+
|
| 29 |
+
def reset(
|
| 30 |
+
self,
|
| 31 |
+
seed: Optional[int] = None,
|
| 32 |
+
episode_id: Optional[str] = None,
|
| 33 |
+
task_name: Optional[str] = None,
|
| 34 |
+
**kwargs: Any,
|
| 35 |
+
) -> RAGGCObservation:
|
| 36 |
+
self._reset_rubric()
|
| 37 |
+
sid = episode_id or str(uuid4())
|
| 38 |
+
if task_name and task_name in ALL_TASKS:
|
| 39 |
+
self._task = ALL_TASKS[task_name]
|
| 40 |
+
elif seed is not None:
|
| 41 |
+
self._task = task_by_seed(int(seed))
|
| 42 |
+
else:
|
| 43 |
+
self._task = task_by_seed(0)
|
| 44 |
+
self._load_task(self._task)
|
| 45 |
+
self._removed_critical = False
|
| 46 |
+
self._state = RAGGCState(
|
| 47 |
+
episode_id=sid,
|
| 48 |
+
step_count=0,
|
| 49 |
+
task_name=self._task.name,
|
| 50 |
+
max_steps=64,
|
| 51 |
+
removed_critical=False,
|
| 52 |
+
submitted=False,
|
| 53 |
+
)
|
| 54 |
+
return self._observe(done=False, reward_value=0.0, msg="ready")
|
| 55 |
+
|
| 56 |
+
def _total_tokens(self) -> int:
|
| 57 |
+
return sum(d.tokens for d in self._docs.values())
|
| 58 |
+
|
| 59 |
+
def _observe(
|
| 60 |
+
self,
|
| 61 |
+
done: bool,
|
| 62 |
+
reward_value: float,
|
| 63 |
+
msg: str,
|
| 64 |
+
reward_detail: Optional[RAGGCReward] = None,
|
| 65 |
+
grader_score: Optional[float] = None,
|
| 66 |
+
) -> RAGGCObservation:
|
| 67 |
+
docs = sorted(self._docs.values(), key=lambda x: x.document_id)
|
| 68 |
+
return RAGGCObservation(
|
| 69 |
+
done=done,
|
| 70 |
+
reward=reward_value,
|
| 71 |
+
query=self._task.query,
|
| 72 |
+
documents=docs,
|
| 73 |
+
token_count=self._total_tokens(),
|
| 74 |
+
token_budget=self._task.token_budget,
|
| 75 |
+
task_name=self._task.name,
|
| 76 |
+
message=msg,
|
| 77 |
+
grader_score=grader_score,
|
| 78 |
+
reward_detail=reward_detail,
|
| 79 |
+
metadata={
|
| 80 |
+
"relevance": {
|
| 81 |
+
row[0]: row[3].get("relevance", 0.5)
|
| 82 |
+
for row in self._task.documents
|
| 83 |
+
if row[0] in self._docs
|
| 84 |
+
},
|
| 85 |
+
"hints": {row[0]: row[3].get("hint", "") for row in self._task.documents},
|
| 86 |
+
},
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def step(
|
| 90 |
+
self,
|
| 91 |
+
action: RAGGCAction,
|
| 92 |
+
timeout_s: Optional[float] = None,
|
| 93 |
+
**kwargs: Any,
|
| 94 |
+
) -> RAGGCObservation:
|
| 95 |
+
self._state.step_count += 1
|
| 96 |
+
docs_before = dict(self._docs)
|
| 97 |
+
|
| 98 |
+
if action.verb == "submit":
|
| 99 |
+
score = grade_context(self._task, list(self._docs.values()))
|
| 100 |
+
self._state.submitted = True
|
| 101 |
+
r = RAGGCReward(
|
| 102 |
+
step_reward=score,
|
| 103 |
+
final_score=score,
|
| 104 |
+
)
|
| 105 |
+
obs = self._observe(
|
| 106 |
+
done=True,
|
| 107 |
+
reward_value=score,
|
| 108 |
+
msg="submitted",
|
| 109 |
+
reward_detail=r,
|
| 110 |
+
grader_score=score,
|
| 111 |
+
)
|
| 112 |
+
return self._apply_transform(obs)
|
| 113 |
+
|
| 114 |
+
if action.document_id is None or action.document_id not in self._docs:
|
| 115 |
+
obs = self._observe(
|
| 116 |
+
done=False,
|
| 117 |
+
reward_value=-0.1,
|
| 118 |
+
msg="unknown_document",
|
| 119 |
+
)
|
| 120 |
+
return self._apply_transform(obs)
|
| 121 |
+
|
| 122 |
+
did = action.document_id
|
| 123 |
+
removed_critical = False
|
| 124 |
+
|
| 125 |
+
if action.verb == "delete":
|
| 126 |
+
if did in self._task.critical_document_ids:
|
| 127 |
+
self._removed_critical = True
|
| 128 |
+
removed_critical = True
|
| 129 |
+
self._docs.pop(did, None)
|
| 130 |
+
|
| 131 |
+
elif action.verb == "keep":
|
| 132 |
+
pass
|
| 133 |
+
|
| 134 |
+
elif action.verb == "summarize":
|
| 135 |
+
item = self._docs[did]
|
| 136 |
+
new_text, new_tok = summarize_deterministic(item.text)
|
| 137 |
+
self._docs[did] = DocumentItem(
|
| 138 |
+
document_id=did,
|
| 139 |
+
text=new_text,
|
| 140 |
+
tokens=new_tok,
|
| 141 |
+
)
|
| 142 |
+
if did in self._task.critical_document_ids:
|
| 143 |
+
for p in self._task.required_phrases:
|
| 144 |
+
if p not in new_text:
|
| 145 |
+
self._removed_critical = True
|
| 146 |
+
removed_critical = True
|
| 147 |
+
|
| 148 |
+
rdetail = step_reward(
|
| 149 |
+
self._task,
|
| 150 |
+
action.verb,
|
| 151 |
+
did,
|
| 152 |
+
docs_before,
|
| 153 |
+
self._docs,
|
| 154 |
+
removed_critical,
|
| 155 |
+
)
|
| 156 |
+
self._state.removed_critical = self._removed_critical
|
| 157 |
+
|
| 158 |
+
over = self._total_tokens() > self._task.token_budget
|
| 159 |
+
if over:
|
| 160 |
+
penalty = -0.08 * (self._total_tokens() - self._task.token_budget)
|
| 161 |
+
rdetail.token_penalty += penalty
|
| 162 |
+
rdetail.step_reward += penalty
|
| 163 |
+
done = self._state.step_count >= self._state.max_steps
|
| 164 |
+
final_score: Optional[float] = None
|
| 165 |
+
if done:
|
| 166 |
+
final_score = grade_context(self._task, list(self._docs.values()))
|
| 167 |
+
rdetail.final_score = final_score
|
| 168 |
+
rdetail.step_reward += final_score * 0.5
|
| 169 |
+
|
| 170 |
+
reward_val = rdetail.step_reward
|
| 171 |
+
if done:
|
| 172 |
+
# When done, the reward is primarily the final grader score,
|
| 173 |
+
# but we can preserve the step-specific bonus we added.
|
| 174 |
+
# final_score is the main signal.
|
| 175 |
+
reward_val = final_score if final_score is not None else rdetail.step_reward
|
| 176 |
+
obs = self._observe(
|
| 177 |
+
done=done,
|
| 178 |
+
reward_value=reward_val,
|
| 179 |
+
msg="over_budget" if over else ("graded" if done else "ok"),
|
| 180 |
+
reward_detail=rdetail,
|
| 181 |
+
grader_score=final_score if done else None,
|
| 182 |
+
)
|
| 183 |
+
return self._apply_transform(obs)
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def state(self) -> RAGGCState:
|
| 187 |
+
return self._state
|
rag_gc_env/grader.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from rag_gc_env.models import DocumentItem
|
| 4 |
+
from rag_gc_env.tasks import TaskSpec
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _joined_text(docs: list[DocumentItem]) -> str:
|
| 8 |
+
return " ".join(d.text for d in docs)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def grade_context(task: TaskSpec, final_documents: list[DocumentItem]) -> float:
|
| 12 |
+
"""
|
| 13 |
+
Deterministic score in [0.0, 0.5, 1.0]:
|
| 14 |
+
1.0 — required facts present, budget respected, efficient (near optimal tokens)
|
| 15 |
+
0.5 — required facts present but inefficient or borderline budget
|
| 16 |
+
0.0 — missing facts, forbidden content present, or critical docs removed incorrectly
|
| 17 |
+
"""
|
| 18 |
+
text = _joined_text(final_documents)
|
| 19 |
+
total_tokens = sum(d.tokens for d in final_documents)
|
| 20 |
+
|
| 21 |
+
for phrase in task.required_phrases:
|
| 22 |
+
if phrase not in text:
|
| 23 |
+
return 0.0
|
| 24 |
+
|
| 25 |
+
for phrase in task.forbidden_phrases:
|
| 26 |
+
if phrase in text:
|
| 27 |
+
return 0.0
|
| 28 |
+
|
| 29 |
+
for pid in task.poison_document_ids:
|
| 30 |
+
still = any(d.document_id == pid for d in final_documents)
|
| 31 |
+
if still:
|
| 32 |
+
return 0.0
|
| 33 |
+
|
| 34 |
+
if total_tokens > task.token_budget:
|
| 35 |
+
return 0.0
|
| 36 |
+
|
| 37 |
+
if not task.critical_document_ids.issubset({d.document_id for d in final_documents}):
|
| 38 |
+
return 0.0
|
| 39 |
+
|
| 40 |
+
if total_tokens <= task.optimal_max_tokens:
|
| 41 |
+
return 1.0
|
| 42 |
+
|
| 43 |
+
return 0.5
|
rag_gc_env/inference.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reproducible baseline policy for Adaptive Context Optimization (RAG GC).
|
| 3 |
+
Deterministic: fixed action sequences per task derived from metadata.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from rag_gc_env.environment import RAGGCEnvironment
|
| 9 |
+
from rag_gc_env.models import RAGGCAction
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def run_baseline(task_name: str, seed: int = 0) -> tuple[float, list[str]]:
|
| 13 |
+
env = RAGGCEnvironment()
|
| 14 |
+
obs = env.reset(seed=seed, task_name=task_name)
|
| 15 |
+
log: list[str] = ["reset"]
|
| 16 |
+
|
| 17 |
+
def step(verb: str, doc_id: str | None) -> None:
|
| 18 |
+
nonlocal obs
|
| 19 |
+
obs = env.step(RAGGCAction(verb=verb, document_id=doc_id))
|
| 20 |
+
log.append(f"{verb}:{doc_id}")
|
| 21 |
+
|
| 22 |
+
if task_name == "easy_irrelevant_removal":
|
| 23 |
+
step("delete", "d1")
|
| 24 |
+
step("submit", None)
|
| 25 |
+
elif task_name == "medium_token_compression":
|
| 26 |
+
step("delete", "m2")
|
| 27 |
+
while obs.token_count > obs.token_budget and not obs.done:
|
| 28 |
+
step("summarize", "m0")
|
| 29 |
+
if len(log) > 40:
|
| 30 |
+
break
|
| 31 |
+
step("submit", None)
|
| 32 |
+
elif task_name == "hard_contradiction_removal":
|
| 33 |
+
step("delete", "h1")
|
| 34 |
+
step("submit", None)
|
| 35 |
+
else:
|
| 36 |
+
step("submit", None)
|
| 37 |
+
|
| 38 |
+
score = float(obs.grader_score or obs.reward or 0.0)
|
| 39 |
+
return score, log
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
for name in (
|
| 44 |
+
"easy_irrelevant_removal",
|
| 45 |
+
"medium_token_compression",
|
| 46 |
+
"hard_contradiction_removal",
|
| 47 |
+
):
|
| 48 |
+
s, lg = run_baseline(name, seed=0)
|
| 49 |
+
print(name, "score=", s, "trace=", lg)
|
rag_gc_env/models.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Literal, Optional
|
| 4 |
+
|
| 5 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DocumentItem(BaseModel):
|
| 10 |
+
document_id: str
|
| 11 |
+
text: str
|
| 12 |
+
tokens: int = Field(description="Estimated tokens for this snippet")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RAGGCAction(Action):
|
| 16 |
+
verb: Literal["keep", "delete", "summarize", "submit"] = Field(
|
| 17 |
+
description="Document operation or submit to finalize and grade"
|
| 18 |
+
)
|
| 19 |
+
document_id: Optional[str] = Field(
|
| 20 |
+
default=None,
|
| 21 |
+
description="Target document for keep/delete/summarize; omit for submit",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RAGGCReward(BaseModel):
|
| 26 |
+
step_reward: float = 0.0
|
| 27 |
+
relevance: float = 0.0
|
| 28 |
+
compression: float = 0.0
|
| 29 |
+
token_penalty: float = 0.0
|
| 30 |
+
critical_penalty: float = 0.0
|
| 31 |
+
final_score: Optional[float] = Field(
|
| 32 |
+
default=None, description="0.0–1.0 after submit; aligns with grader"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RAGGCObservation(Observation):
|
| 37 |
+
query: str = ""
|
| 38 |
+
documents: list[DocumentItem] = Field(default_factory=list)
|
| 39 |
+
token_count: int = 0
|
| 40 |
+
token_budget: int = 0
|
| 41 |
+
task_name: str = ""
|
| 42 |
+
reward_detail: Optional[RAGGCReward] = None
|
| 43 |
+
message: str = ""
|
| 44 |
+
grader_score: Optional[float] = Field(
|
| 45 |
+
default=None, description="Deterministic score after episode ends"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RAGGCState(State):
|
| 50 |
+
task_name: str = ""
|
| 51 |
+
max_steps: int = 64
|
| 52 |
+
removed_critical: bool = False
|
| 53 |
+
submitted: bool = False
|
rag_gc_env/rewards.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from rag_gc_env.models import DocumentItem, RAGGCReward
|
| 4 |
+
from rag_gc_env.tasks import TaskSpec
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def summarize_deterministic(text: str) -> tuple[str, int]:
|
| 8 |
+
"""Deterministic compression: first sentence or capped prefix."""
|
| 9 |
+
stripped = text.strip()
|
| 10 |
+
if not stripped:
|
| 11 |
+
return "", 1
|
| 12 |
+
cut = stripped.split(". ")
|
| 13 |
+
first = cut[0] + ("." if not cut[0].endswith(".") else "")
|
| 14 |
+
if len(first) < 40 and len(cut) > 1:
|
| 15 |
+
first = cut[0] + ". " + cut[1] + ("." if not cut[1].endswith(".") else "")
|
| 16 |
+
cap = 280
|
| 17 |
+
out = first[:cap] + ("..." if len(first) > cap else "")
|
| 18 |
+
tokens = max(1, len(out) // 4)
|
| 19 |
+
return out, tokens
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def estimate_tokens(text: str) -> int:
|
| 23 |
+
return max(1, len(text) // 4)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def step_reward(
|
| 27 |
+
task: TaskSpec,
|
| 28 |
+
verb: str,
|
| 29 |
+
doc_id: str | None,
|
| 30 |
+
docs_before: dict[str, DocumentItem],
|
| 31 |
+
docs_after: dict[str, DocumentItem],
|
| 32 |
+
removed_critical_flag: bool,
|
| 33 |
+
) -> RAGGCReward:
|
| 34 |
+
rel = 0.0
|
| 35 |
+
comp = 0.0
|
| 36 |
+
tok_pen = 0.0
|
| 37 |
+
crit = 0.0
|
| 38 |
+
|
| 39 |
+
if removed_critical_flag:
|
| 40 |
+
crit = -3.0
|
| 41 |
+
|
| 42 |
+
if verb == "delete" and doc_id in docs_before:
|
| 43 |
+
meta = next(
|
| 44 |
+
(m for did, _, _, m in task.documents if did == doc_id),
|
| 45 |
+
{},
|
| 46 |
+
)
|
| 47 |
+
# Reward deleting irrelevant or poison documents
|
| 48 |
+
if doc_id in task.irrelevant_document_ids:
|
| 49 |
+
rel += 0.4
|
| 50 |
+
elif doc_id in task.poison_document_ids:
|
| 51 |
+
rel += 0.6
|
| 52 |
+
elif doc_id in task.critical_document_ids:
|
| 53 |
+
crit -= 3.0
|
| 54 |
+
elif meta.get("hint") == "fluff":
|
| 55 |
+
rel += 0.2
|
| 56 |
+
|
| 57 |
+
# Deleting tokens should NOT result in a penalty proportional to the deleted tokens;
|
| 58 |
+
# instead, it removes the 'keep' penalty they would have incurred.
|
| 59 |
+
# We can add a small constant 'action cost' for deleting if desired, but 0.0 is fine here.
|
| 60 |
+
tok_pen = 0.0
|
| 61 |
+
|
| 62 |
+
if verb == "summarize" and doc_id in docs_before:
|
| 63 |
+
before_t = docs_before[doc_id].tokens
|
| 64 |
+
after = docs_after.get(doc_id)
|
| 65 |
+
if after is not None:
|
| 66 |
+
# Reward for the reduction in size (efficiency)
|
| 67 |
+
reduction_ratio = (before_t - after.tokens) / max(before_t, 1)
|
| 68 |
+
comp += 0.3 * max(0.0, reduction_ratio)
|
| 69 |
+
|
| 70 |
+
# The remaining tokens still incur a small penalty
|
| 71 |
+
tok_pen -= 0.01 * after.tokens
|
| 72 |
+
|
| 73 |
+
if doc_id in task.critical_document_ids:
|
| 74 |
+
for p in task.required_phrases:
|
| 75 |
+
if p not in after.text:
|
| 76 |
+
crit -= 2.5
|
| 77 |
+
|
| 78 |
+
if verb == "keep" and doc_id in docs_before:
|
| 79 |
+
# Standard penalty for keeping tokens in context
|
| 80 |
+
tok_pen -= 0.01 * docs_before[doc_id].tokens
|
| 81 |
+
|
| 82 |
+
step = rel + comp + tok_pen + crit
|
| 83 |
+
return RAGGCReward(
|
| 84 |
+
step_reward=step,
|
| 85 |
+
relevance=rel,
|
| 86 |
+
compression=comp,
|
| 87 |
+
token_penalty=tok_pen,
|
| 88 |
+
critical_penalty=crit,
|
| 89 |
+
)
|
rag_gc_env/server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Server package for OpenEnv HTTP deployment
|
rag_gc_env/server/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
rag_gc_env/server/__pycache__/app.cpython-311.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
rag_gc_env/server/app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from openenv.core.env_server.http_server import create_fastapi_app
|
| 4 |
+
|
| 5 |
+
from rag_gc_env.environment import RAGGCEnvironment
|
| 6 |
+
from rag_gc_env.models import RAGGCAction, RAGGCObservation
|
| 7 |
+
|
| 8 |
+
app = create_fastapi_app(
|
| 9 |
+
RAGGCEnvironment,
|
| 10 |
+
RAGGCAction,
|
| 11 |
+
RAGGCObservation,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main() -> None:
|
| 16 |
+
import uvicorn
|
| 17 |
+
|
| 18 |
+
port = int(os.environ.get("PORT", "8000"))
|
| 19 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
main()
|
rag_gc_env/tasks.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any, FrozenSet
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True)
|
| 8 |
+
class TaskSpec:
|
| 9 |
+
name: str
|
| 10 |
+
query: str
|
| 11 |
+
token_budget: int
|
| 12 |
+
documents: list[tuple[str, str, int, dict[str, Any]]]
|
| 13 |
+
# document_id, text, tokens, metadata (relevance, flags)
|
| 14 |
+
required_phrases: FrozenSet[str] = field(default_factory=frozenset)
|
| 15 |
+
forbidden_phrases: FrozenSet[str] = field(default_factory=frozenset)
|
| 16 |
+
critical_document_ids: FrozenSet[str] = field(default_factory=frozenset)
|
| 17 |
+
irrelevant_document_ids: FrozenSet[str] = field(default_factory=frozenset)
|
| 18 |
+
poison_document_ids: FrozenSet[str] = field(default_factory=frozenset)
|
| 19 |
+
optimal_max_tokens: int = 0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _docs(
|
| 23 |
+
rows: list[tuple[str, str, int, dict[str, Any]]],
|
| 24 |
+
) -> list[tuple[str, str, int, dict[str, Any]]]:
|
| 25 |
+
return rows
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
TASK_EASY = TaskSpec(
|
| 29 |
+
name="easy_irrelevant_removal",
|
| 30 |
+
query="What is the capital city of France?",
|
| 31 |
+
token_budget=400,
|
| 32 |
+
documents=_docs(
|
| 33 |
+
[
|
| 34 |
+
(
|
| 35 |
+
"d0",
|
| 36 |
+
"Paris has been the capital of France since political centralization in the country.",
|
| 37 |
+
24,
|
| 38 |
+
{"relevance": 0.95, "hint": "high"},
|
| 39 |
+
),
|
| 40 |
+
(
|
| 41 |
+
"d1",
|
| 42 |
+
"Penguins thrive in Antarctica and are unrelated to European geography.",
|
| 43 |
+
18,
|
| 44 |
+
{"relevance": 0.08, "hint": "noise"},
|
| 45 |
+
),
|
| 46 |
+
(
|
| 47 |
+
"d2",
|
| 48 |
+
"Lyon is a major French city but not the national capital.",
|
| 49 |
+
16,
|
| 50 |
+
{"relevance": 0.55, "hint": "partial"},
|
| 51 |
+
),
|
| 52 |
+
]
|
| 53 |
+
),
|
| 54 |
+
required_phrases=frozenset({"Paris"}),
|
| 55 |
+
forbidden_phrases=frozenset(),
|
| 56 |
+
critical_document_ids=frozenset({"d0"}),
|
| 57 |
+
irrelevant_document_ids=frozenset({"d1"}),
|
| 58 |
+
poison_document_ids=frozenset(),
|
| 59 |
+
optimal_max_tokens=120,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
_LONG_DUP = (
|
| 63 |
+
"Paris is the capital of France. " * 18
|
| 64 |
+
+ "This repetition exists only to inflate token usage for compression tests."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
TASK_MEDIUM = TaskSpec(
|
| 68 |
+
name="medium_token_compression",
|
| 69 |
+
query="Which city is the capital of France?",
|
| 70 |
+
token_budget=120,
|
| 71 |
+
documents=_docs(
|
| 72 |
+
[
|
| 73 |
+
(
|
| 74 |
+
"m0",
|
| 75 |
+
_LONG_DUP,
|
| 76 |
+
max(1, len(_LONG_DUP) // 4),
|
| 77 |
+
{"relevance": 0.9, "hint": "verbose"},
|
| 78 |
+
),
|
| 79 |
+
(
|
| 80 |
+
"m1",
|
| 81 |
+
"Administrative records list a capital city but this line omits the name intentionally.",
|
| 82 |
+
14,
|
| 83 |
+
{"relevance": 0.55, "hint": "no_answer"},
|
| 84 |
+
),
|
| 85 |
+
(
|
| 86 |
+
"m2",
|
| 87 |
+
"French cuisine is diverse; it does not change which city is the capital.",
|
| 88 |
+
14,
|
| 89 |
+
{"relevance": 0.35, "hint": "fluff"},
|
| 90 |
+
),
|
| 91 |
+
]
|
| 92 |
+
),
|
| 93 |
+
required_phrases=frozenset({"Paris"}),
|
| 94 |
+
forbidden_phrases=frozenset(),
|
| 95 |
+
critical_document_ids=frozenset(),
|
| 96 |
+
irrelevant_document_ids=frozenset({"m2"}),
|
| 97 |
+
poison_document_ids=frozenset(),
|
| 98 |
+
optimal_max_tokens=90,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
TASK_HARD = TaskSpec(
|
| 102 |
+
name="hard_contradiction_removal",
|
| 103 |
+
query="Compute the integer result of 17 + 25.",
|
| 104 |
+
token_budget=350,
|
| 105 |
+
documents=_docs(
|
| 106 |
+
[
|
| 107 |
+
(
|
| 108 |
+
"h0",
|
| 109 |
+
"Arithmetic trace: 17 + 25 = 42.",
|
| 110 |
+
12,
|
| 111 |
+
{"relevance": 0.9, "hint": "consistent", "trust": 0.95},
|
| 112 |
+
),
|
| 113 |
+
(
|
| 114 |
+
"h1",
|
| 115 |
+
"Quick math note: 17 + 25 equals 43 for budgeting purposes.",
|
| 116 |
+
14,
|
| 117 |
+
{"relevance": 0.88, "hint": "poison", "trust": 0.2},
|
| 118 |
+
),
|
| 119 |
+
(
|
| 120 |
+
"h2",
|
| 121 |
+
"Addition of integers is associative and commutative.",
|
| 122 |
+
10,
|
| 123 |
+
{"relevance": 0.4, "hint": "generic"},
|
| 124 |
+
),
|
| 125 |
+
]
|
| 126 |
+
),
|
| 127 |
+
required_phrases=frozenset({"42"}),
|
| 128 |
+
forbidden_phrases=frozenset({"43"}),
|
| 129 |
+
critical_document_ids=frozenset({"h0"}),
|
| 130 |
+
irrelevant_document_ids=frozenset(),
|
| 131 |
+
poison_document_ids=frozenset({"h1"}),
|
| 132 |
+
optimal_max_tokens=200,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
ALL_TASKS: dict[str, TaskSpec] = {
|
| 136 |
+
TASK_EASY.name: TASK_EASY,
|
| 137 |
+
TASK_MEDIUM.name: TASK_MEDIUM,
|
| 138 |
+
TASK_HARD.name: TASK_HARD,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def task_by_seed(seed: int) -> TaskSpec:
|
| 143 |
+
order = [TASK_EASY, TASK_MEDIUM, TASK_HARD]
|
| 144 |
+
return order[seed % 3]
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.1.0
|
| 2 |
+
pydantic>=2.0
|
| 3 |
+
fastapi>=0.104.0
|
| 4 |
+
uvicorn[standard]>=0.24.0
|
| 5 |
+
typing_extensions>=4.8.0
|
test_reward_logic.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag_gc_env.environment import RAGGCEnvironment
|
| 2 |
+
from rag_gc_env.models import RAGGCAction
|
| 3 |
+
|
| 4 |
+
def test_reward_logic():
|
| 5 |
+
env = RAGGCEnvironment()
|
| 6 |
+
|
| 7 |
+
# Test 1: Deleting irrelevant document should be POSITIVE
|
| 8 |
+
print("\n--- Test 1: Deleting irrelevant (d1) in Easy Task ---")
|
| 9 |
+
obs = env.reset(task_name="easy_irrelevant_removal")
|
| 10 |
+
obs = env.step(RAGGCAction(verb="delete", document_id="d1"))
|
| 11 |
+
print(f"Step Reward for deleting d1: {obs.reward}")
|
| 12 |
+
if obs.reward > 0:
|
| 13 |
+
print("PASS: Positive reward for deletion.")
|
| 14 |
+
else:
|
| 15 |
+
print("FAIL: Reward should be positive.")
|
| 16 |
+
|
| 17 |
+
# Test 2: Summarizing to save tokens should be POSITIVE
|
| 18 |
+
print("\n--- Test 2: Summarizing (m0) in Medium Task ---")
|
| 19 |
+
obs = env.reset(task_name="medium_token_compression")
|
| 20 |
+
tokens_before = obs.token_count
|
| 21 |
+
obs = env.step(RAGGCAction(verb="summarize", document_id="m0"))
|
| 22 |
+
print(f"Step Reward for summarizing m0: {obs.reward}")
|
| 23 |
+
print(f"Tokens: {tokens_before} -> {obs.token_count}")
|
| 24 |
+
if obs.reward > 0:
|
| 25 |
+
print("PASS: Positive reward for summarization.")
|
| 26 |
+
else:
|
| 27 |
+
print("FAIL: Reward should be positive (was previously negative).")
|
| 28 |
+
|
| 29 |
+
# Test 3: Over budget penalty should be reflected in reward
|
| 30 |
+
print("\n--- Test 3: Over budget penalty in Medium Task ---")
|
| 31 |
+
obs = env.reset(task_name="medium_token_compression")
|
| 32 |
+
# budget is 120, total is 190. over by 70. penalty should be -0.08 * 70 = -5.6
|
| 33 |
+
# 'keep' m1 (14 tokens). 14 * -0.01 = -0.14.
|
| 34 |
+
# total reward should be -5.6 - 0.14 = -5.74
|
| 35 |
+
obs = env.step(RAGGCAction(verb="keep", document_id="m1"))
|
| 36 |
+
print(f"Reward with budget penalty: {obs.reward}")
|
| 37 |
+
if obs.reward < -5:
|
| 38 |
+
print("PASS: Budget penalty correctly reflected in step reward.")
|
| 39 |
+
else:
|
| 40 |
+
print("FAIL: Budget penalty missing or too low.")
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
test_reward_logic()
|