GlitchGhost Claude Opus 4.6 commited on
Commit
185c876
·
1 Parent(s): b62a150

Fix Docker build + rewrite inference.py to follow OpenEnv sample pattern

Browse files

- Dockerfile: python:3.11-slim -> python:3.12-slim (more reliable registry pull)
- inference.py: Use DataCleanEnv client with from_docker_image() support
- inference.py: Support LOCAL_IMAGE_NAME env var for validator
- inference.py: Use HF router as default API_BASE_URL
- inference.py: Keep [START]/[STEP]/[END] structured output markers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. Dockerfile +3 -3
  2. inference.py +86 -97
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.11-slim
2
 
3
  # Create non-root user (HF Spaces requirement)
4
  RUN useradd -m -u 1000 user
@@ -7,11 +7,11 @@ ENV HOME=/home/user \
7
 
8
  WORKDIR /app
9
 
10
- # Install dependencies
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
14
- # Copy source
15
  COPY . .
16
 
17
  # Install the package
 
1
+ FROM python:3.12-slim
2
 
3
  # Create non-root user (HF Spaces requirement)
4
  RUN useradd -m -u 1000 user
 
7
 
8
  WORKDIR /app
9
 
10
+ # Copy and install dependencies
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
14
+ # Copy source code
15
  COPY . .
16
 
17
  # Install the package
inference.py CHANGED
@@ -1,13 +1,15 @@
1
  """
2
  Inference Script — DataClean Environment
3
  =========================================
4
- MANDATORY:
5
- - Before submitting, ensure the following variables are defined:
6
- API_BASE_URL The API endpoint for the LLM.
7
  MODEL_NAME The model identifier to use for inference.
8
- HF_TOKEN Your Hugging Face / API key.
9
- - This script must be named `inference.py` and placed in the root directory.
10
- - Uses OpenAI Client for all LLM calls.
 
 
 
11
  """
12
 
13
  import json
@@ -17,56 +19,16 @@ import sys
17
  import textwrap
18
  import time
19
 
20
- import requests
21
  from openai import OpenAI
22
 
23
-
24
- class _StepResult:
25
- def __init__(self, observation: dict, reward: float, done: bool):
26
- self.observation = observation
27
- self.reward = reward
28
- self.done = done
29
-
30
-
31
- class _SimpleClient:
32
- """Minimal sync HTTP client for the DataClean environment."""
33
-
34
- def __init__(self, base_url: str):
35
- self.base_url = base_url.rstrip("/")
36
- self.s = requests.Session()
37
-
38
- def _post(self, path: str, payload: dict) -> dict:
39
- """POST with retry on transient errors."""
40
- for attempt in range(3):
41
- try:
42
- r = self.s.post(f"{self.base_url}{path}", json=payload, timeout=60)
43
- r.raise_for_status()
44
- return r.json()
45
- except (requests.ConnectionError, requests.Timeout) as exc:
46
- if attempt < 2:
47
- time.sleep(2 ** attempt)
48
- continue
49
- raise
50
-
51
- def reset(self, task_name: str = "easy") -> _StepResult:
52
- d = self._post("/reset", {"task_name": task_name})
53
- return _StepResult(d.get("observation", {}), float(d.get("reward", 0)), bool(d.get("done", False)))
54
-
55
- def step(self, action: dict) -> _StepResult:
56
- d = self._post("/step", action)
57
- return _StepResult(d.get("observation", {}), float(d.get("reward", 0)), bool(d.get("done", False)))
58
-
59
- def close(self):
60
- self.s.close()
61
-
62
-
63
  # ---------------------------------------------------------------------------
64
  # Configuration
65
  # ---------------------------------------------------------------------------
66
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
67
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
68
- MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
69
 
 
70
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://glitchghost-dataclean-openenv.hf.space")
71
 
72
  MAX_STEPS_PER_TASK = {"easy": 12, "medium": 20, "hard": 30}
@@ -110,12 +72,27 @@ RULES:
110
  """).strip()
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # ---------------------------------------------------------------------------
114
  # Helpers
115
  # ---------------------------------------------------------------------------
116
  ACTION_JSON_RE = re.compile(r"\{[^{}]*\}", re.DOTALL)
117
- # Also match JSON that may span multiple lines or have nested content
118
- ACTION_JSON_GREEDY_RE = re.compile(r"\{.*?\}", re.DOTALL)
119
 
120
 
121
  def parse_action(text: str) -> dict:
@@ -133,31 +110,48 @@ def parse_action(text: str) -> dict:
133
  except (json.JSONDecodeError, ValueError):
134
  pass
135
  # Try regex extraction
136
- for pattern in [ACTION_JSON_RE, ACTION_JSON_GREEDY_RE]:
137
- for m in pattern.finditer(cleaned):
138
- try:
139
- obj = json.loads(m.group(0))
140
- if isinstance(obj, dict) and "action_type" in obj:
141
- return obj
142
- except (json.JSONDecodeError, ValueError):
143
- continue
144
  return {"action_type": "noop"}
145
 
146
 
147
- def build_user_prompt(obs: dict, step_num: int) -> str:
148
  """Build the user prompt from the observation."""
149
- parts = [
150
- f"TASK: {obs.get('task_description', '')}",
151
- f"DIFFICULTY: {obs.get('difficulty', '')}",
152
- f"STEP: {step_num}/{obs.get('max_steps', '?')}",
153
- f"CURRENT SCORE: {obs.get('current_score', 0.0)}",
154
- "",
155
- "CURRENT DATA:",
156
- obs.get("data_preview", "(no data)"),
157
- "",
158
- obs.get("quality_report", ""),
159
- ]
160
- history = obs.get("action_history", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  if history:
162
  parts.append("")
163
  parts.append("RECENT ACTIONS:")
@@ -174,7 +168,7 @@ def build_user_prompt(obs: dict, step_num: int) -> str:
174
  # ---------------------------------------------------------------------------
175
  def run_task(
176
  llm_client: OpenAI,
177
- env_client: _SimpleClient,
178
  task_name: str,
179
  max_steps: int,
180
  ) -> float:
@@ -182,19 +176,12 @@ def run_task(
182
  # Structured output: START marker (required by validator)
183
  print(f"[START] task={task_name}", flush=True)
184
 
185
- print(f"\n{'='*60}", flush=True)
186
- print(f" TASK: {task_name.upper()}", flush=True)
187
- print(f"{'='*60}", flush=True)
188
-
189
  result = env_client.reset(task_name)
190
  obs = result.observation
191
- print(f" Task: {obs.get('task_description', '')[:80]}...", flush=True)
192
- print(f" Max steps: {max_steps}", flush=True)
193
 
194
  step_count = 0
195
  for step in range(1, max_steps + 1):
196
  if result.done:
197
- print(f" Episode done at step {step - 1}", flush=True)
198
  break
199
 
200
  user_prompt = build_user_prompt(obs, step)
@@ -203,6 +190,7 @@ def run_task(
203
  {"role": "user", "content": user_prompt},
204
  ]
205
 
 
206
  for _attempt in range(3):
207
  try:
208
  completion = llm_client.chat.completions.create(
@@ -224,18 +212,22 @@ def run_task(
224
  response_text = '{"action_type": "noop"}'
225
  break
226
 
227
- action = parse_action(response_text)
228
- print(f" Step {step}: {action.get('action_type', '?')}", end="", flush=True)
229
- if action.get("row_index") is not None:
230
- print(f" row={action['row_index']}", end="", flush=True)
231
- if action.get("column_name"):
232
- print(f" col={action['column_name']}", end="", flush=True)
233
- if action.get("new_value"):
234
- print(f" val={action['new_value']}", end="", flush=True)
235
-
 
 
 
236
  result = env_client.step(action)
237
  obs = result.observation
238
  step_count = step
 
239
  print(f" -> reward={result.reward:.4f} done={result.done}", flush=True)
240
 
241
  # Structured output: STEP marker (required by validator)
@@ -247,11 +239,12 @@ def run_task(
247
  # If agent never submitted, force submit
248
  if not result.done:
249
  step_count += 1
250
- result = env_client.step({"action_type": "submit"})
 
251
  print(f"[STEP] step={step_count} reward={result.reward:.4f}", flush=True)
252
 
253
  final_score = result.reward
254
- print(f"\n FINAL SCORE ({task_name}): {final_score:.4f}", flush=True)
255
 
256
  # Structured output: END marker (required by validator)
257
  print(f"[END] task={task_name} score={final_score:.4f} steps={step_count}", flush=True)
@@ -266,17 +259,13 @@ def main() -> None:
266
  if not API_KEY:
267
  print("ERROR: HF_TOKEN or API_KEY environment variable not set", flush=True)
268
  sys.exit(1)
269
- if not MODEL_NAME:
270
- print("ERROR: MODEL_NAME environment variable not set", flush=True)
271
- sys.exit(1)
272
 
273
- print("DataClean Environment Baseline Inference", flush=True)
274
  print(f" API: {API_BASE_URL}", flush=True)
275
  print(f" Model: {MODEL_NAME}", flush=True)
276
- print(f" Env: {ENV_BASE_URL}", flush=True)
277
 
278
  llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
279
- env_client = _SimpleClient(ENV_BASE_URL)
280
 
281
  scores = {}
282
  try:
 
1
  """
2
  Inference Script — DataClean Environment
3
  =========================================
4
+ MANDATORY environment variables:
5
+ API_BASE_URL The API endpoint for the LLM (default: HF router).
 
6
  MODEL_NAME The model identifier to use for inference.
7
+ HF_TOKEN Your Hugging Face / API key (no default).
8
+ OPTIONAL:
9
+ LOCAL_IMAGE_NAME Docker image to use with from_docker_image().
10
+ ENV_BASE_URL Direct URL if environment is already running.
11
+
12
+ Uses OpenAI Client for all LLM calls.
13
  """
14
 
15
  import json
 
19
  import textwrap
20
  import time
21
 
 
22
  from openai import OpenAI
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # ---------------------------------------------------------------------------
25
  # Configuration
26
  # ---------------------------------------------------------------------------
27
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/novita/v3/openai")
28
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
29
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/llama-3.3-70b-instruct")
30
 
31
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "")
32
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://glitchghost-dataclean-openenv.hf.space")
33
 
34
  MAX_STEPS_PER_TASK = {"easy": 12, "medium": 20, "hard": 30}
 
72
  """).strip()
73
 
74
 
75
+ # ---------------------------------------------------------------------------
76
+ # Environment client helpers
77
+ # ---------------------------------------------------------------------------
78
+ def _connect_env():
79
+ """Connect to the DataClean environment using the best available method."""
80
+ from dataclean_env.client import DataCleanEnv
81
+
82
+ # Option 1: Spin up from a local Docker image (validator may set this)
83
+ if LOCAL_IMAGE_NAME:
84
+ print(f" Starting environment from Docker image: {LOCAL_IMAGE_NAME}", flush=True)
85
+ return DataCleanEnv.from_docker_image(image=LOCAL_IMAGE_NAME)
86
+
87
+ # Option 2: Connect to a running instance (HF Space or local)
88
+ print(f" Connecting to environment at: {ENV_BASE_URL}", flush=True)
89
+ return DataCleanEnv(base_url=ENV_BASE_URL)
90
+
91
+
92
  # ---------------------------------------------------------------------------
93
  # Helpers
94
  # ---------------------------------------------------------------------------
95
  ACTION_JSON_RE = re.compile(r"\{[^{}]*\}", re.DOTALL)
 
 
96
 
97
 
98
  def parse_action(text: str) -> dict:
 
110
  except (json.JSONDecodeError, ValueError):
111
  pass
112
  # Try regex extraction
113
+ for m in ACTION_JSON_RE.finditer(cleaned):
114
+ try:
115
+ obj = json.loads(m.group(0))
116
+ if isinstance(obj, dict) and "action_type" in obj:
117
+ return obj
118
+ except (json.JSONDecodeError, ValueError):
119
+ continue
 
120
  return {"action_type": "noop"}
121
 
122
 
123
+ def build_user_prompt(obs, step_num: int) -> str:
124
  """Build the user prompt from the observation."""
125
+ # obs can be a DataCleanObservation object or a dict
126
+ if hasattr(obs, "task_description"):
127
+ # It's a DataCleanObservation object
128
+ parts = [
129
+ f"TASK: {obs.task_description}",
130
+ f"DIFFICULTY: {obs.difficulty}",
131
+ f"STEP: {step_num}/{obs.max_steps}",
132
+ f"CURRENT SCORE: {obs.current_score}",
133
+ "",
134
+ "CURRENT DATA:",
135
+ obs.data_preview or "(no data)",
136
+ "",
137
+ obs.quality_report or "",
138
+ ]
139
+ history = obs.action_history or []
140
+ else:
141
+ # It's a dict
142
+ parts = [
143
+ f"TASK: {obs.get('task_description', '')}",
144
+ f"DIFFICULTY: {obs.get('difficulty', '')}",
145
+ f"STEP: {step_num}/{obs.get('max_steps', '?')}",
146
+ f"CURRENT SCORE: {obs.get('current_score', 0.0)}",
147
+ "",
148
+ "CURRENT DATA:",
149
+ obs.get("data_preview", "(no data)"),
150
+ "",
151
+ obs.get("quality_report", ""),
152
+ ]
153
+ history = obs.get("action_history", [])
154
+
155
  if history:
156
  parts.append("")
157
  parts.append("RECENT ACTIONS:")
 
168
  # ---------------------------------------------------------------------------
169
  def run_task(
170
  llm_client: OpenAI,
171
+ env_client,
172
  task_name: str,
173
  max_steps: int,
174
  ) -> float:
 
176
  # Structured output: START marker (required by validator)
177
  print(f"[START] task={task_name}", flush=True)
178
 
 
 
 
 
179
  result = env_client.reset(task_name)
180
  obs = result.observation
 
 
181
 
182
  step_count = 0
183
  for step in range(1, max_steps + 1):
184
  if result.done:
 
185
  break
186
 
187
  user_prompt = build_user_prompt(obs, step)
 
190
  {"role": "user", "content": user_prompt},
191
  ]
192
 
193
+ response_text = '{"action_type": "noop"}'
194
  for _attempt in range(3):
195
  try:
196
  completion = llm_client.chat.completions.create(
 
212
  response_text = '{"action_type": "noop"}'
213
  break
214
 
215
+ action_dict = parse_action(response_text)
216
+ print(f" Step {step}: {action_dict.get('action_type', '?')}", end="", flush=True)
217
+ if action_dict.get("row_index") is not None:
218
+ print(f" row={action_dict['row_index']}", end="", flush=True)
219
+ if action_dict.get("column_name"):
220
+ print(f" col={action_dict['column_name']}", end="", flush=True)
221
+ if action_dict.get("new_value"):
222
+ print(f" val={action_dict['new_value']}", end="", flush=True)
223
+
224
+ # Step the environment using the proper client
225
+ from dataclean_env.models import DataCleanAction
226
+ action = DataCleanAction(**action_dict)
227
  result = env_client.step(action)
228
  obs = result.observation
229
  step_count = step
230
+
231
  print(f" -> reward={result.reward:.4f} done={result.done}", flush=True)
232
 
233
  # Structured output: STEP marker (required by validator)
 
239
  # If agent never submitted, force submit
240
  if not result.done:
241
  step_count += 1
242
+ from dataclean_env.models import DataCleanAction
243
+ result = env_client.step(DataCleanAction(action_type="submit"))
244
  print(f"[STEP] step={step_count} reward={result.reward:.4f}", flush=True)
245
 
246
  final_score = result.reward
247
+ print(f" FINAL SCORE ({task_name}): {final_score:.4f}", flush=True)
248
 
249
  # Structured output: END marker (required by validator)
250
  print(f"[END] task={task_name} score={final_score:.4f} steps={step_count}", flush=True)
 
259
  if not API_KEY:
260
  print("ERROR: HF_TOKEN or API_KEY environment variable not set", flush=True)
261
  sys.exit(1)
 
 
 
262
 
263
+ print("DataClean Environment - Inference", flush=True)
264
  print(f" API: {API_BASE_URL}", flush=True)
265
  print(f" Model: {MODEL_NAME}", flush=True)
 
266
 
267
  llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
268
+ env_client = _connect_env()
269
 
270
  scores = {}
271
  try: