Spaces:
Running
Running
Update inference.py
Browse files- inference.py +36 -23
inference.py
CHANGED
|
@@ -28,7 +28,8 @@ except ImportError:
|
|
| 28 |
|
| 29 |
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 31 |
-
MODEL_NAME
|
|
|
|
| 32 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 33 |
HF_TOKEN_SOURCE = "HF_TOKEN"
|
| 34 |
if not HF_TOKEN:
|
|
@@ -37,16 +38,18 @@ if not HF_TOKEN:
|
|
| 37 |
if not HF_TOKEN:
|
| 38 |
HF_TOKEN = os.getenv("hf_token")
|
| 39 |
HF_TOKEN_SOURCE = "hf_token"
|
| 40 |
-
|
| 41 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 42 |
-
ENV_URL
|
| 43 |
-
BENCHMARK
|
| 44 |
-
MAX_STEPS
|
| 45 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 46 |
|
| 47 |
client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL)
|
| 48 |
|
|
|
|
| 49 |
# ββ Logging β STRICT PLAINTEXT FORMAT ββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 50 |
def _format_bool(value: bool) -> str:
|
| 51 |
return "true" if value else "false"
|
| 52 |
|
|
@@ -72,6 +75,7 @@ def log_start(task_id: str, env: str, model: str) -> None:
|
|
| 72 |
flush=True,
|
| 73 |
)
|
| 74 |
|
|
|
|
| 75 |
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 76 |
print(
|
| 77 |
f"[STEP] step={step} action={_normalize_token(action)} reward={round(reward, 2):.2f} "
|
|
@@ -79,6 +83,7 @@ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[
|
|
| 79 |
flush=True,
|
| 80 |
)
|
| 81 |
|
|
|
|
| 82 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 83 |
print(
|
| 84 |
f"[END] success={_format_bool(success)} steps={steps} score={round(score, 2):.2f} "
|
|
@@ -86,12 +91,15 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
|
|
| 86 |
flush=True,
|
| 87 |
)
|
| 88 |
|
|
|
|
| 89 |
# ββ Env client ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 90 |
def env_reset(url: str, difficulty: str) -> dict:
|
| 91 |
r = requests.post(f"{url}/reset", json={"difficulty": difficulty}, timeout=30)
|
| 92 |
r.raise_for_status()
|
| 93 |
return r.json()
|
| 94 |
|
|
|
|
| 95 |
def env_step(url: str, fixed_code: str, explanation: Optional[str] = None) -> dict:
|
| 96 |
payload = {"fixed_code": fixed_code}
|
| 97 |
if explanation:
|
|
@@ -100,7 +108,9 @@ def env_step(url: str, fixed_code: str, explanation: Optional[str] = None) -> di
|
|
| 100 |
r.raise_for_status()
|
| 101 |
return r.json()
|
| 102 |
|
|
|
|
| 103 |
# ββ LLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 104 |
SYSTEM_PROMPT = """You are an expert Python debugging agent.
|
| 105 |
|
| 106 |
RESPONSE FORMAT β JSON only, no markdown fences, no extra text:
|
|
@@ -120,10 +130,11 @@ COMMON BUG PATTERNS β memorize these:
|
|
| 120 |
- Wrong operator: target-n not target+n for complement
|
| 121 |
- Off-by-one: lst[1] for second element not lst[2]
|
| 122 |
|
| 123 |
-
IMPORTANT: If feedback shows TimeoutError
|
| 124 |
-
IMPORTANT: If Expected shows right-rotated list
|
| 125 |
"""
|
| 126 |
|
|
|
|
| 127 |
def _parse_llm_response(raw: str, buggy_code: str) -> dict:
|
| 128 |
"""Robustly parse LLM response handling control chars and malformed JSON."""
|
| 129 |
# Remove markdown fences
|
|
@@ -174,7 +185,7 @@ def _parse_llm_response(raw: str, buggy_code: str) -> dict:
|
|
| 174 |
exp = exp_match.group(1).replace("\\n", "\n") if exp_match else None
|
| 175 |
return {"fixed_code": code, "explanation": exp}
|
| 176 |
|
| 177 |
-
# Complete fallback
|
| 178 |
return {"fixed_code": buggy_code, "explanation": None}
|
| 179 |
|
| 180 |
|
|
@@ -228,6 +239,7 @@ def call_llm(
|
|
| 228 |
|
| 229 |
|
| 230 |
# ββ Episode βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 231 |
def run_episode(env_url: str, difficulty: str) -> tuple:
|
| 232 |
"""Run one full episode. Returns (success, steps_taken, rewards)."""
|
| 233 |
data = env_reset(env_url, difficulty)
|
|
@@ -251,19 +263,18 @@ def run_episode(env_url: str, difficulty: str) -> tuple:
|
|
| 251 |
code = action.get("fixed_code") or ""
|
| 252 |
last_code = code
|
| 253 |
|
| 254 |
-
reward = 0.0
|
| 255 |
-
done = False
|
| 256 |
step_error: Optional[str] = None
|
|
|
|
| 257 |
try:
|
| 258 |
result = env_step(env_url, code, action.get("explanation"))
|
| 259 |
reward = result.get("reward", 0.0)
|
| 260 |
-
done
|
| 261 |
-
obs_r
|
| 262 |
if isinstance(obs_r, dict):
|
| 263 |
last_feedback = obs_r.get("feedback", "")
|
| 264 |
-
step_error
|
| 265 |
-
if step_error is None:
|
| 266 |
-
step_error = obs_r.get("error")
|
| 267 |
except Exception as e:
|
| 268 |
step_error = str(e)
|
| 269 |
|
|
@@ -274,10 +285,10 @@ def run_episode(env_url: str, difficulty: str) -> tuple:
|
|
| 274 |
success = True
|
| 275 |
if done:
|
| 276 |
break
|
|
|
|
| 277 |
finally:
|
| 278 |
-
|
| 279 |
-
score
|
| 280 |
-
score = min(max(score, 0.0), 1.0)
|
| 281 |
success = success or (score >= SUCCESS_SCORE_THRESHOLD)
|
| 282 |
log_end(success, steps_taken, score, rewards)
|
| 283 |
|
|
@@ -285,18 +296,21 @@ def run_episode(env_url: str, difficulty: str) -> tuple:
|
|
| 285 |
|
| 286 |
|
| 287 |
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 288 |
def main():
|
| 289 |
parser = argparse.ArgumentParser(description="Code Debug Environment Baseline Agent")
|
| 290 |
parser.add_argument("--url", default=ENV_URL or "http://localhost:7860")
|
| 291 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
| 292 |
args = parser.parse_args()
|
| 293 |
url = args.url.rstrip("/")
|
| 294 |
|
| 295 |
if not HF_TOKEN:
|
| 296 |
print(
|
| 297 |
"# Missing API key. Set HF_TOKEN (or API_KEY / lowercase hf_token).",
|
| 298 |
-
file=sys.stderr,
|
| 299 |
-
flush=True,
|
| 300 |
)
|
| 301 |
sys.exit(1)
|
| 302 |
print(f"# Using API key from {HF_TOKEN_SOURCE}", file=sys.stderr, flush=True)
|
|
@@ -328,5 +342,4 @@ def main():
|
|
| 328 |
|
| 329 |
|
| 330 |
if __name__ == "__main__":
|
| 331 |
-
main()
|
| 332 |
-
|
|
|
|
| 28 |
|
| 29 |
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 31 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
|
| 32 |
+
|
| 33 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 34 |
HF_TOKEN_SOURCE = "HF_TOKEN"
|
| 35 |
if not HF_TOKEN:
|
|
|
|
| 38 |
if not HF_TOKEN:
|
| 39 |
HF_TOKEN = os.getenv("hf_token")
|
| 40 |
HF_TOKEN_SOURCE = "hf_token"
|
| 41 |
+
|
| 42 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 43 |
+
ENV_URL = os.getenv("ENV_URL")
|
| 44 |
+
BENCHMARK = "code-debug-env"
|
| 45 |
+
MAX_STEPS = 5
|
| 46 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 47 |
|
| 48 |
client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL)
|
| 49 |
|
| 50 |
+
|
| 51 |
# ββ Logging β STRICT PLAINTEXT FORMAT ββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
|
| 53 |
def _format_bool(value: bool) -> str:
|
| 54 |
return "true" if value else "false"
|
| 55 |
|
|
|
|
| 75 |
flush=True,
|
| 76 |
)
|
| 77 |
|
| 78 |
+
|
| 79 |
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 80 |
print(
|
| 81 |
f"[STEP] step={step} action={_normalize_token(action)} reward={round(reward, 2):.2f} "
|
|
|
|
| 83 |
flush=True,
|
| 84 |
)
|
| 85 |
|
| 86 |
+
|
| 87 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 88 |
print(
|
| 89 |
f"[END] success={_format_bool(success)} steps={steps} score={round(score, 2):.2f} "
|
|
|
|
| 91 |
flush=True,
|
| 92 |
)
|
| 93 |
|
| 94 |
+
|
| 95 |
# ββ Env client ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
|
| 97 |
def env_reset(url: str, difficulty: str) -> dict:
|
| 98 |
r = requests.post(f"{url}/reset", json={"difficulty": difficulty}, timeout=30)
|
| 99 |
r.raise_for_status()
|
| 100 |
return r.json()
|
| 101 |
|
| 102 |
+
|
| 103 |
def env_step(url: str, fixed_code: str, explanation: Optional[str] = None) -> dict:
|
| 104 |
payload = {"fixed_code": fixed_code}
|
| 105 |
if explanation:
|
|
|
|
| 108 |
r.raise_for_status()
|
| 109 |
return r.json()
|
| 110 |
|
| 111 |
+
|
| 112 |
# ββ LLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 113 |
+
|
| 114 |
SYSTEM_PROMPT = """You are an expert Python debugging agent.
|
| 115 |
|
| 116 |
RESPONSE FORMAT β JSON only, no markdown fences, no extra text:
|
|
|
|
| 130 |
- Wrong operator: target-n not target+n for complement
|
| 131 |
- Off-by-one: lst[1] for second element not lst[2]
|
| 132 |
|
| 133 |
+
IMPORTANT: If feedback shows TimeoutError, you have infinite loop. Add visited set.
|
| 134 |
+
IMPORTANT: If Expected shows right-rotated list, use lst[-k:] + lst[:-k].
|
| 135 |
"""
|
| 136 |
|
| 137 |
+
|
| 138 |
def _parse_llm_response(raw: str, buggy_code: str) -> dict:
|
| 139 |
"""Robustly parse LLM response handling control chars and malformed JSON."""
|
| 140 |
# Remove markdown fences
|
|
|
|
| 185 |
exp = exp_match.group(1).replace("\\n", "\n") if exp_match else None
|
| 186 |
return {"fixed_code": code, "explanation": exp}
|
| 187 |
|
| 188 |
+
# Complete fallback
|
| 189 |
return {"fixed_code": buggy_code, "explanation": None}
|
| 190 |
|
| 191 |
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
# ββ Episode βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 242 |
+
|
| 243 |
def run_episode(env_url: str, difficulty: str) -> tuple:
|
| 244 |
"""Run one full episode. Returns (success, steps_taken, rewards)."""
|
| 245 |
data = env_reset(env_url, difficulty)
|
|
|
|
| 263 |
code = action.get("fixed_code") or ""
|
| 264 |
last_code = code
|
| 265 |
|
| 266 |
+
reward: float = 0.0
|
| 267 |
+
done: bool = False
|
| 268 |
step_error: Optional[str] = None
|
| 269 |
+
|
| 270 |
try:
|
| 271 |
result = env_step(env_url, code, action.get("explanation"))
|
| 272 |
reward = result.get("reward", 0.0)
|
| 273 |
+
done = result.get("done", False)
|
| 274 |
+
obs_r = result.get("observation", {})
|
| 275 |
if isinstance(obs_r, dict):
|
| 276 |
last_feedback = obs_r.get("feedback", "")
|
| 277 |
+
step_error = obs_r.get("last_action_error") or obs_r.get("error")
|
|
|
|
|
|
|
| 278 |
except Exception as e:
|
| 279 |
step_error = str(e)
|
| 280 |
|
|
|
|
| 285 |
success = True
|
| 286 |
if done:
|
| 287 |
break
|
| 288 |
+
|
| 289 |
finally:
|
| 290 |
+
score = max(rewards) if rewards else 0.0
|
| 291 |
+
score = min(max(score, 0.0), 1.0)
|
|
|
|
| 292 |
success = success or (score >= SUCCESS_SCORE_THRESHOLD)
|
| 293 |
log_end(success, steps_taken, score, rewards)
|
| 294 |
|
|
|
|
| 296 |
|
| 297 |
|
| 298 |
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 299 |
+
|
| 300 |
def main():
|
| 301 |
parser = argparse.ArgumentParser(description="Code Debug Environment Baseline Agent")
|
| 302 |
parser.add_argument("--url", default=ENV_URL or "http://localhost:7860")
|
| 303 |
+
parser.add_argument(
|
| 304 |
+
"--difficulty", default=None,
|
| 305 |
+
choices=["easy", "medium", "hard", "all"],
|
| 306 |
+
)
|
| 307 |
args = parser.parse_args()
|
| 308 |
url = args.url.rstrip("/")
|
| 309 |
|
| 310 |
if not HF_TOKEN:
|
| 311 |
print(
|
| 312 |
"# Missing API key. Set HF_TOKEN (or API_KEY / lowercase hf_token).",
|
| 313 |
+
file=sys.stderr, flush=True,
|
|
|
|
| 314 |
)
|
| 315 |
sys.exit(1)
|
| 316 |
print(f"# Using API key from {HF_TOKEN_SOURCE}", file=sys.stderr, flush=True)
|
|
|
|
| 342 |
|
| 343 |
|
| 344 |
if __name__ == "__main__":
|
| 345 |
+
main()
|
|
|