harsharajkumar273 commited on
Commit
b7e4141
·
verified ·
1 Parent(s): cdb633a

Make manual seed input affect visible output

Browse files
cleanops_env/environment.py CHANGED
@@ -2,6 +2,8 @@
2
 
3
  from __future__ import annotations
4
 
 
 
5
  from uuid import uuid4
6
 
7
  from openenv.core.env_server.interfaces import Environment
@@ -50,6 +52,7 @@ class CleanOpsEnvironment(Environment[DataCleaningAction, DataCleaningObservatio
50
  task_id=self._task_spec.task_id,
51
  task_title=self._task_spec.title,
52
  difficulty=self._task_spec.difficulty,
 
53
  max_steps=self._task_spec.max_steps,
54
  submitted=False,
55
  current_score=self._grade.score,
@@ -69,10 +72,11 @@ class CleanOpsEnvironment(Environment[DataCleaningAction, DataCleaningObservatio
69
  task_id: str | None = None,
70
  **kwargs: object,
71
  ) -> DataCleaningObservation:
72
- del seed, kwargs
73
  selected_task_id = task_id or self._task_order[0]
74
  self._task_spec = get_task_spec(selected_task_id)
75
- self._focus_table_name = first_table_name(self._task_spec)
 
76
  self._focus_operation_detail = None
77
  self._done = False
78
  self._grade = grade_tables(self._task_spec, self._task_spec.dirty_tables)
@@ -83,6 +87,7 @@ class CleanOpsEnvironment(Environment[DataCleaningAction, DataCleaningObservatio
83
  task_id=self._task_spec.task_id,
84
  task_title=self._task_spec.title,
85
  difficulty=self._task_spec.difficulty,
 
86
  max_steps=self._task_spec.max_steps,
87
  submitted=False,
88
  current_score=self._grade.score,
@@ -92,7 +97,7 @@ class CleanOpsEnvironment(Environment[DataCleaningAction, DataCleaningObservatio
92
  applied_operation_ids=[],
93
  inspected_tables=[self._focus_table_name],
94
  inspected_operations=[],
95
- recent_history=[f"reset -> loaded task {self._task_spec.task_id} ({self._task_spec.difficulty})"],
96
  )
97
  return self._build_observation(
98
  reward_breakdown=RewardBreakdown(total=0.0),
@@ -273,6 +278,7 @@ class CleanOpsEnvironment(Environment[DataCleaningAction, DataCleaningObservatio
273
  task_id=self._task_spec.task_id,
274
  task_title=self._task_spec.title,
275
  difficulty=self._task_spec.difficulty,
 
276
  objective=self._task_spec.objective,
277
  dataset_context=self._task_spec.dataset_context,
278
  quality_score=self._state.current_score,
@@ -293,6 +299,7 @@ class CleanOpsEnvironment(Environment[DataCleaningAction, DataCleaningObservatio
293
  done=done,
294
  metadata={
295
  "episode_id": self._state.episode_id,
 
296
  "applied_operation_ids": list(self._state.applied_operation_ids),
297
  "submitted": self._state.submitted,
298
  },
@@ -300,10 +307,33 @@ class CleanOpsEnvironment(Environment[DataCleaningAction, DataCleaningObservatio
300
 
301
  def _build_table_view(self, task_spec: TaskSpec, table_name: str) -> TableView:
302
  primary_key = task_spec.primary_keys[table_name]
303
- rows = sorted_rows(self._state.tables.get(table_name, []), primary_key)
304
  columns = sorted({column_name for row in rows for column_name in row})
305
  return TableView(name=table_name, primary_key=primary_key, columns=columns, rows=rows)
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def _build_operation_detail(
308
  self,
309
  task_spec: TaskSpec,
 
2
 
3
  from __future__ import annotations
4
 
5
+ import copy
6
+ import random
7
  from uuid import uuid4
8
 
9
  from openenv.core.env_server.interfaces import Environment
 
52
  task_id=self._task_spec.task_id,
53
  task_title=self._task_spec.title,
54
  difficulty=self._task_spec.difficulty,
55
+ requested_seed=None,
56
  max_steps=self._task_spec.max_steps,
57
  submitted=False,
58
  current_score=self._grade.score,
 
72
  task_id: str | None = None,
73
  **kwargs: object,
74
  ) -> DataCleaningObservation:
75
+ del kwargs
76
  selected_task_id = task_id or self._task_order[0]
77
  self._task_spec = get_task_spec(selected_task_id)
78
+ normalized_seed = seed if seed is None else max(0, int(seed))
79
+ self._focus_table_name = self._choose_initial_focus_table(self._task_spec, normalized_seed)
80
  self._focus_operation_detail = None
81
  self._done = False
82
  self._grade = grade_tables(self._task_spec, self._task_spec.dirty_tables)
 
87
  task_id=self._task_spec.task_id,
88
  task_title=self._task_spec.title,
89
  difficulty=self._task_spec.difficulty,
90
+ requested_seed=normalized_seed,
91
  max_steps=self._task_spec.max_steps,
92
  submitted=False,
93
  current_score=self._grade.score,
 
97
  applied_operation_ids=[],
98
  inspected_tables=[self._focus_table_name],
99
  inspected_operations=[],
100
+ recent_history=[f"reset -> loaded task {self._task_spec.task_id} ({self._task_spec.difficulty}) seed={normalized_seed}"],
101
  )
102
  return self._build_observation(
103
  reward_breakdown=RewardBreakdown(total=0.0),
 
278
  task_id=self._task_spec.task_id,
279
  task_title=self._task_spec.title,
280
  difficulty=self._task_spec.difficulty,
281
+ requested_seed=self._state.requested_seed,
282
  objective=self._task_spec.objective,
283
  dataset_context=self._task_spec.dataset_context,
284
  quality_score=self._state.current_score,
 
299
  done=done,
300
  metadata={
301
  "episode_id": self._state.episode_id,
302
+ "requested_seed": self._state.requested_seed,
303
  "applied_operation_ids": list(self._state.applied_operation_ids),
304
  "submitted": self._state.submitted,
305
  },
 
307
 
308
  def _build_table_view(self, task_spec: TaskSpec, table_name: str) -> TableView:
309
  primary_key = task_spec.primary_keys[table_name]
310
+ rows = self._preview_rows(task_spec, table_name, self._state.tables.get(table_name, []))
311
  columns = sorted({column_name for row in rows for column_name in row})
312
  return TableView(name=table_name, primary_key=primary_key, columns=columns, rows=rows)
313
 
314
+ def _choose_initial_focus_table(self, task_spec: TaskSpec, seed: int | None) -> str:
315
+ table_names = sorted(task_spec.dirty_tables)
316
+ if not table_names:
317
+ return first_table_name(task_spec)
318
+ if seed is None:
319
+ return table_names[0]
320
+ return table_names[seed % len(table_names)]
321
+
322
+ def _preview_rows(
323
+ self,
324
+ task_spec: TaskSpec,
325
+ table_name: str,
326
+ rows: list[dict[str, str]],
327
+ ) -> list[dict[str, str]]:
328
+ primary_key = task_spec.primary_keys[table_name]
329
+ ordered_rows = sorted_rows(rows, primary_key)
330
+ seed = self._state.requested_seed
331
+ if seed is None or len(ordered_rows) <= 1:
332
+ return ordered_rows
333
+ shuffled_rows = copy.deepcopy(ordered_rows)
334
+ random.Random(seed + sum(ord(char) for char in table_name)).shuffle(shuffled_rows)
335
+ return shuffled_rows
336
+
337
  def _build_operation_detail(
338
  self,
339
  task_spec: TaskSpec,
cleanops_env/models.py CHANGED
@@ -114,6 +114,7 @@ class DataCleaningObservation(Observation):
114
  task_id: str = Field(..., description="Current task identifier.")
115
  task_title: str = Field(..., description="Human-readable task title.")
116
  difficulty: Literal["easy", "medium", "hard"] = Field(..., description="Task difficulty.")
 
117
  objective: str = Field(..., description="Concrete task objective.")
118
  dataset_context: str = Field(..., description="Why this dataset exists in the real world.")
119
  quality_score: float = Field(default=0.0, description="Current deterministic grader score.")
@@ -138,6 +139,7 @@ class DataCleaningState(State):
138
  task_id: str = Field(..., description="Current task identifier.")
139
  task_title: str = Field(..., description="Current task title.")
140
  difficulty: Literal["easy", "medium", "hard"] = Field(..., description="Current task difficulty.")
 
141
  max_steps: int = Field(..., description="Task step budget.")
142
  submitted: bool = Field(default=False, description="Whether submit was called.")
143
  current_score: float = Field(default=0.0, description="Current deterministic grader score.")
 
114
  task_id: str = Field(..., description="Current task identifier.")
115
  task_title: str = Field(..., description="Human-readable task title.")
116
  difficulty: Literal["easy", "medium", "hard"] = Field(..., description="Task difficulty.")
117
+ requested_seed: int | None = Field(default=None, description="Seed used when resetting the current episode.")
118
  objective: str = Field(..., description="Concrete task objective.")
119
  dataset_context: str = Field(..., description="Why this dataset exists in the real world.")
120
  quality_score: float = Field(default=0.0, description="Current deterministic grader score.")
 
139
  task_id: str = Field(..., description="Current task identifier.")
140
  task_title: str = Field(..., description="Current task title.")
141
  difficulty: Literal["easy", "medium", "hard"] = Field(..., description="Current task difficulty.")
142
+ requested_seed: int | None = Field(default=None, description="Seed used when resetting the current episode.")
143
  max_steps: int = Field(..., description="Task step budget.")
144
  submitted: bool = Field(default=False, description="Whether submit was called.")
145
  current_score: float = Field(default=0.0, description="Current deterministic grader score.")
server/app.py CHANGED
@@ -2,6 +2,9 @@
2
 
3
  from __future__ import annotations
4
 
 
 
 
5
  from openenv.core import create_app
6
  from fastapi.responses import HTMLResponse, JSONResponse
7
 
@@ -20,18 +23,19 @@ app = create_app(
20
 
21
 
22
  @app.get("/demo/compare", include_in_schema=False)
23
- def demo_compare(task_id: str = "customer_contacts_easy", table_name: str | None = None) -> JSONResponse:
24
  task_spec = get_task_spec(task_id)
25
  selected_table = table_name if table_name in task_spec.dirty_tables else first_table_name(task_spec)
26
  primary_key = task_spec.primary_keys[selected_table]
27
- before_rows = sorted_rows(task_spec.dirty_tables[selected_table], primary_key)
28
- after_rows = sorted_rows(task_spec.gold_tables[selected_table], primary_key)
29
  columns = sorted({column_name for row in before_rows + after_rows for column_name in row})
30
  return JSONResponse(
31
  {
32
  "task_id": task_spec.task_id,
33
  "task_title": task_spec.title,
34
  "table_name": selected_table,
 
35
  "available_tables": list(task_spec.dirty_tables.keys()),
36
  "columns": columns,
37
  "before_rows": before_rows[:4],
@@ -43,6 +47,20 @@ def demo_compare(task_id: str = "customer_contacts_easy", table_name: str | None
43
  )
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  @app.get("/", include_in_schema=False)
47
  def root() -> HTMLResponse:
48
  return HTMLResponse(
@@ -718,7 +736,7 @@ def root() -> HTMLResponse:
718
  <p>
719
  The cards and table below are populated from a real
720
  <code>POST /reset</code> response. Use the task buttons above to
721
- switch between benchmark scenarios.
722
  </p>
723
 
724
  <div class="kpis">
@@ -1039,12 +1057,12 @@ def root() -> HTMLResponse:
1039
  const rows = (observation.focus_table?.rows || []).slice(0, 4);
1040
  renderTable(columns, rows);
1041
 
1042
- const compareResponse = await fetch(`/demo/compare?task_id=${encodeURIComponent(taskId)}&table_name=${encodeURIComponent(observation.focus_table?.name || "")}`);
1043
  if (!compareResponse.ok) {
1044
  throw new Error(`Compare HTTP ${compareResponse.status}`);
1045
  }
1046
  const comparePayload = await compareResponse.json();
1047
- compareMetaEl.textContent = `${comparePayload.task_title} • table: ${comparePayload.table_name}`;
1048
  beforeMetaEl.textContent = `${comparePayload.before_row_count} rows`;
1049
  afterMetaEl.textContent = `${comparePayload.after_row_count} rows`;
1050
  renderTableTo(beforeHeadRowEl, beforeBodyEl, comparePayload.columns || [], comparePayload.before_rows || []);
@@ -1056,6 +1074,7 @@ def root() -> HTMLResponse:
1056
  outputEl.textContent = JSON.stringify(
1057
  {
1058
  task_id: observation.task_id,
 
1059
  difficulty: observation.difficulty,
1060
  objective: observation.objective,
1061
  quality_score: observation.quality_score,
@@ -1067,7 +1086,7 @@ def root() -> HTMLResponse:
1067
  null,
1068
  2
1069
  );
1070
- setRunFeedback("success", `Loaded ${observation.task_title || taskId} successfully.`);
1071
  } catch (error) {
1072
  outputEl.textContent = `Request failed: ${error.message}`;
1073
  objectiveEl.textContent = "Request failed.";
 
2
 
3
  from __future__ import annotations
4
 
5
+ import copy
6
+ import random
7
+
8
  from openenv.core import create_app
9
  from fastapi.responses import HTMLResponse, JSONResponse
10
 
 
23
 
24
 
25
  @app.get("/demo/compare", include_in_schema=False)
26
+ def demo_compare(task_id: str = "customer_contacts_easy", table_name: str | None = None, seed: int | None = None) -> JSONResponse:
27
  task_spec = get_task_spec(task_id)
28
  selected_table = table_name if table_name in task_spec.dirty_tables else first_table_name(task_spec)
29
  primary_key = task_spec.primary_keys[selected_table]
30
+ before_rows = _seed_preview_rows(task_spec.dirty_tables[selected_table], primary_key, selected_table, seed)
31
+ after_rows = _seed_preview_rows(task_spec.gold_tables[selected_table], primary_key, selected_table, seed)
32
  columns = sorted({column_name for row in before_rows + after_rows for column_name in row})
33
  return JSONResponse(
34
  {
35
  "task_id": task_spec.task_id,
36
  "task_title": task_spec.title,
37
  "table_name": selected_table,
38
+ "requested_seed": seed,
39
  "available_tables": list(task_spec.dirty_tables.keys()),
40
  "columns": columns,
41
  "before_rows": before_rows[:4],
 
47
  )
48
 
49
 
50
+ def _seed_preview_rows(
51
+ rows: list[dict[str, str]],
52
+ primary_key: str,
53
+ table_name: str,
54
+ seed: int | None,
55
+ ) -> list[dict[str, str]]:
56
+ ordered_rows = sorted_rows(rows, primary_key)
57
+ if seed is None or len(ordered_rows) <= 1:
58
+ return ordered_rows
59
+ shuffled_rows = copy.deepcopy(ordered_rows)
60
+ random.Random(max(0, int(seed)) + sum(ord(char) for char in table_name)).shuffle(shuffled_rows)
61
+ return shuffled_rows
62
+
63
+
64
  @app.get("/", include_in_schema=False)
65
  def root() -> HTMLResponse:
66
  return HTMLResponse(
 
736
  <p>
737
  The cards and table below are populated from a real
738
  <code>POST /reset</code> response. Use the task buttons above to
739
+ switch between benchmark scenarios, or choose your own task and seed.
740
  </p>
741
 
742
  <div class="kpis">
 
1057
  const rows = (observation.focus_table?.rows || []).slice(0, 4);
1058
  renderTable(columns, rows);
1059
 
1060
+ const compareResponse = await fetch(`/demo/compare?task_id=${encodeURIComponent(taskId)}&table_name=${encodeURIComponent(observation.focus_table?.name || "")}&seed=${encodeURIComponent(String(seed))}`);
1061
  if (!compareResponse.ok) {
1062
  throw new Error(`Compare HTTP ${compareResponse.status}`);
1063
  }
1064
  const comparePayload = await compareResponse.json();
1065
+ compareMetaEl.textContent = `${comparePayload.task_title} • table: ${comparePayload.table_name} • seed: ${comparePayload.requested_seed ?? seed}`;
1066
  beforeMetaEl.textContent = `${comparePayload.before_row_count} rows`;
1067
  afterMetaEl.textContent = `${comparePayload.after_row_count} rows`;
1068
  renderTableTo(beforeHeadRowEl, beforeBodyEl, comparePayload.columns || [], comparePayload.before_rows || []);
 
1074
  outputEl.textContent = JSON.stringify(
1075
  {
1076
  task_id: observation.task_id,
1077
+ requested_seed: observation.requested_seed ?? seed,
1078
  difficulty: observation.difficulty,
1079
  objective: observation.objective,
1080
  quality_score: observation.quality_score,
 
1086
  null,
1087
  2
1088
  );
1089
+ setRunFeedback("success", `Loaded ${observation.task_title || taskId} successfully with seed ${observation.requested_seed ?? seed}.`);
1090
  } catch (error) {
1091
  outputEl.textContent = `Request failed: ${error.message}`;
1092
  objectiveEl.textContent = "Request failed.";
tests/test_environment.py CHANGED
@@ -10,6 +10,7 @@ def test_reset_step_state_api() -> None:
10
  env = LocalCleanOpsEnv()
11
  observation = env.reset(task_id="customer_contacts_easy", seed=7)
12
  assert observation.task_id == "customer_contacts_easy"
 
13
  assert observation.done is False
14
  assert observation.quality_score < 1.0
15
 
@@ -45,3 +46,16 @@ def test_decoy_operation_lowers_easy_task_quality() -> None:
45
  damaged_grade = grade_tables(task_spec, damaged_tables)
46
  assert clean_grade.score == 1.0
47
  assert damaged_grade.score < clean_grade.score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  env = LocalCleanOpsEnv()
11
  observation = env.reset(task_id="customer_contacts_easy", seed=7)
12
  assert observation.task_id == "customer_contacts_easy"
13
+ assert observation.requested_seed == 7
14
  assert observation.done is False
15
  assert observation.quality_score < 1.0
16
 
 
46
  damaged_grade = grade_tables(task_spec, damaged_tables)
47
  assert clean_grade.score == 1.0
48
  assert damaged_grade.score < clean_grade.score
49
+
50
+
51
+ def test_seed_changes_visible_preview_rows() -> None:
52
+ env = LocalCleanOpsEnv()
53
+ observation_seed_2 = env.reset(task_id="customer_contacts_easy", seed=2)
54
+ preview_seed_2 = [row["customer_id"] for row in observation_seed_2.focus_table.rows[:4]]
55
+
56
+ observation_seed_7 = env.reset(task_id="customer_contacts_easy", seed=7)
57
+ preview_seed_7 = [row["customer_id"] for row in observation_seed_7.focus_table.rows[:4]]
58
+
59
+ assert observation_seed_2.requested_seed == 2
60
+ assert observation_seed_7.requested_seed == 7
61
+ assert preview_seed_2 != preview_seed_7