prithic07 commited on
Commit
a108ef2
·
1 Parent(s): 5011e42

Beginning of project

Browse files
.dockerignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ .git
4
+ .venv
5
+ venv
6
+ *.md
7
+ .pytest_cache
8
+ .mypy_cache
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+ ENV PYTHONPATH=/app
7
+
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ COPY rag_gc_env ./rag_gc_env
12
+
13
+ EXPOSE 7860
14
+
15
+ CMD sh -c 'uvicorn rag_gc_env.server.app:app --host 0.0.0.0 --port ${PORT:-7860}'
app.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Hugging Face Spaces default module entry (optional)."""
2
+ from rag_gc_env.server.app import app
3
+
4
+ __all__ = ["app"]
debug_tokens.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rag_gc_env.environment import RAGGCEnvironment
2
+ from rag_gc_env.models import RAGGCAction
3
+
4
+ def test_medium():
5
+ env = RAGGCEnvironment()
6
+ obs = env.reset(task_name="medium_token_compression")
7
+ print(f"Initial tokens: {obs.token_count}, budget: {obs.token_budget}")
8
+
9
+ obs = env.step(RAGGCAction(verb="delete", document_id="m2"))
10
+ print(f"After delete m2: {obs.token_count}")
11
+
12
+ obs = env.step(RAGGCAction(verb="summarize", document_id="m0"))
13
+ print(f"After summarize m0: {obs.token_count}")
14
+
15
+ if obs.token_count > obs.token_budget:
16
+ print("BUG: token_count still above budget!")
17
+ else:
18
+ print("OK: token_count below budget.")
19
+
20
+ if __name__ == "__main__":
21
+ test_medium()
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: rag_gc_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: rag_gc_env.server.app:app
6
+ port: 8000
output.log ADDED
Binary file (528 Bytes). View file
 
output_utf8.log ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ easy_irrelevant_removal score= 1.0 trace= ['reset', 'delete:d1', 'submit:None']
2
+ medium_token_compression score= 1.0 trace= ['reset', 'delete:m2', 'summarize:m0', 'submit:None']
3
+ hard_contradiction_removal score= 1.0 trace= ['reset', 'delete:h1', 'submit:None']
4
+
rag_gc_env/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rag_gc_env.models import RAGGCAction, RAGGCObservation, RAGGCReward, RAGGCState
2
+ from rag_gc_env.environment import RAGGCEnvironment
3
+
4
+ __all__ = [
5
+ "RAGGCAction",
6
+ "RAGGCObservation",
7
+ "RAGGCReward",
8
+ "RAGGCState",
9
+ "RAGGCEnvironment",
10
+ ]
11
+
rag_gc_env/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (440 Bytes). View file
 
rag_gc_env/__pycache__/environment.cpython-311.pyc ADDED
Binary file (9.34 kB). View file
 
rag_gc_env/__pycache__/grader.cpython-311.pyc ADDED
Binary file (2.75 kB). View file
 
rag_gc_env/__pycache__/inference.cpython-311.pyc ADDED
Binary file (2.66 kB). View file
 
rag_gc_env/__pycache__/models.cpython-311.pyc ADDED
Binary file (3.51 kB). View file
 
rag_gc_env/__pycache__/rewards.cpython-311.pyc ADDED
Binary file (3.92 kB). View file
 
rag_gc_env/__pycache__/tasks.cpython-311.pyc ADDED
Binary file (5.11 kB). View file
 
rag_gc_env/environment.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Optional
4
+ from uuid import uuid4
5
+
6
+ from openenv.core.env_server.interfaces import Environment
7
+
8
+ from rag_gc_env.grader import grade_context
9
+ from rag_gc_env.models import DocumentItem, RAGGCAction, RAGGCObservation, RAGGCReward, RAGGCState
10
+ from rag_gc_env.rewards import step_reward, summarize_deterministic
11
+ from rag_gc_env.tasks import ALL_TASKS, TaskSpec, task_by_seed
12
+
13
+
14
+ class RAGGCEnvironment(Environment[RAGGCAction, RAGGCObservation, RAGGCState]):
15
+ SUPPORTS_CONCURRENT_SESSIONS = True
16
+
17
+ def __init__(self) -> None:
18
+ super().__init__(transform=None, rubric=None)
19
+ self._state = RAGGCState(episode_id=str(uuid4()), step_count=0)
20
+ self._task: TaskSpec = task_by_seed(0)
21
+ self._docs: dict[str, DocumentItem] = {}
22
+ self._removed_critical = False
23
+
24
+ def _load_task(self, spec: TaskSpec) -> None:
25
+ self._docs = {}
26
+ for did, text, tok, _meta in spec.documents:
27
+ self._docs[did] = DocumentItem(document_id=did, text=text, tokens=tok)
28
+
29
+ def reset(
30
+ self,
31
+ seed: Optional[int] = None,
32
+ episode_id: Optional[str] = None,
33
+ task_name: Optional[str] = None,
34
+ **kwargs: Any,
35
+ ) -> RAGGCObservation:
36
+ self._reset_rubric()
37
+ sid = episode_id or str(uuid4())
38
+ if task_name and task_name in ALL_TASKS:
39
+ self._task = ALL_TASKS[task_name]
40
+ elif seed is not None:
41
+ self._task = task_by_seed(int(seed))
42
+ else:
43
+ self._task = task_by_seed(0)
44
+ self._load_task(self._task)
45
+ self._removed_critical = False
46
+ self._state = RAGGCState(
47
+ episode_id=sid,
48
+ step_count=0,
49
+ task_name=self._task.name,
50
+ max_steps=64,
51
+ removed_critical=False,
52
+ submitted=False,
53
+ )
54
+ return self._observe(done=False, reward_value=0.0, msg="ready")
55
+
56
+ def _total_tokens(self) -> int:
57
+ return sum(d.tokens for d in self._docs.values())
58
+
59
+ def _observe(
60
+ self,
61
+ done: bool,
62
+ reward_value: float,
63
+ msg: str,
64
+ reward_detail: Optional[RAGGCReward] = None,
65
+ grader_score: Optional[float] = None,
66
+ ) -> RAGGCObservation:
67
+ docs = sorted(self._docs.values(), key=lambda x: x.document_id)
68
+ return RAGGCObservation(
69
+ done=done,
70
+ reward=reward_value,
71
+ query=self._task.query,
72
+ documents=docs,
73
+ token_count=self._total_tokens(),
74
+ token_budget=self._task.token_budget,
75
+ task_name=self._task.name,
76
+ message=msg,
77
+ grader_score=grader_score,
78
+ reward_detail=reward_detail,
79
+ metadata={
80
+ "relevance": {
81
+ row[0]: row[3].get("relevance", 0.5)
82
+ for row in self._task.documents
83
+ if row[0] in self._docs
84
+ },
85
+ "hints": {row[0]: row[3].get("hint", "") for row in self._task.documents},
86
+ },
87
+ )
88
+
89
+ def step(
90
+ self,
91
+ action: RAGGCAction,
92
+ timeout_s: Optional[float] = None,
93
+ **kwargs: Any,
94
+ ) -> RAGGCObservation:
95
+ self._state.step_count += 1
96
+ docs_before = dict(self._docs)
97
+
98
+ if action.verb == "submit":
99
+ score = grade_context(self._task, list(self._docs.values()))
100
+ self._state.submitted = True
101
+ r = RAGGCReward(
102
+ step_reward=score,
103
+ final_score=score,
104
+ )
105
+ obs = self._observe(
106
+ done=True,
107
+ reward_value=score,
108
+ msg="submitted",
109
+ reward_detail=r,
110
+ grader_score=score,
111
+ )
112
+ return self._apply_transform(obs)
113
+
114
+ if action.document_id is None or action.document_id not in self._docs:
115
+ obs = self._observe(
116
+ done=False,
117
+ reward_value=-0.1,
118
+ msg="unknown_document",
119
+ )
120
+ return self._apply_transform(obs)
121
+
122
+ did = action.document_id
123
+ removed_critical = False
124
+
125
+ if action.verb == "delete":
126
+ if did in self._task.critical_document_ids:
127
+ self._removed_critical = True
128
+ removed_critical = True
129
+ self._docs.pop(did, None)
130
+
131
+ elif action.verb == "keep":
132
+ pass
133
+
134
+ elif action.verb == "summarize":
135
+ item = self._docs[did]
136
+ new_text, new_tok = summarize_deterministic(item.text)
137
+ self._docs[did] = DocumentItem(
138
+ document_id=did,
139
+ text=new_text,
140
+ tokens=new_tok,
141
+ )
142
+ if did in self._task.critical_document_ids:
143
+ for p in self._task.required_phrases:
144
+ if p not in new_text:
145
+ self._removed_critical = True
146
+ removed_critical = True
147
+
148
+ rdetail = step_reward(
149
+ self._task,
150
+ action.verb,
151
+ did,
152
+ docs_before,
153
+ self._docs,
154
+ removed_critical,
155
+ )
156
+ self._state.removed_critical = self._removed_critical
157
+
158
+ over = self._total_tokens() > self._task.token_budget
159
+ if over:
160
+ penalty = -0.08 * (self._total_tokens() - self._task.token_budget)
161
+ rdetail.token_penalty += penalty
162
+ rdetail.step_reward += penalty
163
+ done = self._state.step_count >= self._state.max_steps
164
+ final_score: Optional[float] = None
165
+ if done:
166
+ final_score = grade_context(self._task, list(self._docs.values()))
167
+ rdetail.final_score = final_score
168
+ rdetail.step_reward += final_score * 0.5
169
+
170
+ reward_val = rdetail.step_reward
171
+ if done:
172
+ # When done, the reward is primarily the final grader score,
173
+ # but we can preserve the step-specific bonus we added.
174
+ # final_score is the main signal.
175
+ reward_val = final_score if final_score is not None else rdetail.step_reward
176
+ obs = self._observe(
177
+ done=done,
178
+ reward_value=reward_val,
179
+ msg="over_budget" if over else ("graded" if done else "ok"),
180
+ reward_detail=rdetail,
181
+ grader_score=final_score if done else None,
182
+ )
183
+ return self._apply_transform(obs)
184
+
185
+ @property
186
+ def state(self) -> RAGGCState:
187
+ return self._state
rag_gc_env/grader.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from rag_gc_env.models import DocumentItem
4
+ from rag_gc_env.tasks import TaskSpec
5
+
6
+
7
+ def _joined_text(docs: list[DocumentItem]) -> str:
8
+ return " ".join(d.text for d in docs)
9
+
10
+
11
+ def grade_context(task: TaskSpec, final_documents: list[DocumentItem]) -> float:
12
+ """
13
+ Deterministic score in [0.0, 0.5, 1.0]:
14
+ 1.0 — required facts present, budget respected, efficient (near optimal tokens)
15
+ 0.5 — required facts present but inefficient or borderline budget
16
+ 0.0 — missing facts, forbidden content present, or critical docs removed incorrectly
17
+ """
18
+ text = _joined_text(final_documents)
19
+ total_tokens = sum(d.tokens for d in final_documents)
20
+
21
+ for phrase in task.required_phrases:
22
+ if phrase not in text:
23
+ return 0.0
24
+
25
+ for phrase in task.forbidden_phrases:
26
+ if phrase in text:
27
+ return 0.0
28
+
29
+ for pid in task.poison_document_ids:
30
+ still = any(d.document_id == pid for d in final_documents)
31
+ if still:
32
+ return 0.0
33
+
34
+ if total_tokens > task.token_budget:
35
+ return 0.0
36
+
37
+ if not task.critical_document_ids.issubset({d.document_id for d in final_documents}):
38
+ return 0.0
39
+
40
+ if total_tokens <= task.optimal_max_tokens:
41
+ return 1.0
42
+
43
+ return 0.5
rag_gc_env/inference.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reproducible baseline policy for Adaptive Context Optimization (RAG GC).
3
+ Deterministic: fixed action sequences per task derived from metadata.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from rag_gc_env.environment import RAGGCEnvironment
9
+ from rag_gc_env.models import RAGGCAction
10
+
11
+
12
+ def run_baseline(task_name: str, seed: int = 0) -> tuple[float, list[str]]:
13
+ env = RAGGCEnvironment()
14
+ obs = env.reset(seed=seed, task_name=task_name)
15
+ log: list[str] = ["reset"]
16
+
17
+ def step(verb: str, doc_id: str | None) -> None:
18
+ nonlocal obs
19
+ obs = env.step(RAGGCAction(verb=verb, document_id=doc_id))
20
+ log.append(f"{verb}:{doc_id}")
21
+
22
+ if task_name == "easy_irrelevant_removal":
23
+ step("delete", "d1")
24
+ step("submit", None)
25
+ elif task_name == "medium_token_compression":
26
+ step("delete", "m2")
27
+ while obs.token_count > obs.token_budget and not obs.done:
28
+ step("summarize", "m0")
29
+ if len(log) > 40:
30
+ break
31
+ step("submit", None)
32
+ elif task_name == "hard_contradiction_removal":
33
+ step("delete", "h1")
34
+ step("submit", None)
35
+ else:
36
+ step("submit", None)
37
+
38
+ score = float(obs.grader_score or obs.reward or 0.0)
39
+ return score, log
40
+
41
+
42
+ if __name__ == "__main__":
43
+ for name in (
44
+ "easy_irrelevant_removal",
45
+ "medium_token_compression",
46
+ "hard_contradiction_removal",
47
+ ):
48
+ s, lg = run_baseline(name, seed=0)
49
+ print(name, "score=", s, "trace=", lg)
rag_gc_env/models.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Literal, Optional
4
+
5
+ from openenv.core.env_server.types import Action, Observation, State
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class DocumentItem(BaseModel):
10
+ document_id: str
11
+ text: str
12
+ tokens: int = Field(description="Estimated tokens for this snippet")
13
+
14
+
15
+ class RAGGCAction(Action):
16
+ verb: Literal["keep", "delete", "summarize", "submit"] = Field(
17
+ description="Document operation or submit to finalize and grade"
18
+ )
19
+ document_id: Optional[str] = Field(
20
+ default=None,
21
+ description="Target document for keep/delete/summarize; omit for submit",
22
+ )
23
+
24
+
25
+ class RAGGCReward(BaseModel):
26
+ step_reward: float = 0.0
27
+ relevance: float = 0.0
28
+ compression: float = 0.0
29
+ token_penalty: float = 0.0
30
+ critical_penalty: float = 0.0
31
+ final_score: Optional[float] = Field(
32
+ default=None, description="0.0–1.0 after submit; aligns with grader"
33
+ )
34
+
35
+
36
+ class RAGGCObservation(Observation):
37
+ query: str = ""
38
+ documents: list[DocumentItem] = Field(default_factory=list)
39
+ token_count: int = 0
40
+ token_budget: int = 0
41
+ task_name: str = ""
42
+ reward_detail: Optional[RAGGCReward] = None
43
+ message: str = ""
44
+ grader_score: Optional[float] = Field(
45
+ default=None, description="Deterministic score after episode ends"
46
+ )
47
+
48
+
49
+ class RAGGCState(State):
50
+ task_name: str = ""
51
+ max_steps: int = 64
52
+ removed_critical: bool = False
53
+ submitted: bool = False
rag_gc_env/rewards.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from rag_gc_env.models import DocumentItem, RAGGCReward
4
+ from rag_gc_env.tasks import TaskSpec
5
+
6
+
7
+ def summarize_deterministic(text: str) -> tuple[str, int]:
8
+ """Deterministic compression: first sentence or capped prefix."""
9
+ stripped = text.strip()
10
+ if not stripped:
11
+ return "", 1
12
+ cut = stripped.split(". ")
13
+ first = cut[0] + ("." if not cut[0].endswith(".") else "")
14
+ if len(first) < 40 and len(cut) > 1:
15
+ first = cut[0] + ". " + cut[1] + ("." if not cut[1].endswith(".") else "")
16
+ cap = 280
17
+ out = first[:cap] + ("..." if len(first) > cap else "")
18
+ tokens = max(1, len(out) // 4)
19
+ return out, tokens
20
+
21
+
22
+ def estimate_tokens(text: str) -> int:
23
+ return max(1, len(text) // 4)
24
+
25
+
26
+ def step_reward(
27
+ task: TaskSpec,
28
+ verb: str,
29
+ doc_id: str | None,
30
+ docs_before: dict[str, DocumentItem],
31
+ docs_after: dict[str, DocumentItem],
32
+ removed_critical_flag: bool,
33
+ ) -> RAGGCReward:
34
+ rel = 0.0
35
+ comp = 0.0
36
+ tok_pen = 0.0
37
+ crit = 0.0
38
+
39
+ if removed_critical_flag:
40
+ crit = -3.0
41
+
42
+ if verb == "delete" and doc_id in docs_before:
43
+ meta = next(
44
+ (m for did, _, _, m in task.documents if did == doc_id),
45
+ {},
46
+ )
47
+ # Reward deleting irrelevant or poison documents
48
+ if doc_id in task.irrelevant_document_ids:
49
+ rel += 0.4
50
+ elif doc_id in task.poison_document_ids:
51
+ rel += 0.6
52
+ elif doc_id in task.critical_document_ids:
53
+ crit -= 3.0
54
+ elif meta.get("hint") == "fluff":
55
+ rel += 0.2
56
+
57
+ # Deleting tokens should NOT result in a penalty proportional to the deleted tokens;
58
+ # instead, it removes the 'keep' penalty they would have incurred.
59
+ # We can add a small constant 'action cost' for deleting if desired, but 0.0 is fine here.
60
+ tok_pen = 0.0
61
+
62
+ if verb == "summarize" and doc_id in docs_before:
63
+ before_t = docs_before[doc_id].tokens
64
+ after = docs_after.get(doc_id)
65
+ if after is not None:
66
+ # Reward for the reduction in size (efficiency)
67
+ reduction_ratio = (before_t - after.tokens) / max(before_t, 1)
68
+ comp += 0.3 * max(0.0, reduction_ratio)
69
+
70
+ # The remaining tokens still incur a small penalty
71
+ tok_pen -= 0.01 * after.tokens
72
+
73
+ if doc_id in task.critical_document_ids:
74
+ for p in task.required_phrases:
75
+ if p not in after.text:
76
+ crit -= 2.5
77
+
78
+ if verb == "keep" and doc_id in docs_before:
79
+ # Standard penalty for keeping tokens in context
80
+ tok_pen -= 0.01 * docs_before[doc_id].tokens
81
+
82
+ step = rel + comp + tok_pen + crit
83
+ return RAGGCReward(
84
+ step_reward=step,
85
+ relevance=rel,
86
+ compression=comp,
87
+ token_penalty=tok_pen,
88
+ critical_penalty=crit,
89
+ )
rag_gc_env/server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Server package for OpenEnv HTTP deployment
rag_gc_env/server/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (154 Bytes). View file
 
rag_gc_env/server/__pycache__/app.cpython-311.pyc ADDED
Binary file (1.04 kB). View file
 
rag_gc_env/server/app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from openenv.core.env_server.http_server import create_fastapi_app
4
+
5
+ from rag_gc_env.environment import RAGGCEnvironment
6
+ from rag_gc_env.models import RAGGCAction, RAGGCObservation
7
+
8
+ app = create_fastapi_app(
9
+ RAGGCEnvironment,
10
+ RAGGCAction,
11
+ RAGGCObservation,
12
+ )
13
+
14
+
15
+ def main() -> None:
16
+ import uvicorn
17
+
18
+ port = int(os.environ.get("PORT", "8000"))
19
+ uvicorn.run(app, host="0.0.0.0", port=port)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ main()
rag_gc_env/tasks.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, FrozenSet
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class TaskSpec:
9
+ name: str
10
+ query: str
11
+ token_budget: int
12
+ documents: list[tuple[str, str, int, dict[str, Any]]]
13
+ # document_id, text, tokens, metadata (relevance, flags)
14
+ required_phrases: FrozenSet[str] = field(default_factory=frozenset)
15
+ forbidden_phrases: FrozenSet[str] = field(default_factory=frozenset)
16
+ critical_document_ids: FrozenSet[str] = field(default_factory=frozenset)
17
+ irrelevant_document_ids: FrozenSet[str] = field(default_factory=frozenset)
18
+ poison_document_ids: FrozenSet[str] = field(default_factory=frozenset)
19
+ optimal_max_tokens: int = 0
20
+
21
+
22
+ def _docs(
23
+ rows: list[tuple[str, str, int, dict[str, Any]]],
24
+ ) -> list[tuple[str, str, int, dict[str, Any]]]:
25
+ return rows
26
+
27
+
28
+ TASK_EASY = TaskSpec(
29
+ name="easy_irrelevant_removal",
30
+ query="What is the capital city of France?",
31
+ token_budget=400,
32
+ documents=_docs(
33
+ [
34
+ (
35
+ "d0",
36
+ "Paris has been the capital of France since political centralization in the country.",
37
+ 24,
38
+ {"relevance": 0.95, "hint": "high"},
39
+ ),
40
+ (
41
+ "d1",
42
+ "Penguins thrive in Antarctica and are unrelated to European geography.",
43
+ 18,
44
+ {"relevance": 0.08, "hint": "noise"},
45
+ ),
46
+ (
47
+ "d2",
48
+ "Lyon is a major French city but not the national capital.",
49
+ 16,
50
+ {"relevance": 0.55, "hint": "partial"},
51
+ ),
52
+ ]
53
+ ),
54
+ required_phrases=frozenset({"Paris"}),
55
+ forbidden_phrases=frozenset(),
56
+ critical_document_ids=frozenset({"d0"}),
57
+ irrelevant_document_ids=frozenset({"d1"}),
58
+ poison_document_ids=frozenset(),
59
+ optimal_max_tokens=120,
60
+ )
61
+
62
+ _LONG_DUP = (
63
+ "Paris is the capital of France. " * 18
64
+ + "This repetition exists only to inflate token usage for compression tests."
65
+ )
66
+
67
+ TASK_MEDIUM = TaskSpec(
68
+ name="medium_token_compression",
69
+ query="Which city is the capital of France?",
70
+ token_budget=120,
71
+ documents=_docs(
72
+ [
73
+ (
74
+ "m0",
75
+ _LONG_DUP,
76
+ max(1, len(_LONG_DUP) // 4),
77
+ {"relevance": 0.9, "hint": "verbose"},
78
+ ),
79
+ (
80
+ "m1",
81
+ "Administrative records list a capital city but this line omits the name intentionally.",
82
+ 14,
83
+ {"relevance": 0.55, "hint": "no_answer"},
84
+ ),
85
+ (
86
+ "m2",
87
+ "French cuisine is diverse; it does not change which city is the capital.",
88
+ 14,
89
+ {"relevance": 0.35, "hint": "fluff"},
90
+ ),
91
+ ]
92
+ ),
93
+ required_phrases=frozenset({"Paris"}),
94
+ forbidden_phrases=frozenset(),
95
+ critical_document_ids=frozenset(),
96
+ irrelevant_document_ids=frozenset({"m2"}),
97
+ poison_document_ids=frozenset(),
98
+ optimal_max_tokens=90,
99
+ )
100
+
101
+ TASK_HARD = TaskSpec(
102
+ name="hard_contradiction_removal",
103
+ query="Compute the integer result of 17 + 25.",
104
+ token_budget=350,
105
+ documents=_docs(
106
+ [
107
+ (
108
+ "h0",
109
+ "Arithmetic trace: 17 + 25 = 42.",
110
+ 12,
111
+ {"relevance": 0.9, "hint": "consistent", "trust": 0.95},
112
+ ),
113
+ (
114
+ "h1",
115
+ "Quick math note: 17 + 25 equals 43 for budgeting purposes.",
116
+ 14,
117
+ {"relevance": 0.88, "hint": "poison", "trust": 0.2},
118
+ ),
119
+ (
120
+ "h2",
121
+ "Addition of integers is associative and commutative.",
122
+ 10,
123
+ {"relevance": 0.4, "hint": "generic"},
124
+ ),
125
+ ]
126
+ ),
127
+ required_phrases=frozenset({"42"}),
128
+ forbidden_phrases=frozenset({"43"}),
129
+ critical_document_ids=frozenset({"h0"}),
130
+ irrelevant_document_ids=frozenset(),
131
+ poison_document_ids=frozenset({"h1"}),
132
+ optimal_max_tokens=200,
133
+ )
134
+
135
+ ALL_TASKS: dict[str, TaskSpec] = {
136
+ TASK_EASY.name: TASK_EASY,
137
+ TASK_MEDIUM.name: TASK_MEDIUM,
138
+ TASK_HARD.name: TASK_HARD,
139
+ }
140
+
141
+
142
+ def task_by_seed(seed: int) -> TaskSpec:
143
+ order = [TASK_EASY, TASK_MEDIUM, TASK_HARD]
144
+ return order[seed % 3]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv-core>=0.1.0
2
+ pydantic>=2.0
3
+ fastapi>=0.104.0
4
+ uvicorn[standard]>=0.24.0
5
+ typing_extensions>=4.8.0
test_reward_logic.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rag_gc_env.environment import RAGGCEnvironment
2
+ from rag_gc_env.models import RAGGCAction
3
+
4
+ def test_reward_logic():
5
+ env = RAGGCEnvironment()
6
+
7
+ # Test 1: Deleting irrelevant document should be POSITIVE
8
+ print("\n--- Test 1: Deleting irrelevant (d1) in Easy Task ---")
9
+ obs = env.reset(task_name="easy_irrelevant_removal")
10
+ obs = env.step(RAGGCAction(verb="delete", document_id="d1"))
11
+ print(f"Step Reward for deleting d1: {obs.reward}")
12
+ if obs.reward > 0:
13
+ print("PASS: Positive reward for deletion.")
14
+ else:
15
+ print("FAIL: Reward should be positive.")
16
+
17
+ # Test 2: Summarizing to save tokens should be POSITIVE
18
+ print("\n--- Test 2: Summarizing (m0) in Medium Task ---")
19
+ obs = env.reset(task_name="medium_token_compression")
20
+ tokens_before = obs.token_count
21
+ obs = env.step(RAGGCAction(verb="summarize", document_id="m0"))
22
+ print(f"Step Reward for summarizing m0: {obs.reward}")
23
+ print(f"Tokens: {tokens_before} -> {obs.token_count}")
24
+ if obs.reward > 0:
25
+ print("PASS: Positive reward for summarization.")
26
+ else:
27
+ print("FAIL: Reward should be positive (was previously negative).")
28
+
29
+ # Test 3: Over budget penalty should be reflected in reward
30
+ print("\n--- Test 3: Over budget penalty in Medium Task ---")
31
+ obs = env.reset(task_name="medium_token_compression")
32
+ # budget is 120, total is 190. over by 70. penalty should be -0.08 * 70 = -5.6
33
+ # 'keep' m1 (14 tokens). 14 * -0.01 = -0.14.
34
+ # total reward should be -5.6 - 0.14 = -5.74
35
+ obs = env.step(RAGGCAction(verb="keep", document_id="m1"))
36
+ print(f"Reward with budget penalty: {obs.reward}")
37
+ if obs.reward < -5:
38
+ print("PASS: Budget penalty correctly reflected in step reward.")
39
+ else:
40
+ print("FAIL: Budget penalty missing or too low.")
41
+
42
+ if __name__ == "__main__":
43
+ test_reward_logic()