GlitchGhost Claude Opus 4.6 commited on
Commit
78c6c35
·
1 Parent(s): bc2c650

Improve inference script robustness and update defaults

Browse files

- Use Groq API with Llama 3.3 70B as default (faster, best scores)
- Add retry with backoff for rate limits (429) and connection errors
- Improve JSON parsing: handle markdown fences, try full-text parse first
- Set ENV_BASE_URL default to live HF Space URL

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. inference.py +62 -37
inference.py CHANGED
@@ -15,14 +15,10 @@ import os
15
  import re
16
  import sys
17
  import textwrap
18
- from typing import List, Optional
19
 
20
- from openai import OpenAI
21
-
22
- # ---------------------------------------------------------------------------
23
- # Inline client (HTTP) so inference.py is self-contained
24
- # ---------------------------------------------------------------------------
25
  import requests
 
26
 
27
 
28
  class _StepResult:
@@ -39,16 +35,25 @@ class _SimpleClient:
39
  self.base_url = base_url.rstrip("/")
40
  self.s = requests.Session()
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def reset(self, task_name: str = "easy") -> _StepResult:
43
- r = self.s.post(f"{self.base_url}/reset", json={"task_name": task_name}, timeout=30)
44
- r.raise_for_status()
45
- d = r.json()
46
  return _StepResult(d.get("observation", {}), float(d.get("reward", 0)), bool(d.get("done", False)))
47
 
48
  def step(self, action: dict) -> _StepResult:
49
- r = self.s.post(f"{self.base_url}/step", json=action, timeout=30)
50
- r.raise_for_status()
51
- d = r.json()
52
  return _StepResult(d.get("observation", {}), float(d.get("reward", 0)), bool(d.get("done", False)))
53
 
54
  def close(self):
@@ -58,12 +63,11 @@ class _SimpleClient:
58
  # ---------------------------------------------------------------------------
59
  # Configuration
60
  # ---------------------------------------------------------------------------
61
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
62
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
63
- MODEL_NAME = os.getenv("MODEL_NAME")
64
 
65
- # Where the DataClean env server is running
66
- ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
67
 
68
  MAX_STEPS_PER_TASK = {"easy": 12, "medium": 20, "hard": 30}
69
  TEMPERATURE = 0.1
@@ -109,21 +113,34 @@ RULES:
109
  # ---------------------------------------------------------------------------
110
  # Helpers
111
  # ---------------------------------------------------------------------------
112
- ACTION_JSON_RE = re.compile(r"\{[^}]+\}", re.DOTALL)
 
 
113
 
114
 
115
  def parse_action(text: str) -> dict:
116
  """Extract the first JSON object from the model response."""
117
  if not text:
118
  return {"action_type": "noop"}
119
- m = ACTION_JSON_RE.search(text)
120
- if m:
121
- try:
122
- obj = json.loads(m.group(0))
123
- if "action_type" in obj:
124
- return obj
125
- except json.JSONDecodeError:
126
- pass
 
 
 
 
 
 
 
 
 
 
 
127
  return {"action_type": "noop"}
128
 
129
 
@@ -182,18 +199,26 @@ def run_task(
182
  {"role": "user", "content": user_prompt},
183
  ]
184
 
185
- try:
186
- completion = llm_client.chat.completions.create(
187
- model=MODEL_NAME,
188
- messages=messages,
189
- temperature=TEMPERATURE,
190
- max_tokens=MAX_TOKENS,
191
- stream=False,
192
- )
193
- response_text = completion.choices[0].message.content or ""
194
- except Exception as exc:
195
- print(f" Step {step}: LLM error ({exc}), using noop")
196
- response_text = '{"action_type": "noop"}'
 
 
 
 
 
 
 
 
197
 
198
  action = parse_action(response_text)
199
  print(f" Step {step}: {action.get('action_type', '?')}", end="")
 
15
  import re
16
  import sys
17
  import textwrap
18
+ import time
19
 
 
 
 
 
 
20
  import requests
21
+ from openai import OpenAI
22
 
23
 
24
  class _StepResult:
 
35
  self.base_url = base_url.rstrip("/")
36
  self.s = requests.Session()
37
 
38
+ def _post(self, path: str, payload: dict) -> dict:
39
+ """POST with retry on transient errors."""
40
+ for attempt in range(3):
41
+ try:
42
+ r = self.s.post(f"{self.base_url}{path}", json=payload, timeout=60)
43
+ r.raise_for_status()
44
+ return r.json()
45
+ except (requests.ConnectionError, requests.Timeout) as exc:
46
+ if attempt < 2:
47
+ time.sleep(2 ** attempt)
48
+ continue
49
+ raise
50
+
51
  def reset(self, task_name: str = "easy") -> _StepResult:
52
+ d = self._post("/reset", {"task_name": task_name})
 
 
53
  return _StepResult(d.get("observation", {}), float(d.get("reward", 0)), bool(d.get("done", False)))
54
 
55
  def step(self, action: dict) -> _StepResult:
56
+ d = self._post("/step", action)
 
 
57
  return _StepResult(d.get("observation", {}), float(d.get("reward", 0)), bool(d.get("done", False)))
58
 
59
  def close(self):
 
63
  # ---------------------------------------------------------------------------
64
  # Configuration
65
  # ---------------------------------------------------------------------------
66
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
67
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
68
+ MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
69
 
70
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://glitchghost-dataclean-openenv.hf.space")
 
71
 
72
  MAX_STEPS_PER_TASK = {"easy": 12, "medium": 20, "hard": 30}
73
  TEMPERATURE = 0.1
 
113
  # ---------------------------------------------------------------------------
114
  # Helpers
115
  # ---------------------------------------------------------------------------
116
+ ACTION_JSON_RE = re.compile(r"\{[^{}]*\}", re.DOTALL)
117
+ # Also match JSON that may span multiple lines or have nested content
118
+ ACTION_JSON_GREEDY_RE = re.compile(r"\{.*?\}", re.DOTALL)
119
 
120
 
121
  def parse_action(text: str) -> dict:
122
  """Extract the first JSON object from the model response."""
123
  if not text:
124
  return {"action_type": "noop"}
125
+ # Strip markdown code fences if present
126
+ cleaned = re.sub(r"```(?:json)?\s*", "", text)
127
+ cleaned = re.sub(r"```", "", cleaned).strip()
128
+ # Try parsing the whole cleaned text as JSON first
129
+ try:
130
+ obj = json.loads(cleaned)
131
+ if isinstance(obj, dict) and "action_type" in obj:
132
+ return obj
133
+ except (json.JSONDecodeError, ValueError):
134
+ pass
135
+ # Try regex extraction
136
+ for pattern in [ACTION_JSON_RE, ACTION_JSON_GREEDY_RE]:
137
+ for m in pattern.finditer(cleaned):
138
+ try:
139
+ obj = json.loads(m.group(0))
140
+ if isinstance(obj, dict) and "action_type" in obj:
141
+ return obj
142
+ except (json.JSONDecodeError, ValueError):
143
+ continue
144
  return {"action_type": "noop"}
145
 
146
 
 
199
  {"role": "user", "content": user_prompt},
200
  ]
201
 
202
+ for _attempt in range(3):
203
+ try:
204
+ completion = llm_client.chat.completions.create(
205
+ model=MODEL_NAME,
206
+ messages=messages,
207
+ temperature=TEMPERATURE,
208
+ max_tokens=MAX_TOKENS,
209
+ stream=False,
210
+ )
211
+ response_text = completion.choices[0].message.content or ""
212
+ break
213
+ except Exception as exc:
214
+ if "429" in str(exc) and _attempt < 2:
215
+ wait = 5 * (2 ** _attempt)
216
+ print(f" Step {step}: Rate limited, waiting {wait}s...")
217
+ time.sleep(wait)
218
+ continue
219
+ print(f" Step {step}: LLM error ({exc}), using noop")
220
+ response_text = '{"action_type": "noop"}'
221
+ break
222
 
223
  action = parse_action(response_text)
224
  print(f" Step {step}: {action.get('action_type', '?')}", end="")