ambiguity-env / test_inference_format.py
Yaser77's picture
feat: ambiguity resolution environment v1.0 - OpenEnv Hackathon
c06cf60
"""
test_inference_format.py
Validates inference.py log format + fallback path WITHOUT making real API calls.
Monkey-patches the OpenAI client to return canned responses.
"""
import sys, importlib.util, types, json
# ── load modules without triggering openenv-core/gradio ──────────────────────
def load(path, name):
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec)
sys.modules[name] = mod
spec.loader.exec_module(mod)
return mod
load("models/models.py", "models.models")
load("skills/ambiguity_detection.py", "skills.ambiguity_detection")
load("skills/conversation_memory.py", "skills.conversation_memory")
load("skills/reward_system.py", "skills.reward_system")
load("env/env.py", "env.env")
load("tasks/tasks.py", "tasks.tasks")
load("grader/grader.py", "grader.grader")
# ── patch openai before inference.py imports it ──────────────────────────────
CANNED = [
{"type": "ask", "question": "When should this happen?"},
{"type": "ask", "question": "Who are the participants?"},
{"type": "execute", "proposed_time": "10 AM", "proposed_participants": ["Team A"]},
]
call_count = [0]
class FakeChoice:
def __init__(self, text):
self.message = types.SimpleNamespace(content=text)
class FakeResp:
def __init__(self, text):
self.choices = [FakeChoice(text)]
class FakeCompletions:
def create(self, **kwargs):
idx = min(call_count[0], len(CANNED) - 1)
call_count[0] += 1
return FakeResp(json.dumps(CANNED[idx]))
class FakeChat:
completions = FakeCompletions()
class FakeOpenAI:
def __init__(self, **kwargs): pass
chat = FakeChat()
fake_openai = types.ModuleType("openai")
fake_openai.OpenAI = FakeOpenAI
sys.modules["openai"] = fake_openai
# also stub dotenv
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda: None
sys.modules["dotenv"] = fake_dotenv
import os
os.environ["HF_TOKEN"] = "test-token"
os.environ["MODEL_NAME"] = "TestModel"
os.environ["API_BASE_URL"] = "http://localhost"
# ── capture stdout ────────────────────────────────────────────────────────────
import io
captured = io.StringIO()
sys.stdout = captured
load("inference.py", "__inference__")
import importlib
inf = sys.modules["__inference__"]
# run only the first task (hard_ambiguous for best coverage)
from tasks.tasks import TASKS
inf.run_task(TASKS[3])
sys.stdout = sys.__stdout__
output = captured.getvalue()
print("=== CAPTURED OUTPUT ===")
print(output)
# ── validate format ───────────────────────────────────────────────────────────
lines = [l for l in output.strip().splitlines() if l.strip()]
start_lines = [l for l in lines if l.startswith("[START]")]
step_lines = [l for l in lines if l.startswith("[STEP]")]
end_lines = [l for l in lines if l.startswith("[END]")]
print("=== FORMAT CHECKS ===")
assert len(start_lines) == 1, f"Expected 1 [START], got {len(start_lines)}"
print("[OK] [START] present (1 line)")
assert len(step_lines) >= 1, f"Expected >=1 [STEP], got {len(step_lines)}"
print(f"[OK] [STEP] present ({len(step_lines)} lines)")
assert len(end_lines) == 1, f"Expected 1 [END], got {len(end_lines)}"
print("[OK] [END] present (1 line)")
# Check [START] fields
sl = start_lines[0]
assert "task=" in sl and "env=" in sl and "model=" in sl, f"[START] fields missing: {sl}"
print("[OK] [START] has task= env= model=")
# Check [STEP] fields
for s in step_lines:
assert "step=" in s, s
assert "action=" in s, s
assert "reward=" in s, s
assert "done=" in s, s
assert "error=" in s, s
# reward formatted to 2 decimals
import re
m = re.search(r"reward=(\d+\.\d+)", s)
assert m, f"reward not formatted: {s}"
assert len(m.group(1).split(".")[1]) == 2, f"reward not 2dp: {s}"
print("[OK] [STEP] has step= action= reward(2dp) done= error=")
# Check done values are lowercase
for s in step_lines:
assert "done=true" in s or "done=false" in s, f"done not lowercase: {s}"
print("[OK] done= is lowercase true/false")
# Check [END] fields
el = end_lines[0]
assert "success=" in el and "steps=" in el and "rewards=" in el
print("[OK] [END] has success= steps= rewards=")
print()
print("=== ALL FORMAT CHECKS PASSED ===")