""" test_log_format.py Strict verification of [START]/[STEP]/[END] log format in inference.py. Uses a mock OpenAI client — no real API calls. """ import sys, io, json, types, os, re, importlib.util # UTF-8 stdout sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") def _load(path, name): spec = importlib.util.spec_from_file_location(name, path) mod = importlib.util.module_from_spec(spec) sys.modules[name] = mod spec.loader.exec_module(mod) return mod _load("models/models.py", "models.models") _load("skills/ambiguity_detection.py", "skills.ambiguity_detection") _load("skills/conversation_memory.py", "skills.conversation_memory") _load("skills/reward_system.py", "skills.reward_system") _load("env/env.py", "env.env") _load("tasks/tasks.py", "tasks.tasks") _load("grader/grader.py", "grader.grader") # ── mock OpenAI + dotenv ────────────────────────────────────────────────────── CANNED = [ {"type": "ask", "question": "When should this happen?"}, {"type": "ask", "question": "Who are the participants?"}, {"type": "execute", "proposed_time": "10 AM", "proposed_participants": ["Team A"]}, ] _idx = [0] class _FC: def __init__(self, t): self.message = types.SimpleNamespace(content=t) class _FR: def __init__(self, t): self.choices = [_FC(t)] class _FCmp: def create(self, **k): i = min(_idx[0], len(CANNED) - 1); _idx[0] += 1 return _FR(json.dumps(CANNED[i])) class _FChat: completions = _FCmp() class _FOAI: def __init__(self, **k): pass chat = _FChat() _fk = types.ModuleType("openai"); _fk.OpenAI = _FOAI; sys.modules["openai"] = _fk _fd = types.ModuleType("dotenv"); _fd.load_dotenv = lambda: None; sys.modules["dotenv"] = _fd os.environ["HF_TOKEN"] = "test" os.environ["MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct" os.environ["API_BASE_URL"] = "http://localhost" # ── run inference, capture output to file ──────────────────────────────────── _load("inference.py", "__inf__") inf = sys.modules["__inf__"] from tasks.tasks import TASKS LOG_FILE = "_test_log_capture.txt" with open(LOG_FILE, "w", encoding="utf-8") as f: old_stdout = sys.stdout sys.stdout = f inf.run_task(TASKS[3]) # hard_ambiguous — most complete coverage sys.stdout = old_stdout with open(LOG_FILE, "r", encoding="utf-8") as f: out = f.read() os.remove(LOG_FILE) # ── display ─────────────────────────────────────────────────────────────────── lines = [l for l in out.strip().splitlines() if l.strip()] print("=== CAPTURED LOG ===") for l in lines: print(" ", l) # ── format checks ───────────────────────────────────────────────────────────── print("\n=== FORMAT CHECKS ===") passed, failed = [], [] def ck(name, cond, got=None): if cond: passed.append(name) print(f" [PASS] {name}") else: failed.append(name) print(f" [FAIL] {name} got={got!r}") starts = [l for l in lines if l.startswith("[START]")] steps = [l for l in lines if l.startswith("[STEP]")] ends = [l for l in lines if l.startswith("[END]")] ck("[START] present (1 line)", len(starts) == 1, len(starts)) ck("[STEP]s present (>=1)", len(steps) >= 1, len(steps)) ck("[END] present (1 line)", len(ends) == 1, len(ends)) # [START] fields sl = starts[0] ck("[START] has task=", "task=" in sl) ck("[START] has env=", "env=" in sl) ck("[START] has model=", "model=" in sl) # [STEP] field checks (check every step) for i, s in enumerate(steps): n = i + 1 ck(f"[STEP]{n} has step=", "step=" in s) ck(f"[STEP]{n} has action=", "action=" in s) ck(f"[STEP]{n} has reward=", "reward=" in s) ck(f"[STEP]{n} has done=", "done=" in s) ck(f"[STEP]{n} has error=", "error=" in s) m = re.search(r"reward=(\d+\.\d+)", s) dp = m and len(m.group(1).split(".")[1]) == 2 ck(f"[STEP]{n} reward is 2dp", dp, m and m.group(1)) ck(f"[STEP]{n} done= is lowercase", "done=true" in s or "done=false" in s) # [END] field checks el = ends[0] ck("[END] has success=", "success=" in el) ck("[END] has steps=", "steps=" in el) ck("[END] has score=", "score=" in el) ck("[END] has rewards=", "rewards=" in el) ck("[END] success is lowercase", "success=true" in el or "success=false" in el) # score format + value m2 = re.search(r"score=(\d+\.\d+)", el) ck("[END] score= is present", bool(m2), el) if m2: sc_str = m2.group(1) sc_val = float(sc_str) ck("[END] score is 2dp", len(sc_str.split(".")[1]) == 2, sc_str) ck("[END] score in [0.0, 1.0]", 0.0 <= sc_val <= 1.0, sc_val) # score = mean(rewards) = mean([0.30, 0.30, 1.00]) = 0.53 # Not 1.0 — that would be the final step reward only. # Score > 0.5 → success=true (correct) ck("[END] score > 0.5 (success threshold)", sc_val > 0.5, sc_val) ck("[END] score = 0.53 (mean of 3 steps)", sc_val == 0.53, sc_val) # rewards format m3 = re.search(r"rewards=([\d.,]+)", el) ck("[END] rewards= values present", bool(m3), el) if m3: parts = m3.group(1).split(",") ck("[END] all rewards are 2dp", all(len(p.split(".")[1]) == 2 for p in parts if "." in p), parts) # Field ORDER: success → steps → score → rewards try: order_ok = ( el.index("success=") < el.index("steps=") < el.index("score=") < el.index("rewards=") ) except ValueError: order_ok = False ck("[END] field order: success→steps→score→rewards", order_ok) # No extra spaces between fields ck("[END] no double spaces", " " not in el) # ── summary ─────────────────────────────────────────────────────────────────── print(f"\nPassed: {len(passed)}/{len(passed)+len(failed)}") if failed: print("FAILURES:") for f in failed: print(f" - {f}") sys.exit(1) else: print("ALL LOG FORMAT CHECKS PASSED")