Siddh12334 commited on
Commit
204fa23
Β·
verified Β·
1 Parent(s): 5d65fb7

feat: training space with manual start UI

Browse files
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
6
+
7
+ # Training deps β€” separate from server requirements
8
+ RUN pip install --no-cache-dir \
9
+ "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" \
10
+ trl transformers datasets accelerate \
11
+ openenv-core fastapi uvicorn pydantic \
12
+ wandb faker python-dotenv gradio
13
+
14
+ COPY . .
15
+
16
+ RUN python -m data.loader || echo "Will use fallback facts"
17
+
18
+ EXPOSE 7860
19
+
20
+ CMD ["python", "-m", "training.space_runner"]
README.md CHANGED
@@ -1,10 +1,18 @@
1
  ---
2
  title: Context Corruption Training
3
- emoji: πŸŒ–
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
1
  ---
2
  title: Context Corruption Training
3
+ emoji: πŸ‹οΈ
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
+ # ContextCorruption-Env β€” GRPO Training Space
12
+
13
+ Click **Start Training** in the UI. Set secrets first in Space Settings:
14
+ - `WANDB_API_KEY`
15
+ - `HF_TOKEN`
16
+ - `HF_HUB_MODEL_ID` (e.g. `Siddh12334/qwen-1.5b-context-corruption`)
17
+
18
+ Upgrade hardware to **A10G Small** before starting (~$1.05/hr, ~1.5 hrs total).
data/__init__.py ADDED
File without changes
data/corruption.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+
4
+ try:
5
+ from faker import Faker
6
+ except ModuleNotFoundError:
7
+ Faker = None
8
+
9
+
10
+ class _FallbackFaker:
11
+ def name(self) -> str:
12
+ return random.choice(["Alex Morgan", "Jordan Lee", "Taylor Brooks", "Casey Patel"])
13
+
14
+ def last_name(self) -> str:
15
+ return random.choice(["Morgan", "Lee", "Brooks", "Patel", "Reed"])
16
+
17
+ def company(self) -> str:
18
+ return random.choice(
19
+ ["Global Research Institute", "Civic Data Group", "Archive Analytics Lab"]
20
+ )
21
+
22
+ def word(self) -> str:
23
+ return random.choice(["revised", "alternate", "disputed", "corrected"])
24
+
25
+
26
+ fake = Faker() if Faker else _FallbackFaker()
27
+
28
+ COUNTRIES = [
29
+ "France",
30
+ "Germany",
31
+ "Brazil",
32
+ "Japan",
33
+ "Canada",
34
+ "India",
35
+ "Australia",
36
+ "Kenya",
37
+ "Mexico",
38
+ "Norway",
39
+ ]
40
+ CITIES = [
41
+ "Paris",
42
+ "Berlin",
43
+ "Tokyo",
44
+ "Toronto",
45
+ "Mumbai",
46
+ "Sydney",
47
+ "Nairobi",
48
+ "Mexico City",
49
+ "Oslo",
50
+ "Rome",
51
+ ]
52
+ ORGANIZATIONS = [
53
+ "World Health Organization",
54
+ "United Nations",
55
+ "NASA",
56
+ "Oxford University",
57
+ "Reuters",
58
+ "Smithsonian Institution",
59
+ "International Monetary Fund",
60
+ "Royal Society",
61
+ ]
62
+ ANTONYMS = {
63
+ "largest": "smallest",
64
+ "smallest": "largest",
65
+ "first": "last",
66
+ "last": "first",
67
+ "highest": "lowest",
68
+ "lowest": "highest",
69
+ "won": "lost",
70
+ "lost": "won",
71
+ "north": "south",
72
+ "south": "north",
73
+ "east": "west",
74
+ "west": "east",
75
+ "increase": "decrease",
76
+ "decrease": "increase",
77
+ "before": "after",
78
+ "after": "before",
79
+ "true": "false",
80
+ "false": "true",
81
+ "older": "newer",
82
+ "newer": "older",
83
+ "major": "minor",
84
+ "minor": "major",
85
+ }
86
+
87
+
88
+ def _preserve_case(original: str, replacement: str) -> str:
89
+ if original.isupper():
90
+ return replacement.upper()
91
+ if original.istitle():
92
+ return replacement.title()
93
+ if original.islower():
94
+ return replacement.lower()
95
+ return replacement
96
+
97
+
98
+ def _replace_first_case_insensitive(text: str, target: str, replacement: str) -> str:
99
+ pattern = re.compile(re.escape(target), re.IGNORECASE)
100
+
101
+ def repl(match: re.Match[str]) -> str:
102
+ return _preserve_case(match.group(0), replacement)
103
+
104
+ return pattern.sub(repl, text, count=1)
105
+
106
+
107
+ def _different_choice(options: list[str], current: str) -> str:
108
+ viable = [option for option in options if option.lower() != current.lower()]
109
+ return random.choice(viable or options)
110
+
111
+
112
+ def corrupt_number(text: str, answer: str) -> str:
113
+ numbers = re.findall(r"\b\d{4}\b|\b\d+\b", text)
114
+ if not numbers:
115
+ return (
116
+ f"{text} A later statistical revision changed the reported figure "
117
+ f"from {answer} to {random.randint(12, 98)}."
118
+ )
119
+
120
+ original = random.choice(numbers)
121
+ value = int(original)
122
+ if len(original) == 4 and 1900 <= value <= 2030:
123
+ replacement = str(value + random.choice([-20, -10, -5, 5, 10, 20]))
124
+ else:
125
+ mutated = value * random.choice([0.5, 2, 3, 5, 10])
126
+ replacement = str(max(1, int(round(mutated))))
127
+
128
+ return text.replace(original, replacement, 1)
129
+
130
+
131
+ def corrupt_entity(text: str, answer: str) -> str:
132
+ answer = answer.strip()
133
+ pools = [COUNTRIES, CITIES, ORGANIZATIONS]
134
+ if answer and re.search(re.escape(answer), text, re.IGNORECASE):
135
+ for pool in pools:
136
+ if answer in pool:
137
+ replacement = _different_choice(pool, answer)
138
+ return _replace_first_case_insensitive(text, answer, replacement)
139
+
140
+ if len(answer.split()) <= 3:
141
+ generated_names = [fake.name() for _ in range(8)]
142
+ replacement = _different_choice(generated_names, answer)
143
+ return _replace_first_case_insensitive(text, answer, replacement)
144
+
145
+ return (
146
+ f"{text} In a later archive note, researcher {fake.name()} attributed "
147
+ f"the answer to {fake.name()} instead."
148
+ )
149
+
150
+
151
+ def corrupt_inversion(text: str, answer: str) -> str:
152
+ pattern = re.compile(r"\b(" + "|".join(map(re.escape, ANTONYMS)) + r")\b", re.IGNORECASE)
153
+
154
+ def repl(match: re.Match[str]) -> str:
155
+ word = match.group(0)
156
+ replacement = ANTONYMS[word.lower()]
157
+ return _preserve_case(word, replacement)
158
+
159
+ corrupted, count = pattern.subn(repl, text, count=1)
160
+ if count:
161
+ return corrupted
162
+
163
+ return (
164
+ f"{text} This statement contradicts earlier scholarly consensus, "
165
+ f"which identified {answer} as incorrect."
166
+ )
167
+
168
+
169
+ def _generate_wrong_answer(answer: str) -> str:
170
+ answer = answer.strip()
171
+ if not answer:
172
+ return fake.word().title()
173
+
174
+ number_match = re.search(r"\d+", answer)
175
+ if number_match:
176
+ original = number_match.group(0)
177
+ mutated = str(int(original) + random.choice([-5, -2, -1, 1, 2, 5]))
178
+ return answer.replace(original, mutated, 1)
179
+
180
+ words = answer.split()
181
+ if len(words) == 1 and words[0][:1].isupper():
182
+ return fake.last_name()
183
+ if len(words) > 1:
184
+ shuffled = words[:]
185
+ random.shuffle(shuffled)
186
+ if shuffled != words:
187
+ return " ".join(shuffled)
188
+ return f"{answer} Institute"
189
+ return fake.word()
190
+
191
+
192
+ def corrupt_coherent(text: str, answer: str) -> str:
193
+ wrong_answer = _generate_wrong_answer(answer)
194
+ year = random.randint(2015, 2025)
195
+ org = fake.company()
196
+ source = random.choice(
197
+ [
198
+ "a peer-reviewed survey",
199
+ "an institutional archive",
200
+ "a longitudinal review",
201
+ "a Reuters-style fact check",
202
+ ]
203
+ )
204
+
205
+ if answer and re.search(re.escape(answer), text, re.IGNORECASE):
206
+ text = _replace_first_case_insensitive(text, answer, wrong_answer)
207
+
208
+ return (
209
+ f"{text} According to {source} released by {org} in {year}, the verified "
210
+ f"answer is {wrong_answer}, based on revised primary-source evidence."
211
+ )
212
+
213
+
214
+ def corrupt_text(text: str, answer: str, level: int) -> str:
215
+ try:
216
+ if level <= 1:
217
+ return corrupt_number(text, answer)
218
+ if level == 2:
219
+ return corrupt_entity(text, answer)
220
+ if level == 3:
221
+ return corrupt_inversion(text, answer)
222
+ return corrupt_coherent(text, answer)
223
+ except Exception:
224
+ return (
225
+ f"{text} A conflicting secondary source reports a different answer "
226
+ f"than {answer}."
227
+ )
data/generator.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any
3
+
4
+ from data.corruption import corrupt_text
5
+
6
+
7
+ SOURCES = [
8
+ "Encyclopedia Britannica",
9
+ "Reuters Fact Check",
10
+ "National Geographic",
11
+ "Smithsonian Magazine",
12
+ "BBC Reference Desk",
13
+ "Oxford Reference",
14
+ "World Almanac",
15
+ "Associated Press Archive",
16
+ "Library of Congress Notes",
17
+ "Academic Knowledge Base",
18
+ ]
19
+
20
+ TEMPLATES = [
21
+ "{source} summarizes the question '{question}' and identifies the answer as {answer}.",
22
+ "In its reference entry, {source} states that the correct answer to '{question}' is {answer}.",
23
+ "{source} records {answer} as the accepted answer when asked: '{question}'",
24
+ "A background note from {source} explains that {answer} is the established response to '{question}'",
25
+ "According to {source}, researchers commonly answer '{question}' with {answer}.",
26
+ "{source} lists the verified answer for '{question}' as {answer}, matching standard references.",
27
+ "The archive maintained by {source} gives {answer} as the answer to '{question}'",
28
+ "For the prompt '{question}', {source} reports that the answer is {answer}.",
29
+ ]
30
+
31
+
32
+ def _as_text(value: Any, default: str = "") -> str:
33
+ if value is None:
34
+ return default
35
+ text = str(value).strip()
36
+ return text or default
37
+
38
+
39
+ def generate_documents(
40
+ fact: dict[str, Any],
41
+ num_docs: int = 8,
42
+ corrupt_positions: list[int] | None = None,
43
+ ) -> list[dict[str, Any]]:
44
+ question = _as_text(fact.get("question"), "Unknown question?")
45
+ answer = _as_text(fact.get("answer"), "unknown")
46
+ corrupt_set = set(corrupt_positions or [])
47
+ corrupt_order = {doc_id: idx + 1 for idx, doc_id in enumerate(corrupt_positions or [])}
48
+
49
+ documents: list[dict[str, Any]] = []
50
+ for doc_id in range(num_docs):
51
+ source = random.choice(SOURCES)
52
+ template = random.choice(TEMPLATES)
53
+ content = template.format(source=source, question=question, answer=answer)
54
+ is_corrupt = doc_id in corrupt_set
55
+
56
+ if is_corrupt:
57
+ level = min(corrupt_order[doc_id], 4)
58
+ content = corrupt_text(content, answer, level)
59
+
60
+ documents.append(
61
+ {
62
+ "id": doc_id,
63
+ "title": f"{source} Document {doc_id + 1}",
64
+ "content": content,
65
+ "is_corrupt": is_corrupt,
66
+ }
67
+ )
68
+
69
+ return documents
data/loader.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import urllib.request
4
+ import ast
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+
9
+ FACTS_PATH = Path(__file__).parent / "facts.json"
10
+ FAITHEVAL_COUNTERFACTUAL_URL = (
11
+ "https://raw.githubusercontent.com/SalesforceAIResearch/FaithEval/main/"
12
+ "data/counterfactual.json"
13
+ )
14
+
15
+
16
+ def _load_dataset(*args: Any, **kwargs: Any) -> Any:
17
+ from datasets import load_dataset
18
+
19
+ return load_dataset(*args, **kwargs)
20
+
21
+
22
+ def _first_text(value: Any) -> str | None:
23
+ """Extract the first useful text value from nested dataset fields."""
24
+ if value is None:
25
+ return None
26
+ if isinstance(value, str):
27
+ text = value.strip()
28
+ if text.startswith("[") and text.endswith("]"):
29
+ try:
30
+ parsed = ast.literal_eval(text)
31
+ except (SyntaxError, ValueError):
32
+ parsed = None
33
+ parsed_text = _first_text(parsed)
34
+ if parsed_text:
35
+ return parsed_text
36
+ return text or None
37
+ if isinstance(value, (int, float)):
38
+ return str(value)
39
+ if isinstance(value, dict):
40
+ for key in ("text", "answer", "answers", "value"):
41
+ text = _first_text(value.get(key))
42
+ if text:
43
+ return text
44
+ return None
45
+ if isinstance(value, (list, tuple)):
46
+ for item in value:
47
+ text = _first_text(item)
48
+ if text:
49
+ return text
50
+ return None
51
+
52
+
53
+ def _word_count(text: str) -> int:
54
+ return len(text.split())
55
+
56
+
57
+ def _clean_question(text: Any) -> str | None:
58
+ question = _first_text(text)
59
+ if not question:
60
+ return None
61
+ question = question.strip()
62
+ if not question.endswith("?"):
63
+ question = f"{question}?"
64
+ return question
65
+
66
+
67
+ def _natural_questions_answer(row: dict[str, Any]) -> str | None:
68
+ annotations = row.get("annotations") or {}
69
+ short_answers = annotations.get("short_answers")
70
+ answer = _first_text(short_answers)
71
+ if answer and _word_count(answer) <= 5:
72
+ return answer
73
+ return None
74
+
75
+
76
+ def load_natural_questions(n: int = 300) -> list[dict[str, str]]:
77
+ facts: list[dict[str, str]] = []
78
+ dataset = _load_dataset(
79
+ "google-research-datasets/natural_questions",
80
+ split="train",
81
+ streaming=True,
82
+ )
83
+
84
+ for row in dataset:
85
+ question = _clean_question(row.get("question") or row.get("question_text"))
86
+ answer = _natural_questions_answer(row)
87
+ if not question or not answer:
88
+ continue
89
+
90
+ facts.append(
91
+ {
92
+ "question": question,
93
+ "answer": answer,
94
+ "source": "natural_questions",
95
+ "conflict_type": "entity",
96
+ }
97
+ )
98
+ if len(facts) >= n:
99
+ break
100
+
101
+ return facts
102
+
103
+
104
+ def load_popqa(n: int = 150) -> list[dict[str, str]]:
105
+ facts: list[dict[str, str]] = []
106
+ dataset = _load_dataset("akariasai/PopQA", split="test")
107
+
108
+ for row in dataset:
109
+ question = _clean_question(row.get("question"))
110
+ answer = _first_text(row.get("possible_answers"))
111
+ if not question or not answer:
112
+ continue
113
+
114
+ facts.append(
115
+ {
116
+ "question": question,
117
+ "answer": answer,
118
+ "source": "popqa",
119
+ "conflict_type": "entity",
120
+ "entity": _first_text(row.get("subj") or row.get("entity")) or "",
121
+ "relation": _first_text(row.get("prop") or row.get("relation")) or "",
122
+ }
123
+ )
124
+ if len(facts) >= n:
125
+ break
126
+
127
+ return facts
128
+
129
+
130
+ def _iter_faitheval_items(payload: Any) -> list[dict[str, Any]]:
131
+ if isinstance(payload, list):
132
+ return [item for item in payload if isinstance(item, dict)]
133
+ if isinstance(payload, dict):
134
+ for key in ("data", "examples", "items", "counterfactual"):
135
+ items = payload.get(key)
136
+ if isinstance(items, list):
137
+ return [item for item in items if isinstance(item, dict)]
138
+ return []
139
+
140
+
141
+ def load_faitheval_counterfactual(n: int = 100) -> list[dict[str, str]]:
142
+ try:
143
+ with urllib.request.urlopen(FAITHEVAL_COUNTERFACTUAL_URL, timeout=20) as response:
144
+ payload = json.loads(response.read().decode("utf-8"))
145
+ except Exception:
146
+ return []
147
+
148
+ facts: list[dict[str, str]] = []
149
+ for item in _iter_faitheval_items(payload):
150
+ question = _clean_question(
151
+ item.get("question") or item.get("query") or item.get("claim")
152
+ )
153
+ answer = _first_text(
154
+ item.get("answer")
155
+ or item.get("gold_answer")
156
+ or item.get("label")
157
+ or item.get("target")
158
+ )
159
+ if not question or not answer:
160
+ continue
161
+
162
+ facts.append(
163
+ {
164
+ "question": question,
165
+ "answer": answer,
166
+ "source": "faitheval",
167
+ "conflict_type": "counterfactual",
168
+ "provided_context": _first_text(
169
+ item.get("provided_context")
170
+ or item.get("context")
171
+ or item.get("evidence")
172
+ )
173
+ or "",
174
+ }
175
+ )
176
+ if len(facts) >= n:
177
+ break
178
+
179
+ return facts
180
+
181
+
182
+ def build_fact_database() -> list[dict[str, str]]:
183
+ facts = (
184
+ load_natural_questions()
185
+ + load_popqa()
186
+ + load_faitheval_counterfactual()
187
+ )
188
+ random.shuffle(facts)
189
+
190
+ FACTS_PATH.parent.mkdir(parents=True, exist_ok=True)
191
+ with open(FACTS_PATH, "w", encoding="utf-8") as f:
192
+ json.dump(facts, f, indent=2, ensure_ascii=False)
193
+
194
+ counts: dict[str, int] = {}
195
+ for fact in facts:
196
+ source = fact.get("source", "unknown")
197
+ counts[source] = counts.get(source, 0) + 1
198
+
199
+ print(f"Wrote {len(facts)} facts to {FACTS_PATH}")
200
+ print(f"Source counts: {counts}")
201
+ return facts
202
+
203
+
204
+ if __name__ == "__main__":
205
+ build_fact_database()
environment/__init__.py ADDED
File without changes
environment/actions.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Optional
3
+ from pydantic import BaseModel, field_validator
4
+ from openenv.core import Action, Observation, State
5
+
6
+
7
+ class ActionType(str, Enum):
8
+ read_doc = "read_doc"
9
+ flag_suspicious = "flag_suspicious"
10
+ unflag_doc = "unflag_doc"
11
+ submit_answer = "submit_answer"
12
+
13
+
14
+ class ContextCorruptionAction(Action):
15
+ action_type: ActionType
16
+ doc_id: Optional[int] = None
17
+ answer: Optional[str] = None
18
+ confidence: Optional[float] = None
19
+
20
+ @field_validator("confidence")
21
+ @classmethod
22
+ def confidence_range(cls, v):
23
+ if v is not None and not (0.0 <= v <= 1.0):
24
+ raise ValueError("confidence must be between 0.0 and 1.0")
25
+ return v
26
+
27
+
28
+ class Document(BaseModel):
29
+ id: int
30
+ title: str
31
+ content: str
32
+ is_flagged: bool = False
33
+
34
+
35
+ class EpisodeObservation(Observation):
36
+ question: str = ""
37
+ documents: list[Document] = []
38
+ flagged_ids: list[int] = []
39
+ budget_remaining: int = 0
40
+ turn: int = 0
41
+ message: Optional[str] = None
42
+ # `done` and `reward` inherited from Observation
43
+
44
+
45
+ class ContextCorruptionState(State):
46
+ question: str = ""
47
+ ground_truth: str = ""
48
+ corrupt_ids: list[int] = []
49
+ flagged_ids: list[int] = []
50
+ budget_used: int = 0
51
+ done: bool = False
52
+ reward: Optional[float] = None
53
+ breakdown: Optional[dict] = None
environment/env.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from pathlib import Path
4
+
5
+ from openenv.core import Environment
6
+
7
+ from environment.actions import (
8
+ ActionType, ContextCorruptionAction, Document,
9
+ EpisodeObservation, ContextCorruptionState,
10
+ )
11
+ from environment.reward import ContextCorruptionRubric
12
+
13
+ _FALLBACK_FACTS = [
14
+ {"question": "What is the capital of France?", "answer": "Paris"}
15
+ ]
16
+
17
+
18
+ class ContextCorruptionEnv(Environment[ContextCorruptionAction, EpisodeObservation, ContextCorruptionState]):
19
+ MAX_BUDGET = 12
20
+ NUM_DOCS = 8
21
+ DIFFICULTY_LEVELS = [1, 2, 3, 4]
22
+ SUPPORTS_CONCURRENT_SESSIONS = True
23
+
24
+ def __init__(self, difficulty=None):
25
+ rubric = ContextCorruptionRubric(state_fn=self._state_dict)
26
+ super().__init__(rubric=rubric)
27
+ self.difficulty = difficulty
28
+ facts_path = Path(__file__).parent.parent / "data" / "facts.json"
29
+ if facts_path.exists():
30
+ with open(facts_path, encoding="utf-8") as f:
31
+ self._facts = json.load(f)
32
+ else:
33
+ self._facts = _FALLBACK_FACTS
34
+ self._reset_state()
35
+
36
+ def _reset_state(self):
37
+ self._question = ""
38
+ self._ground_truth = ""
39
+ self._documents: list[dict] = []
40
+ self._corrupt_ids: list[int] = []
41
+ self._flagged_ids: list[int] = []
42
+ self._budget_used = 0
43
+ self._turn = 0
44
+ self._done = False
45
+ self._reward = None
46
+ self._breakdown = None
47
+
48
+ def reset(self, seed=None, episode_id=None, **kwargs) -> EpisodeObservation:
49
+ self._reset_rubric()
50
+ self._reset_state()
51
+ if seed is not None:
52
+ random.seed(seed)
53
+ fact = random.choice(self._facts)
54
+ n_corrupt = self.difficulty if self.difficulty is not None else random.choice(self.DIFFICULTY_LEVELS)
55
+ self._corrupt_ids = random.sample(range(self.NUM_DOCS), n_corrupt)
56
+ self._question = fact["question"]
57
+ self._ground_truth = fact["answer"]
58
+
59
+ try:
60
+ from data.generator import generate_documents
61
+ raw_docs = generate_documents(fact, num_docs=self.NUM_DOCS, corrupt_positions=self._corrupt_ids)
62
+ except Exception:
63
+ raw_docs = [
64
+ {"id": i, "title": f"Document {i}", "content": fact["answer"], "is_corrupt": i in self._corrupt_ids}
65
+ for i in range(self.NUM_DOCS)
66
+ ]
67
+
68
+ self._documents = raw_docs
69
+ return self._apply_transform(self._build_observation())
70
+
71
+ def step(self, action: ContextCorruptionAction, timeout_s=None, **kwargs) -> EpisodeObservation:
72
+ if self._done:
73
+ return self._apply_transform(self._build_observation(message="Episode already done."))
74
+
75
+ self._turn += 1
76
+ self._budget_used += 1
77
+
78
+ if action.action_type == ActionType.read_doc:
79
+ pass
80
+
81
+ elif action.action_type == ActionType.flag_suspicious:
82
+ if action.doc_id is not None and action.doc_id not in self._flagged_ids:
83
+ self._flagged_ids.append(action.doc_id)
84
+
85
+ elif action.action_type == ActionType.unflag_doc:
86
+ if action.doc_id in self._flagged_ids:
87
+ self._flagged_ids.remove(action.doc_id)
88
+
89
+ elif action.action_type == ActionType.submit_answer:
90
+ self._done = True
91
+
92
+ # Force-submit on budget exhaustion
93
+ if self._budget_used >= self.MAX_BUDGET and not self._done:
94
+ self._done = True
95
+
96
+ obs = self._build_observation()
97
+
98
+ if obs.done:
99
+ obs.reward = self._apply_rubric(action, obs)
100
+ self._reward = obs.reward
101
+ self._breakdown = self.rubric.last_breakdown if self.rubric else None
102
+
103
+ return self._apply_transform(obs)
104
+
105
+ @property
106
+ def state(self) -> ContextCorruptionState:
107
+ return ContextCorruptionState(
108
+ question=self._question,
109
+ ground_truth=self._ground_truth,
110
+ corrupt_ids=list(self._corrupt_ids),
111
+ flagged_ids=list(self._flagged_ids),
112
+ budget_used=self._budget_used,
113
+ done=self._done,
114
+ reward=self._reward,
115
+ breakdown=self._breakdown,
116
+ )
117
+
118
+ def _state_dict(self) -> dict:
119
+ return {
120
+ "ground_truth": self._ground_truth,
121
+ "flagged_ids": list(self._flagged_ids),
122
+ "corrupt_ids": list(self._corrupt_ids),
123
+ "budget_used": self._budget_used,
124
+ "max_budget": self.MAX_BUDGET,
125
+ }
126
+
127
+ def _build_observation(self, message=None) -> EpisodeObservation:
128
+ docs = [
129
+ Document(
130
+ id=d["id"],
131
+ title=d["title"],
132
+ content=d["content"],
133
+ is_flagged=d["id"] in self._flagged_ids,
134
+ )
135
+ for d in self._documents
136
+ ]
137
+ return EpisodeObservation(
138
+ question=self._question,
139
+ documents=docs,
140
+ flagged_ids=list(self._flagged_ids),
141
+ budget_remaining=self.MAX_BUDGET - self._budget_used,
142
+ turn=self._turn,
143
+ done=self._done,
144
+ reward=self._reward,
145
+ message=message,
146
+ )
environment/reward.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from openenv.core.rubrics import Rubric
3
+
4
+
5
+ def _normalize(text: str) -> str:
6
+ text = text.lower()
7
+ text = re.sub(r"[^\w\s]", "", text)
8
+ text = re.sub(r"\s+", " ", text).strip()
9
+ return text
10
+
11
+
12
+ def compute_reward(
13
+ submitted_answer: str,
14
+ ground_truth_answer: str,
15
+ flagged_ids: list[int],
16
+ corrupt_ids: list[int],
17
+ confidence: float,
18
+ budget_used: int,
19
+ max_budget: int,
20
+ ) -> tuple[float, dict]:
21
+ correct = _normalize(submitted_answer) == _normalize(ground_truth_answer)
22
+ answer_score = 0.4 if correct else 0.0
23
+
24
+ true_positives = [i for i in flagged_ids if i in corrupt_ids]
25
+ recall = len(true_positives) / len(corrupt_ids) if corrupt_ids else 0.0
26
+ recall_score = 0.3 * recall
27
+
28
+ false_positives = [i for i in flagged_ids if i not in corrupt_ids]
29
+ precision_score = max(0.0, 0.2 - 0.1 * len(false_positives))
30
+
31
+ confidence = confidence or 0.0
32
+ calibration_score = (0.1 * confidence) if correct else (-0.2 * confidence)
33
+
34
+ efficiency_score = 0.05 * (1 - budget_used / max_budget)
35
+
36
+ total = answer_score + recall_score + precision_score + calibration_score + efficiency_score
37
+
38
+ breakdown = {
39
+ "answer_correctness": round(answer_score, 4),
40
+ "flag_recall": round(recall_score, 4),
41
+ "false_positive_penalty": round(precision_score, 4),
42
+ "confidence_calibration": round(calibration_score, 4),
43
+ "efficiency": round(efficiency_score, 4),
44
+ "total": round(total, 4),
45
+ }
46
+ return round(total, 4), breakdown
47
+
48
+
49
+ class ContextCorruptionRubric(Rubric):
50
+ """Scores a completed episode using compute_reward().
51
+
52
+ Requires a state_fn closure to access ground-truth env state that is
53
+ intentionally hidden from the agent's observation.
54
+ """
55
+
56
+ def __init__(self, state_fn):
57
+ super().__init__()
58
+ self._state_fn = state_fn
59
+ self.last_breakdown: dict = {}
60
+
61
+ def forward(self, action, observation) -> float:
62
+ if not observation.done:
63
+ return 0.0
64
+ s = self._state_fn()
65
+ reward, breakdown = compute_reward(
66
+ submitted_answer=getattr(action, "answer", None) or "",
67
+ ground_truth_answer=s["ground_truth"],
68
+ flagged_ids=s["flagged_ids"],
69
+ corrupt_ids=s["corrupt_ids"],
70
+ confidence=getattr(action, "confidence", None) or 0.0,
71
+ budget_used=s["budget_used"],
72
+ max_budget=s["max_budget"],
73
+ )
74
+ self.last_breakdown = breakdown
75
+ return reward
environment/server.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from openenv.core import create_app
4
+ import uvicorn
5
+
6
+ load_dotenv()
7
+
8
+ from environment.actions import ContextCorruptionAction, EpisodeObservation
9
+ from environment.env import ContextCorruptionEnv
10
+
11
+ _difficulty_env = os.getenv("DIFFICULTY")
12
+ _difficulty = int(_difficulty_env) if _difficulty_env else None
13
+ _max_sessions = int(os.getenv("MAX_CONCURRENT_ENVS", "64"))
14
+
15
+ app = create_app(
16
+ env=lambda: ContextCorruptionEnv(difficulty=_difficulty),
17
+ action_cls=ContextCorruptionAction,
18
+ observation_cls=EpisodeObservation,
19
+ env_name="ContextCorruption-Env",
20
+ max_concurrent_envs=_max_sessions,
21
+ )
22
+
23
+ if __name__ == "__main__":
24
+ uvicorn.run("environment.server:app", host="0.0.0.0", port=7860, reload=False)
training/ContextCorruption_GRPO.ipynb ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# ContextCorruption-Env β€” GRPO Training\n",
8
+ "> **OpenEnv Hackathon | Meta Γ— HuggingFace Γ— PyTorch**\n",
9
+ "\n",
10
+ "Fine-tunes **Qwen2-1.5B-Instruct** with GRPO to identify corrupted documents and answer questions correctly.\n",
11
+ "\n",
12
+ "**Reward signal (fully deterministic, no LLM judge):**\n",
13
+ "| Component | Weight |\n",
14
+ "|---|---|\n",
15
+ "| Answer correctness (exact match after normalisation) | +0.40 |\n",
16
+ "| Corruption detection recall | +0.30 |\n",
17
+ "| False-positive penalty | +0.20 |\n",
18
+ "| Confidence calibration | Β±0.10 |\n",
19
+ "| Efficiency bonus | +0.05 |\n",
20
+ "\n",
21
+ "**Random baseline:** avg reward β‰ˆ 0.13 β€” beat this to show improvement.\n",
22
+ "\n",
23
+ "---\n",
24
+ "⚠️ Requires **GPU runtime** (A100 recommended). Go to `Runtime β†’ Change runtime type β†’ GPU`."
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "## 1. Install dependencies"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "%%capture\n",
41
+ "!pip install openenv-core==0.2.3 unsloth trl transformers datasets wandb faker python-dotenv"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "metadata": {},
47
+ "source": [
48
+ "## 2. Clone repo and generate facts"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "import os\n",
58
+ "\n",
59
+ "REPO_URL = \"https://github.com/sas-dev5/context-corruption-env.git\"\n",
60
+ "\n",
61
+ "!git clone {REPO_URL}\n",
62
+ "%cd context-corruption-env\n",
63
+ "\n",
64
+ "# Generate facts.json (pulls NQ + PopQA)\n",
65
+ "!python -m data.loader"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {},
71
+ "source": [
72
+ "## 3. Authenticate WandB and HuggingFace"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "import wandb\n",
82
+ "from huggingface_hub import login\n",
83
+ "\n",
84
+ "# Paste your keys here or set as Colab secrets\n",
85
+ "WANDB_API_KEY = os.getenv(\"WANDB_API_KEY\", \"\")\n",
86
+ "HF_TOKEN = os.getenv(\"HF_TOKEN\", \"\")\n",
87
+ "HF_HUB_MODEL_ID = \"\" # e.g. \"your-username/qwen-1.5b-context-corruption\" β€” leave blank to skip\n",
88
+ "\n",
89
+ "if WANDB_API_KEY:\n",
90
+ " wandb.login(key=WANDB_API_KEY)\n",
91
+ "else:\n",
92
+ " wandb.login() # interactive prompt\n",
93
+ "\n",
94
+ "if HF_TOKEN:\n",
95
+ " login(token=HF_TOKEN)\n",
96
+ "\n",
97
+ "os.environ[\"HF_HUB_MODEL_ID\"] = HF_HUB_MODEL_ID"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "## 4. Verify environment (smoke test)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "from environment.env import ContextCorruptionEnv\n",
114
+ "from environment.actions import ContextCorruptionAction, ActionType\n",
115
+ "\n",
116
+ "env = ContextCorruptionEnv(difficulty=2)\n",
117
+ "obs = env.reset()\n",
118
+ "assert len(obs.documents) == 8\n",
119
+ "obs = env.step(ContextCorruptionAction(action_type=ActionType.submit_answer, answer=\"test\", confidence=0.5))\n",
120
+ "assert obs.done and obs.reward is not None\n",
121
+ "print(f\"βœ… Smoke test passed | reward: {obs.reward:.4f}\")\n",
122
+ "print(f\" Question: {env.state.question}\")"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {},
128
+ "source": [
129
+ "## 5. Preview training dataset"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "import sys\n",
139
+ "sys.path.insert(0, \".\")\n",
140
+ "from training.train_grpo import build_dataset, SYSTEM_PROMPT\n",
141
+ "\n",
142
+ "sample_ds = build_dataset(n_episodes=5, seed=0)\n",
143
+ "sample = sample_ds[0]\n",
144
+ "print(\"System:\", sample[\"messages\"][0][\"content\"][:200], \"...\")\n",
145
+ "print(\"\\nUser message (first 400 chars):\", sample[\"messages\"][1][\"content\"][:400], \"...\")\n",
146
+ "print(\"\\nGround truth:\", sample[\"ground_truth\"])\n",
147
+ "print(\"Corrupt doc IDs:\", sample[\"corrupt_ids\"])"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "markdown",
152
+ "metadata": {},
153
+ "source": [
154
+ "## 6. Run GRPO training\n",
155
+ "\n",
156
+ "Expected time on A100: ~45–60 min for 3 epochs over 500 episodes."
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "from training.train_grpo import main\n",
166
+ "main()"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "metadata": {},
172
+ "source": [
173
+ "## 7. View training curves"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "from IPython.display import Image, display\n",
183
+ "\n",
184
+ "display(Image(\"assets/reward_curve.png\"))\n",
185
+ "display(Image(\"assets/loss_curve.png\"))"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "markdown",
190
+ "metadata": {},
191
+ "source": [
192
+ "## 8. Evaluate trained model vs baseline"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "import json, torch, re\n",
202
+ "from unsloth import FastLanguageModel\n",
203
+ "from training.train_grpo import (\n",
204
+ " MODEL_NAME, MAX_SEQ_LENGTH, OUTPUT_DIR,\n",
205
+ " build_dataset, SYSTEM_PROMPT, _parse_completion\n",
206
+ ")\n",
207
+ "from environment.reward import compute_reward\n",
208
+ "\n",
209
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
210
+ " model_name=f\"{OUTPUT_DIR}-final\",\n",
211
+ " max_seq_length=MAX_SEQ_LENGTH,\n",
212
+ " load_in_4bit=True,\n",
213
+ ")\n",
214
+ "FastLanguageModel.for_inference(model)\n",
215
+ "\n",
216
+ "eval_ds = build_dataset(n_episodes=50, seed=999)\n",
217
+ "rewards = []\n",
218
+ "\n",
219
+ "for row in eval_ds:\n",
220
+ " prompt = tokenizer.apply_chat_template(\n",
221
+ " row[\"messages\"], tokenize=False, add_generation_prompt=True\n",
222
+ " )\n",
223
+ " inputs = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
224
+ " with torch.no_grad():\n",
225
+ " out = model.generate(**inputs, max_new_tokens=256, temperature=0.1, do_sample=True)\n",
226
+ " completion = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
227
+ " parsed = _parse_completion(completion)\n",
228
+ " if parsed:\n",
229
+ " reward, _ = compute_reward(\n",
230
+ " parsed.get(\"answer\", \"\"), row[\"ground_truth\"],\n",
231
+ " [int(x) for x in parsed.get(\"suspicious_docs\", [])],\n",
232
+ " row[\"corrupt_ids\"], float(parsed.get(\"confidence\", 0.5)),\n",
233
+ " budget_used=1, max_budget=12\n",
234
+ " )\n",
235
+ " else:\n",
236
+ " reward = 0.0\n",
237
+ " rewards.append(reward)\n",
238
+ "\n",
239
+ "avg = sum(rewards) / len(rewards)\n",
240
+ "print(f\"\\n{'='*50}\")\n",
241
+ "print(f\"Trained model avg reward : {avg:.4f}\")\n",
242
+ "print(f\"Random baseline avg : 0.1302\")\n",
243
+ "print(f\"Improvement : {avg - 0.1302:+.4f}\")\n",
244
+ "print(f\"{'='*50}\")"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {},
250
+ "source": [
251
+ "## 9. Commit plots and results"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "trained_avg = avg # from cell above\n",
261
+ "\n",
262
+ "results = {\n",
263
+ " \"baseline_avg_reward\": 0.1302,\n",
264
+ " \"trained_avg_reward\": round(trained_avg, 4),\n",
265
+ " \"improvement\": round(trained_avg - 0.1302, 4),\n",
266
+ " \"n_eval_episodes\": 50,\n",
267
+ " \"model\": \"Qwen2-1.5B-Instruct + LoRA r=16 GRPO\",\n",
268
+ "}\n",
269
+ "with open(\"eval/trained_results.json\", \"w\") as f:\n",
270
+ " json.dump(results, f, indent=2)\n",
271
+ "\n",
272
+ "!git config user.email \"colab@training\"\n",
273
+ "!git config user.name \"Colab Training Run\"\n",
274
+ "!git add assets/reward_curve.png assets/loss_curve.png eval/trained_results.json\n",
275
+ "!git commit -m \"results: add training curves and eval results\"\n",
276
+ "!git push origin main\n",
277
+ "print(\"Done β€” plots and results committed.\")"
278
+ ]
279
+ }
280
+ ],
281
+ "metadata": {
282
+ "accelerator": "GPU",
283
+ "colab": {
284
+ "gpuType": "A100",
285
+ "name": "ContextCorruption_GRPO.ipynb",
286
+ "provenance": []
287
+ },
288
+ "kernelspec": {
289
+ "display_name": "Python 3",
290
+ "language": "python",
291
+ "name": "python3"
292
+ },
293
+ "language_info": {
294
+ "name": "python",
295
+ "version": "3.11.0"
296
+ }
297
+ },
298
+ "nbformat": 4,
299
+ "nbformat_minor": 4
300
+ }
training/space_runner.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI for the training Space.
3
+ Training does NOT start automatically β€” user must click "Start Training".
4
+ """
5
+ import os
6
+ import sys
7
+ import threading
8
+ import time
9
+ from pathlib import Path
10
+
11
+ import gradio as gr
12
+ from dotenv import load_dotenv
13
+
14
+ load_dotenv()
15
+
16
+ _log_lines: list[str] = []
17
+ _training_status = "idle" # idle | running | complete | failed
18
+
19
+
20
+ def _append_log(msg: str):
21
+ ts = time.strftime("%H:%M:%S")
22
+ _log_lines.append(f"[{ts}] {msg}")
23
+
24
+
25
+ def _run_training():
26
+ global _training_status
27
+ _training_status = "running"
28
+ _append_log("Training started.")
29
+ try:
30
+ # Redirect stdout so log lines appear in the UI
31
+ import io
32
+ import contextlib
33
+
34
+ sys.path.insert(0, str(Path(__file__).parent.parent))
35
+ from training.train_grpo import main
36
+
37
+ # Capture print output
38
+ old_stdout = sys.stdout
39
+ old_stderr = sys.stderr
40
+
41
+ class Tee:
42
+ def __init__(self, orig):
43
+ self._orig = orig
44
+
45
+ def write(self, msg):
46
+ if msg.strip():
47
+ _append_log(msg.rstrip())
48
+ self._orig.write(msg)
49
+
50
+ def flush(self):
51
+ self._orig.flush()
52
+
53
+ sys.stdout = Tee(old_stdout)
54
+ sys.stderr = Tee(old_stderr)
55
+
56
+ try:
57
+ main()
58
+ finally:
59
+ sys.stdout = old_stdout
60
+ sys.stderr = old_stderr
61
+
62
+ _training_status = "complete"
63
+ _append_log("βœ… Training complete. Check WandB for curves.")
64
+ except Exception as e:
65
+ _training_status = "failed"
66
+ _append_log(f"❌ Training failed: {e}")
67
+
68
+
69
+ def start_training():
70
+ global _training_status
71
+ if _training_status == "running":
72
+ return "⚠️ Training is already running.", _get_logs()
73
+ if _training_status == "complete":
74
+ return "βœ… Training already complete.", _get_logs()
75
+
76
+ missing = []
77
+ if not os.getenv("WANDB_API_KEY"):
78
+ missing.append("WANDB_API_KEY")
79
+ if not os.getenv("HF_TOKEN"):
80
+ missing.append("HF_TOKEN")
81
+ if not os.getenv("HF_HUB_MODEL_ID"):
82
+ missing.append("HF_HUB_MODEL_ID")
83
+ if missing:
84
+ return f"❌ Missing secrets: {', '.join(missing)}. Set them in Space Settings β†’ Variables and secrets.", _get_logs()
85
+
86
+ threading.Thread(target=_run_training, daemon=True).start()
87
+ return "πŸš€ Training started! Logs updating below...", _get_logs()
88
+
89
+
90
+ def _get_logs() -> str:
91
+ return "\n".join(_log_lines[-80:]) if _log_lines else "No logs yet."
92
+
93
+
94
+ def get_status() -> str:
95
+ icons = {"idle": "⏸️ Idle", "running": "πŸ”„ Training in progress...",
96
+ "complete": "βœ… Complete", "failed": "❌ Failed"}
97
+ return icons.get(_training_status, _training_status)
98
+
99
+
100
+ def refresh():
101
+ return get_status(), _get_logs()
102
+
103
+
104
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
105
+
106
+ with gr.Blocks(title="ContextCorruption Training") as demo:
107
+ gr.Markdown("""
108
+ # ContextCorruption-Env β€” GRPO Training
109
+ **Qwen2-1.5B-Instruct** fine-tuned to identify corrupted documents and resist misleading context.
110
+
111
+ Before starting, ensure these secrets are set in **Space Settings β†’ Variables and secrets**:
112
+ - `WANDB_API_KEY`
113
+ - `HF_TOKEN`
114
+ - `HF_HUB_MODEL_ID` (e.g. `Siddh12334/qwen-1.5b-context-corruption`)
115
+ """)
116
+
117
+ status_box = gr.Textbox(label="Status", value="⏸️ Idle", interactive=False)
118
+ log_box = gr.Textbox(label="Training Logs", lines=20, interactive=False,
119
+ value="Waiting to start...")
120
+ msg_box = gr.Textbox(label="Message", interactive=False)
121
+
122
+ with gr.Row():
123
+ start_btn = gr.Button("πŸš€ Start Training", variant="primary", scale=2)
124
+ refresh_btn = gr.Button("πŸ”„ Refresh Logs", scale=1)
125
+
126
+ gr.Markdown("""
127
+ ---
128
+ **Config:** 500 episodes Β· 3 epochs Β· Qwen2-1.5B Β· LoRA r=16 Β· A10G ~1.5 hrs Β· ~$2
129
+ """)
130
+
131
+ start_btn.click(fn=start_training, outputs=[msg_box, log_box])
132
+ refresh_btn.click(fn=refresh, outputs=[status_box, log_box])
133
+
134
+ # Auto-refresh every 10s while running
135
+ demo.load(fn=refresh, outputs=[status_box, log_box], every=10)
136
+
137
+
138
+ if __name__ == "__main__":
139
+ demo.launch(server_name="0.0.0.0", server_port=7860)
training/train_grpo.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRPO fine-tuning of Qwen2-1.5B-Instruct on ContextCorruption-Env.
3
+
4
+ Architecture:
5
+ - Single-turn formulation: model sees question + all 8 docs, responds with
6
+ JSON {"answer": "...", "suspicious_docs": [0, 3], "confidence": 0.85}
7
+ - Two reward signals: correctness (from compute_reward) + format (valid JSON)
8
+ - WandB logs metrics + sample completions every LOGGING_STEPS
9
+ - Pushes final model to HF Hub after training
10
+
11
+ Usage (on GPU machine / HF Space):
12
+ pip install -r requirements.txt
13
+ WANDB_API_KEY=... HF_TOKEN=... python -m training.train_grpo
14
+ """
15
+
16
+ import json
17
+ import os
18
+ import random
19
+ import re
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ import wandb
24
+
25
+ # ── Config ────────────────────────────────────────────────────────────────────
26
+ MODEL_NAME = "unsloth/Qwen2-1.5B-Instruct"
27
+ MAX_SEQ_LENGTH = 2048
28
+ LOAD_IN_4BIT = True
29
+ LORA_R = 16
30
+ LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]
31
+
32
+ NUM_TRAIN_EPOCHS = 3
33
+ BATCH_SIZE = 4
34
+ GRAD_ACCUM = 4
35
+ LR = 5e-5
36
+ MAX_COMPLETION_LENGTH = 512
37
+ NUM_GENERATIONS = 8
38
+ LOGGING_STEPS = 10
39
+ SAVE_STEPS = 50
40
+ SEED = 42
41
+
42
+ N_TRAIN_EPISODES = 500
43
+ N_EVAL_EPISODES = 50
44
+
45
+ WANDB_PROJECT = "context-corruption-env"
46
+ WANDB_RUN_NAME = "qwen-1.5b-grpo-run1"
47
+ OUTPUT_DIR = "checkpoints/grpo-qwen-1.5b"
48
+ HF_HUB_MODEL_ID = os.getenv("HF_HUB_MODEL_ID", "") # set to push model after training
49
+
50
+ SYSTEM_PROMPT = """\
51
+ You are an epistemic agent. You will receive a factual question and a set of reference documents.
52
+ Some documents may contain deliberately corrupted information (wrong numbers, wrong names, inverted facts).
53
+
54
+ Your task:
55
+ 1. Answer the question correctly using your knowledge and the documents.
56
+ 2. Identify which document IDs contain false or corrupted information.
57
+ 3. Report your confidence (0.0–1.0).
58
+
59
+ Respond ONLY with valid JSON on a single line:
60
+ {"answer": "<your answer>", "suspicious_docs": [<doc ids>], "confidence": <0.0-1.0>}
61
+
62
+ Examples:
63
+ {"answer": "Paris", "suspicious_docs": [2, 5], "confidence": 0.95}
64
+ {"answer": "1969", "suspicious_docs": [], "confidence": 0.8}"""
65
+
66
+
67
+ # ── Dataset builder ────────────────────────────────────────────────────────────
68
+
69
+ def _format_user_message(question: str, docs: list[dict]) -> str:
70
+ docs_text = "\n\n".join(
71
+ f"[Doc {d['id']}] {d['title']}\n{d['content']}" for d in docs
72
+ )
73
+ return f"Question: {question}\n\nDocuments:\n{docs_text}"
74
+
75
+
76
+ def build_dataset(n_episodes: int, seed: int = SEED) -> "datasets.Dataset":
77
+ from datasets import Dataset
78
+ from data.generator import generate_documents
79
+
80
+ random.seed(seed)
81
+ facts_path = Path(__file__).parent.parent / "data" / "facts.json"
82
+ if not facts_path.exists():
83
+ raise FileNotFoundError(
84
+ "data/facts.json not found. Run: python -m data.loader"
85
+ )
86
+ facts = json.loads(facts_path.read_text(encoding="utf-8"))
87
+
88
+ rows = []
89
+ for _ in range(n_episodes):
90
+ fact = random.choice(facts)
91
+ n_corrupt = random.choice([1, 2, 3, 4])
92
+ corrupt_ids = random.sample(range(8), n_corrupt)
93
+ try:
94
+ docs = generate_documents(fact, num_docs=8, corrupt_positions=corrupt_ids)
95
+ except Exception:
96
+ docs = [
97
+ {"id": i, "title": f"Doc {i}", "content": fact["answer"],
98
+ "is_corrupt": i in corrupt_ids}
99
+ for i in range(8)
100
+ ]
101
+ rows.append({
102
+ "messages": [
103
+ {"role": "system", "content": SYSTEM_PROMPT},
104
+ {"role": "user", "content": _format_user_message(fact["question"], docs)},
105
+ ],
106
+ "ground_truth": fact["answer"],
107
+ "corrupt_ids": corrupt_ids,
108
+ })
109
+
110
+ return Dataset.from_list(rows)
111
+
112
+
113
+ # ── Reward functions ───────────────────────────────────────────────────────────
114
+
115
+ def _parse_completion(text: str) -> dict | None:
116
+ """Extract first JSON object from completion text."""
117
+ # Strip any <think>...</think> blocks (chain-of-thought models)
118
+ text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
119
+ # Try direct parse first
120
+ try:
121
+ return json.loads(text)
122
+ except json.JSONDecodeError:
123
+ pass
124
+ # Find first {...} block
125
+ match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
126
+ if match:
127
+ try:
128
+ return json.loads(match.group())
129
+ except json.JSONDecodeError:
130
+ pass
131
+ return None
132
+
133
+
134
+ def format_reward(prompts, completions, **kwargs) -> list[float]:
135
+ """Small bonus for structurally valid responses β€” teaches the output format."""
136
+ rewards = []
137
+ for completion in completions:
138
+ parsed = _parse_completion(completion)
139
+ if parsed is None:
140
+ rewards.append(-0.1)
141
+ continue
142
+ has_answer = isinstance(parsed.get("answer"), str) and parsed["answer"].strip()
143
+ has_docs = isinstance(parsed.get("suspicious_docs"), list)
144
+ has_conf = isinstance(parsed.get("confidence"), (int, float))
145
+ rewards.append(0.1 if (has_answer and has_docs and has_conf) else 0.0)
146
+ return rewards
147
+
148
+
149
+ def correctness_reward(prompts, completions, ground_truth, corrupt_ids, **kwargs) -> list[float]:
150
+ """Main reward: calls compute_reward() from environment/reward.py."""
151
+ from environment.reward import compute_reward
152
+
153
+ rewards = []
154
+ for completion, gt, cids in zip(completions, ground_truth, corrupt_ids):
155
+ parsed = _parse_completion(completion)
156
+ if parsed is None:
157
+ rewards.append(0.0)
158
+ continue
159
+ answer = str(parsed.get("answer", "")).strip()
160
+ flagged = [int(x) for x in parsed.get("suspicious_docs", [])
161
+ if isinstance(x, (int, float))]
162
+ confidence = float(parsed.get("confidence", 0.5))
163
+ confidence = max(0.0, min(1.0, confidence))
164
+ cids_list = list(cids) if not isinstance(cids, list) else cids
165
+ reward, _ = compute_reward(
166
+ submitted_answer=answer,
167
+ ground_truth_answer=gt,
168
+ flagged_ids=flagged,
169
+ corrupt_ids=cids_list,
170
+ confidence=confidence,
171
+ budget_used=1,
172
+ max_budget=12,
173
+ )
174
+ rewards.append(float(reward))
175
+ return rewards
176
+
177
+
178
+ # ── Plot saving ────────────────────────────────────────────────────────────────
179
+
180
+ def save_training_plots(run_id: str):
181
+ """Download reward + loss curves from WandB and save to assets/."""
182
+ try:
183
+ import matplotlib
184
+ matplotlib.use("Agg")
185
+ import matplotlib.pyplot as plt
186
+ api = wandb.Api()
187
+ run = api.run(f"{WANDB_PROJECT}/{run_id}")
188
+ history = run.history(keys=["train/reward", "train/loss"], pandas=True)
189
+ assets = Path(__file__).parent.parent / "assets"
190
+ assets.mkdir(exist_ok=True)
191
+
192
+ if "train/reward" in history.columns:
193
+ fig, ax = plt.subplots(figsize=(8, 4))
194
+ ax.plot(history["_step"], history["train/reward"])
195
+ ax.set_xlabel("Training step")
196
+ ax.set_ylabel("Mean episode reward")
197
+ ax.set_title("GRPO Training Reward β€” Qwen2-1.5B")
198
+ ax.grid(True, alpha=0.3)
199
+ fig.tight_layout()
200
+ fig.savefig(assets / "reward_curve.png", dpi=150)
201
+ plt.close(fig)
202
+ print(f"Saved reward_curve.png")
203
+
204
+ if "train/loss" in history.columns:
205
+ fig, ax = plt.subplots(figsize=(8, 4))
206
+ ax.plot(history["_step"], history["train/loss"])
207
+ ax.set_xlabel("Training step")
208
+ ax.set_ylabel("GRPO loss")
209
+ ax.set_title("GRPO Training Loss β€” Qwen2-1.5B")
210
+ ax.grid(True, alpha=0.3)
211
+ fig.tight_layout()
212
+ fig.savefig(assets / "loss_curve.png", dpi=150)
213
+ plt.close(fig)
214
+ print(f"Saved loss_curve.png")
215
+ except Exception as e:
216
+ print(f"[warn] Could not save plots: {e}")
217
+
218
+
219
+ # ── Main ───────────────────────────────────────────────────────────────────────
220
+
221
+ def main():
222
+ # Guard: must have GPU
223
+ try:
224
+ import torch
225
+ if not torch.cuda.is_available():
226
+ print("[error] No GPU detected. Training requires CUDA. Exiting.")
227
+ sys.exit(1)
228
+ except ImportError:
229
+ pass
230
+
231
+ from unsloth import FastLanguageModel
232
+ from trl import GRPOTrainer, GRPOConfig
233
+
234
+ run = wandb.init(
235
+ project=WANDB_PROJECT,
236
+ name=WANDB_RUN_NAME,
237
+ config={
238
+ "model": MODEL_NAME,
239
+ "lora_r": LORA_R,
240
+ "epochs": NUM_TRAIN_EPOCHS,
241
+ "batch_size": BATCH_SIZE,
242
+ "grad_accum": GRAD_ACCUM,
243
+ "lr": LR,
244
+ "num_generations": NUM_GENERATIONS,
245
+ "n_train_episodes": N_TRAIN_EPISODES,
246
+ "seed": SEED,
247
+ },
248
+ )
249
+
250
+ print("Building training dataset...")
251
+ train_dataset = build_dataset(N_TRAIN_EPISODES, seed=SEED)
252
+ eval_dataset = build_dataset(N_EVAL_EPISODES, seed=SEED + 1)
253
+ print(f"Train: {len(train_dataset)} episodes | Eval: {len(eval_dataset)} episodes")
254
+
255
+ print("Loading model with Unsloth...")
256
+ model, tokenizer = FastLanguageModel.from_pretrained(
257
+ model_name=MODEL_NAME,
258
+ max_seq_length=MAX_SEQ_LENGTH,
259
+ load_in_4bit=LOAD_IN_4BIT,
260
+ )
261
+ model = FastLanguageModel.get_peft_model(
262
+ model,
263
+ r=LORA_R,
264
+ target_modules=LORA_TARGET_MODULES,
265
+ lora_dropout=0.0,
266
+ use_gradient_checkpointing="unsloth",
267
+ )
268
+
269
+ push_to_hub = bool(HF_HUB_MODEL_ID and os.getenv("HF_TOKEN"))
270
+
271
+ config = GRPOConfig(
272
+ output_dir=OUTPUT_DIR,
273
+ num_train_epochs=NUM_TRAIN_EPOCHS,
274
+ per_device_train_batch_size=BATCH_SIZE,
275
+ gradient_accumulation_steps=GRAD_ACCUM,
276
+ learning_rate=LR,
277
+ max_completion_length=MAX_COMPLETION_LENGTH,
278
+ num_generations=NUM_GENERATIONS,
279
+ report_to="wandb",
280
+ logging_steps=LOGGING_STEPS,
281
+ save_steps=SAVE_STEPS,
282
+ save_total_limit=2,
283
+ seed=SEED,
284
+ # Deployment logs: log completions to WandB every logging step
285
+ log_completions=True,
286
+ num_completions_to_print=2,
287
+ # Push to HF Hub if token provided
288
+ push_to_hub=push_to_hub,
289
+ hub_model_id=HF_HUB_MODEL_ID if push_to_hub else None,
290
+ hub_strategy="end",
291
+ bf16=True,
292
+ remove_unused_columns=False,
293
+ )
294
+
295
+ trainer = GRPOTrainer(
296
+ model=model,
297
+ args=config,
298
+ processing_class=tokenizer,
299
+ train_dataset=train_dataset,
300
+ eval_dataset=eval_dataset,
301
+ reward_funcs=[correctness_reward, format_reward],
302
+ )
303
+
304
+ print("Starting GRPO training...")
305
+ trainer.train()
306
+
307
+ print("Saving final model...")
308
+ model.save_pretrained(f"{OUTPUT_DIR}-final")
309
+ tokenizer.save_pretrained(f"{OUTPUT_DIR}-final")
310
+
311
+ if push_to_hub:
312
+ model.push_to_hub(HF_HUB_MODEL_ID)
313
+ tokenizer.push_to_hub(HF_HUB_MODEL_ID)
314
+ print(f"Model pushed to HF Hub: {HF_HUB_MODEL_ID}")
315
+
316
+ print("Saving training plots...")
317
+ save_training_plots(run.id)
318
+
319
+ wandb.finish()
320
+ print("Training complete.")
321
+
322
+
323
+ if __name__ == "__main__":
324
+ main()