Spaces:
Sleeping
Sleeping
| """ | |
| test_inference_format.py | |
| Validates inference.py log format + fallback path WITHOUT making real API calls. | |
| Monkey-patches the OpenAI client to return canned responses. | |
| """ | |
| import sys, importlib.util, types, json | |
| # ββ load modules without triggering openenv-core/gradio ββββββββββββββββββββββ | |
| def load(path, name): | |
| spec = importlib.util.spec_from_file_location(name, path) | |
| mod = importlib.util.module_from_spec(spec) | |
| sys.modules[name] = mod | |
| spec.loader.exec_module(mod) | |
| return mod | |
| load("models/models.py", "models.models") | |
| load("skills/ambiguity_detection.py", "skills.ambiguity_detection") | |
| load("skills/conversation_memory.py", "skills.conversation_memory") | |
| load("skills/reward_system.py", "skills.reward_system") | |
| load("env/env.py", "env.env") | |
| load("tasks/tasks.py", "tasks.tasks") | |
| load("grader/grader.py", "grader.grader") | |
| # ββ patch openai before inference.py imports it ββββββββββββββββββββββββββββββ | |
| CANNED = [ | |
| {"type": "ask", "question": "When should this happen?"}, | |
| {"type": "ask", "question": "Who are the participants?"}, | |
| {"type": "execute", "proposed_time": "10 AM", "proposed_participants": ["Team A"]}, | |
| ] | |
| call_count = [0] | |
| class FakeChoice: | |
| def __init__(self, text): | |
| self.message = types.SimpleNamespace(content=text) | |
| class FakeResp: | |
| def __init__(self, text): | |
| self.choices = [FakeChoice(text)] | |
| class FakeCompletions: | |
| def create(self, **kwargs): | |
| idx = min(call_count[0], len(CANNED) - 1) | |
| call_count[0] += 1 | |
| return FakeResp(json.dumps(CANNED[idx])) | |
| class FakeChat: | |
| completions = FakeCompletions() | |
| class FakeOpenAI: | |
| def __init__(self, **kwargs): pass | |
| chat = FakeChat() | |
| fake_openai = types.ModuleType("openai") | |
| fake_openai.OpenAI = FakeOpenAI | |
| sys.modules["openai"] = fake_openai | |
| # also stub dotenv | |
| fake_dotenv = types.ModuleType("dotenv") | |
| fake_dotenv.load_dotenv = lambda: None | |
| sys.modules["dotenv"] = fake_dotenv | |
| import os | |
| os.environ["HF_TOKEN"] = "test-token" | |
| os.environ["MODEL_NAME"] = "TestModel" | |
| os.environ["API_BASE_URL"] = "http://localhost" | |
| # ββ capture stdout ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import io | |
| captured = io.StringIO() | |
| sys.stdout = captured | |
| load("inference.py", "__inference__") | |
| import importlib | |
| inf = sys.modules["__inference__"] | |
| # run only the first task (hard_ambiguous for best coverage) | |
| from tasks.tasks import TASKS | |
| inf.run_task(TASKS[3]) | |
| sys.stdout = sys.__stdout__ | |
| output = captured.getvalue() | |
| print("=== CAPTURED OUTPUT ===") | |
| print(output) | |
| # ββ validate format βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| lines = [l for l in output.strip().splitlines() if l.strip()] | |
| start_lines = [l for l in lines if l.startswith("[START]")] | |
| step_lines = [l for l in lines if l.startswith("[STEP]")] | |
| end_lines = [l for l in lines if l.startswith("[END]")] | |
| print("=== FORMAT CHECKS ===") | |
| assert len(start_lines) == 1, f"Expected 1 [START], got {len(start_lines)}" | |
| print("[OK] [START] present (1 line)") | |
| assert len(step_lines) >= 1, f"Expected >=1 [STEP], got {len(step_lines)}" | |
| print(f"[OK] [STEP] present ({len(step_lines)} lines)") | |
| assert len(end_lines) == 1, f"Expected 1 [END], got {len(end_lines)}" | |
| print("[OK] [END] present (1 line)") | |
| # Check [START] fields | |
| sl = start_lines[0] | |
| assert "task=" in sl and "env=" in sl and "model=" in sl, f"[START] fields missing: {sl}" | |
| print("[OK] [START] has task= env= model=") | |
| # Check [STEP] fields | |
| for s in step_lines: | |
| assert "step=" in s, s | |
| assert "action=" in s, s | |
| assert "reward=" in s, s | |
| assert "done=" in s, s | |
| assert "error=" in s, s | |
| # reward formatted to 2 decimals | |
| import re | |
| m = re.search(r"reward=(\d+\.\d+)", s) | |
| assert m, f"reward not formatted: {s}" | |
| assert len(m.group(1).split(".")[1]) == 2, f"reward not 2dp: {s}" | |
| print("[OK] [STEP] has step= action= reward(2dp) done= error=") | |
| # Check done values are lowercase | |
| for s in step_lines: | |
| assert "done=true" in s or "done=false" in s, f"done not lowercase: {s}" | |
| print("[OK] done= is lowercase true/false") | |
| # Check [END] fields | |
| el = end_lines[0] | |
| assert "success=" in el and "steps=" in el and "rewards=" in el | |
| print("[OK] [END] has success= steps= rewards=") | |
| print() | |
| print("=== ALL FORMAT CHECKS PASSED ===") | |