Spaces:
Sleeping
Sleeping
fixed
Browse files- data_cleaning_env.py +11 -1
- inference.py +80 -63
data_cleaning_env.py
CHANGED
|
@@ -214,6 +214,16 @@ class DataCleaningEnvironment:
|
|
| 214 |
tasks_completed=[],
|
| 215 |
total_reward=self._total_reward
|
| 216 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
def _apply_action(self, action: CleaningAction) -> pd.DataFrame:
|
| 219 |
df = deepcopy(self._current_df)
|
|
@@ -223,7 +233,7 @@ class DataCleaningEnvironment:
|
|
| 223 |
col = action.column or "hire_date"
|
| 224 |
if col not in df.columns:
|
| 225 |
raise ValueError(f"Column '{col}' not found.")
|
| 226 |
-
df[col] =
|
| 227 |
elif action.action_type == "remove_outliers":
|
| 228 |
target_cols = [action.column] if action.column and action.column != "all" else ['salary', 'age']
|
| 229 |
for col in target_cols:
|
|
|
|
| 214 |
tasks_completed=[],
|
| 215 |
total_reward=self._total_reward
|
| 216 |
)
|
| 217 |
+
|
| 218 |
+
@staticmethod
|
| 219 |
+
def _parse_date(val):
|
| 220 |
+
for fmt in ["%Y-%m-%d", "%Y/%m/%d", "%Y.%m.%d",
|
| 221 |
+
"%m/%d/%Y", "%m-%d-%Y", "%d/%m/%Y", "%d-%m-%Y", "%d.%m.%Y"]:
|
| 222 |
+
try:
|
| 223 |
+
return datetime.strptime(str(val).strip(), fmt).strftime("%Y-%m-%d")
|
| 224 |
+
except:
|
| 225 |
+
continue
|
| 226 |
+
return None
|
| 227 |
|
| 228 |
def _apply_action(self, action: CleaningAction) -> pd.DataFrame:
|
| 229 |
df = deepcopy(self._current_df)
|
|
|
|
| 233 |
col = action.column or "hire_date"
|
| 234 |
if col not in df.columns:
|
| 235 |
raise ValueError(f"Column '{col}' not found.")
|
| 236 |
+
df[col] = df[col].apply(self._parse_date)
|
| 237 |
elif action.action_type == "remove_outliers":
|
| 238 |
target_cols = [action.column] if action.column and action.column != "all" else ['salary', 'age']
|
| 239 |
for col in target_cols:
|
inference.py
CHANGED
|
@@ -1,38 +1,28 @@
|
|
| 1 |
"""
|
| 2 |
-
inference.py β
|
| 3 |
-
=================================================
|
| 4 |
-
Runs all 3 tasks and produces reproducible scores.
|
| 5 |
-
|
| 6 |
-
Usage:
|
| 7 |
-
export API_BASE_URL=https://router.huggingface.co/v1
|
| 8 |
-
export MODEL_NAME=SohamK18/data-cleaning-grpo
|
| 9 |
-
export HF_TOKEN=hf_your_token_here
|
| 10 |
-
export ENV_URL=https://your-space.hf.space # your HF Space URL
|
| 11 |
-
python inference.py
|
| 12 |
"""
|
| 13 |
-
|
| 14 |
import os, json, sys
|
| 15 |
from openai import OpenAI
|
| 16 |
from data_cleaning_env import DataCleaningEnvironment, CleaningAction
|
| 17 |
|
| 18 |
-
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 20 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 21 |
-
MODEL_NAME
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
if not HF_TOKEN:
|
| 25 |
-
print("ERROR: HF_TOKEN not set.", file=sys.stderr)
|
| 26 |
-
sys.exit(1)
|
| 27 |
|
| 28 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 29 |
|
| 30 |
SYSTEM_PROMPT = (
|
| 31 |
"You are a data cleaning expert. "
|
| 32 |
"Respond ONLY with a valid JSON object, no markdown, no explanation.\n"
|
| 33 |
-
|
| 34 |
)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def parse_llm_response(text: str, task_id: int) -> CleaningAction:
|
| 38 |
text = text.strip().replace("```json", "").replace("```", "").strip()
|
|
@@ -41,7 +31,11 @@ def parse_llm_response(text: str, task_id: int) -> CleaningAction:
|
|
| 41 |
action_type = data.get("action_type", "remove_nulls")
|
| 42 |
if action_type not in ["remove_nulls", "fix_dates", "remove_outliers"]:
|
| 43 |
action_type = "remove_nulls"
|
| 44 |
-
return CleaningAction(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
except Exception:
|
| 46 |
if "date" in text.lower():
|
| 47 |
return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date")
|
|
@@ -50,9 +44,20 @@ def parse_llm_response(text: str, task_id: int) -> CleaningAction:
|
|
| 50 |
return CleaningAction(task_id=task_id, action_type="remove_nulls")
|
| 51 |
|
| 52 |
|
| 53 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
env = DataCleaningEnvironment(task_id=task_id, seed=seed)
|
| 55 |
obs = env.reset()
|
|
|
|
|
|
|
| 56 |
|
| 57 |
user_msg = (
|
| 58 |
f"Task {task_id}: {obs.task_description}\n"
|
|
@@ -62,66 +67,78 @@ def run_episode(task_id: int, seed: int) -> float:
|
|
| 62 |
f"Respond with JSON only."
|
| 63 |
)
|
| 64 |
|
|
|
|
| 65 |
try:
|
| 66 |
resp = client.chat.completions.create(
|
| 67 |
model=MODEL_NAME,
|
| 68 |
messages=[
|
| 69 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 70 |
-
{"role": "user", "content": user_msg}
|
| 71 |
],
|
| 72 |
max_tokens=100,
|
| 73 |
-
temperature=0.1
|
| 74 |
)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
# Heuristic fallback
|
| 79 |
-
if obs.null_count > 0:
|
| 80 |
-
action = CleaningAction(task_id=task_id, action_type="remove_nulls")
|
| 81 |
-
elif obs.date_format_errors > 0:
|
| 82 |
-
action = CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date")
|
| 83 |
-
else:
|
| 84 |
-
action = CleaningAction(task_id=task_id, action_type="remove_outliers", column="all")
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
|
|
|
|
|
|
|
| 94 |
all_results = {}
|
|
|
|
| 95 |
|
| 96 |
for task_id in [1, 2, 3]:
|
| 97 |
task_name = TASK_NAMES[task_id]
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
with open("scores.json", "w") as f:
|
| 121 |
json.dump({"tasks": all_results, "overall": overall}, f, indent=2)
|
| 122 |
-
|
| 123 |
print(f"[SUMMARY] overall_score={overall} task_scores={all_results}", flush=True)
|
| 124 |
|
| 125 |
|
| 126 |
if __name__ == "__main__":
|
| 127 |
-
main()
|
|
|
|
| 1 |
"""
|
| 2 |
+
inference.py β OpenEnv submission file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
|
|
|
| 4 |
import os, json, sys
|
| 5 |
from openai import OpenAI
|
| 6 |
from data_cleaning_env import DataCleaningEnvironment, CleaningAction
|
| 7 |
|
| 8 |
+
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 10 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 11 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-1.5B-Instruct")
|
| 12 |
+
if HF_TOKEN is None:
|
| 13 |
+
raise ValueError("HF_TOKEN environment variable is required")
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 16 |
|
| 17 |
SYSTEM_PROMPT = (
|
| 18 |
"You are a data cleaning expert. "
|
| 19 |
"Respond ONLY with a valid JSON object, no markdown, no explanation.\n"
|
| 20 |
+
'Format: {"action_type": "<remove_nulls|fix_dates|remove_outliers>", "column": "<col_or_null>"}'
|
| 21 |
)
|
| 22 |
|
| 23 |
+
TASK_NAMES = {1: "remove_nulls", 2: "fix_dates", 3: "remove_outliers"}
|
| 24 |
+
ENV_NAME = "data_cleaning"
|
| 25 |
+
|
| 26 |
|
| 27 |
def parse_llm_response(text: str, task_id: int) -> CleaningAction:
|
| 28 |
text = text.strip().replace("```json", "").replace("```", "").strip()
|
|
|
|
| 31 |
action_type = data.get("action_type", "remove_nulls")
|
| 32 |
if action_type not in ["remove_nulls", "fix_dates", "remove_outliers"]:
|
| 33 |
action_type = "remove_nulls"
|
| 34 |
+
return CleaningAction(
|
| 35 |
+
task_id=task_id,
|
| 36 |
+
action_type=action_type,
|
| 37 |
+
column=data.get("column")
|
| 38 |
+
)
|
| 39 |
except Exception:
|
| 40 |
if "date" in text.lower():
|
| 41 |
return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date")
|
|
|
|
| 44 |
return CleaningAction(task_id=task_id, action_type="remove_nulls")
|
| 45 |
|
| 46 |
|
| 47 |
+
def heuristic_action(task_id: int, obs) -> CleaningAction:
|
| 48 |
+
if obs.null_count > 0:
|
| 49 |
+
return CleaningAction(task_id=task_id, action_type="remove_nulls")
|
| 50 |
+
elif obs.date_format_errors > 0:
|
| 51 |
+
return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date")
|
| 52 |
+
else:
|
| 53 |
+
return CleaningAction(task_id=task_id, action_type="remove_outliers", column="all")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def run_episode(task_id: int, seed: int):
|
| 57 |
env = DataCleaningEnvironment(task_id=task_id, seed=seed)
|
| 58 |
obs = env.reset()
|
| 59 |
+
error_str = "null"
|
| 60 |
+
action = None
|
| 61 |
|
| 62 |
user_msg = (
|
| 63 |
f"Task {task_id}: {obs.task_description}\n"
|
|
|
|
| 67 |
f"Respond with JSON only."
|
| 68 |
)
|
| 69 |
|
| 70 |
+
# ββ Primary: LLM via OpenAI client βββββββββββββββββββββββββββββββββββββββ
|
| 71 |
try:
|
| 72 |
resp = client.chat.completions.create(
|
| 73 |
model=MODEL_NAME,
|
| 74 |
messages=[
|
| 75 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 76 |
+
{"role": "user", "content": user_msg},
|
| 77 |
],
|
| 78 |
max_tokens=100,
|
| 79 |
+
temperature=0.1,
|
| 80 |
)
|
| 81 |
+
action = parse_llm_response(resp.choices[0].message.content, task_id)
|
| 82 |
+
except Exception as e:
|
| 83 |
+
error_str = str(e).replace("\n", " ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
# ββ Fallback: heuristic if LLM failed ββββββββββββββββββββββββββββββββββββ
|
| 86 |
+
if action is None:
|
| 87 |
+
action = heuristic_action(task_id, obs)
|
| 88 |
|
| 89 |
+
col = action.column if action.column else "null"
|
| 90 |
+
action_str = f"{action.action_type}('{col}')"
|
| 91 |
|
| 92 |
+
_, reward, done, _ = env.step(action)
|
| 93 |
+
if hasattr(env, "close"):
|
| 94 |
+
env.close()
|
| 95 |
+
|
| 96 |
+
return float(reward), action_str, bool(done), error_str
|
| 97 |
|
| 98 |
+
|
| 99 |
+
def main():
|
| 100 |
all_results = {}
|
| 101 |
+
n_episodes = int(os.getenv("N_EPISODES", "10"))
|
| 102 |
|
| 103 |
for task_id in [1, 2, 3]:
|
| 104 |
task_name = TASK_NAMES[task_id]
|
| 105 |
+
print(f"[START] task={task_name} env={ENV_NAME} model={MODEL_NAME}", flush=True)
|
| 106 |
+
|
| 107 |
+
episode_rewards = []
|
| 108 |
+
success = False
|
| 109 |
+
score = 0.0
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
for seed in range(n_episodes):
|
| 113 |
+
reward, action_str, done, error_str = run_episode(task_id, seed)
|
| 114 |
+
episode_rewards.append(reward)
|
| 115 |
+
print(
|
| 116 |
+
f"[STEP] step={seed + 1} action={action_str} "
|
| 117 |
+
f"reward={reward:.2f} done={str(done).lower()} error={error_str}",
|
| 118 |
+
flush=True,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
score = sum(episode_rewards) / len(episode_rewards)
|
| 122 |
+
score = round(min(max(score, 0.0), 1.0), 2)
|
| 123 |
+
all_results[task_id] = score
|
| 124 |
+
success = score > 0.0
|
| 125 |
+
|
| 126 |
+
finally:
|
| 127 |
+
rewards_str = ",".join(f"{r:.2f}" for r in episode_rewards)
|
| 128 |
+
# ββ [END] with score= field as required ββββββββββββββββββββββββββ
|
| 129 |
+
print(
|
| 130 |
+
f"[END] success={str(success).lower()} "
|
| 131 |
+
f"steps={len(episode_rewards)} "
|
| 132 |
+
f"score={score:.2f} "
|
| 133 |
+
f"rewards={rewards_str}",
|
| 134 |
+
flush=True,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
overall = round(sum(all_results.values()) / max(len(all_results), 1), 4)
|
| 138 |
with open("scores.json", "w") as f:
|
| 139 |
json.dump({"tasks": all_results, "overall": overall}, f, indent=2)
|
|
|
|
| 140 |
print(f"[SUMMARY] overall_score={overall} task_scores={all_results}", flush=True)
|
| 141 |
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|
| 144 |
+
main()
|