kush5699 commited on
Commit
caa9970
·
verified ·
1 Parent(s): 593f876

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. env/environment.py +4 -4
  2. env/models.py +3 -3
  3. env/tasks.py +1 -1
  4. server/app.py +1 -1
env/environment.py CHANGED
@@ -32,7 +32,7 @@ class DataValidationEnvironment:
32
  max_steps=task["max_steps"],
33
  done=False,
34
  reward_history=[],
35
- cumulative_reward=0.0,
36
  dataset=task["dataset"],
37
  ground_truth=self._ground_truth,
38
  errors=self._errors,
@@ -51,8 +51,8 @@ class DataValidationEnvironment:
51
  errors_fixed=0,
52
  step_count=0,
53
  max_steps=task["max_steps"],
54
- reward=0.0,
55
- cumulative_reward=0.0,
56
  done=False,
57
  last_action_result="Environment reset. Examine errors and fix them.",
58
  task_hint=task["hint"],
@@ -62,7 +62,7 @@ class DataValidationEnvironment:
62
 
63
  def step(self, action: DataCleanAction) -> DataCleanObservation:
64
  if self._state.done:
65
- return self._make_observation(0.0, "Episode already done. Call reset().")
66
 
67
  self._state.step_count += 1
68
 
 
32
  max_steps=task["max_steps"],
33
  done=False,
34
  reward_history=[],
35
+ cumulative_reward=0.01,
36
  dataset=task["dataset"],
37
  ground_truth=self._ground_truth,
38
  errors=self._errors,
 
51
  errors_fixed=0,
52
  step_count=0,
53
  max_steps=task["max_steps"],
54
+ reward=0.01,
55
+ cumulative_reward=0.01,
56
  done=False,
57
  last_action_result="Environment reset. Examine errors and fix them.",
58
  task_hint=task["hint"],
 
62
 
63
  def step(self, action: DataCleanAction) -> DataCleanObservation:
64
  if self._state.done:
65
+ return self._make_observation(0.01, "Episode already done. Call reset().")
66
 
67
  self._state.step_count += 1
68
 
env/models.py CHANGED
@@ -19,8 +19,8 @@ class DataCleanObservation(BaseModel):
19
  errors_fixed: int = Field(default=0)
20
  step_count: int = Field(default=0)
21
  max_steps: int = Field(default=20)
22
- reward: float = Field(default=0.0)
23
- cumulative_reward: float = Field(default=0.0)
24
  done: bool = Field(default=False)
25
  last_action_result: str = Field(default="")
26
  task_hint: str = Field(default="")
@@ -41,7 +41,7 @@ class DataCleanState(BaseModel):
41
  max_steps: int = Field(default=20)
42
  done: bool = Field(default=False)
43
  reward_history: List[float] = Field(default_factory=list)
44
- cumulative_reward: float = Field(default=0.0)
45
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
46
  ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
47
  errors: List[Dict[str, Any]] = Field(default_factory=list)
 
19
  errors_fixed: int = Field(default=0)
20
  step_count: int = Field(default=0)
21
  max_steps: int = Field(default=20)
22
+ reward: float = Field(default=0.01)
23
+ cumulative_reward: float = Field(default=0.01)
24
  done: bool = Field(default=False)
25
  last_action_result: str = Field(default="")
26
  task_hint: str = Field(default="")
 
41
  max_steps: int = Field(default=20)
42
  done: bool = Field(default=False)
43
  reward_history: List[float] = Field(default_factory=list)
44
+ cumulative_reward: float = Field(default=0.01)
45
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
46
  ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
47
  errors: List[Dict[str, Any]] = Field(default_factory=list)
env/tasks.py CHANGED
@@ -250,7 +250,7 @@ def grade_action(action_type: str, target_field: str, target_row: int,
250
 
251
  expected_error_type = action_to_error_map.get(action_type, "")
252
  if expected_error_type != matching_error["error_type"]:
253
- return -0.05, f"Wrong action type '{action_type}' for error type '{matching_error['error_type']}'", False
254
 
255
  gt_value = ground_truth[target_row][target_field]
256
 
 
250
 
251
  expected_error_type = action_to_error_map.get(action_type, "")
252
  if expected_error_type != matching_error["error_type"]:
253
+ return 0.01, f"Wrong action type '{action_type}' for error type '{matching_error['error_type']}'", False
254
 
255
  gt_value = ground_truth[target_row][target_field]
256
 
server/app.py CHANGED
@@ -114,7 +114,7 @@ async def websocket_endpoint(websocket: WebSocket):
114
  response = {
115
  "type": "reset",
116
  "observation": obs.model_dump(),
117
- "reward": 0.0,
118
  "done": False,
119
  }
120
  elif method == "step":
 
114
  response = {
115
  "type": "reset",
116
  "observation": obs.model_dump(),
117
+ "reward": 0.01,
118
  "done": False,
119
  }
120
  elif method == "step":