File size: 4,693 Bytes
c06cf60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
test_inference_format.py
Validates inference.py log format + fallback path WITHOUT making real API calls.
Monkey-patches the OpenAI client to return canned responses.
"""
import sys, importlib.util, types, json

# ── load modules without triggering openenv-core/gradio ──────────────────────
def load(path, name):
    spec = importlib.util.spec_from_file_location(name, path)
    mod  = importlib.util.module_from_spec(spec)
    sys.modules[name] = mod
    spec.loader.exec_module(mod)
    return mod

load("models/models.py",              "models.models")
load("skills/ambiguity_detection.py", "skills.ambiguity_detection")
load("skills/conversation_memory.py", "skills.conversation_memory")
load("skills/reward_system.py",       "skills.reward_system")
load("env/env.py",                    "env.env")
load("tasks/tasks.py",                "tasks.tasks")
load("grader/grader.py",              "grader.grader")

# ── patch openai before inference.py imports it ──────────────────────────────
CANNED = [
    {"type": "ask",     "question": "When should this happen?"},
    {"type": "ask",     "question": "Who are the participants?"},
    {"type": "execute", "proposed_time": "10 AM", "proposed_participants": ["Team A"]},
]
call_count = [0]

class FakeChoice:
    def __init__(self, text):
        self.message = types.SimpleNamespace(content=text)

class FakeResp:
    def __init__(self, text):
        self.choices = [FakeChoice(text)]

class FakeCompletions:
    def create(self, **kwargs):
        idx = min(call_count[0], len(CANNED) - 1)
        call_count[0] += 1
        return FakeResp(json.dumps(CANNED[idx]))

class FakeChat:
    completions = FakeCompletions()

class FakeOpenAI:
    def __init__(self, **kwargs): pass
    chat = FakeChat()

fake_openai = types.ModuleType("openai")
fake_openai.OpenAI = FakeOpenAI
sys.modules["openai"] = fake_openai

# also stub dotenv
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda: None
sys.modules["dotenv"] = fake_dotenv

import os
os.environ["HF_TOKEN"]      = "test-token"
os.environ["MODEL_NAME"]    = "TestModel"
os.environ["API_BASE_URL"]  = "http://localhost"

# ── capture stdout ────────────────────────────────────────────────────────────
import io
captured = io.StringIO()
sys.stdout = captured

load("inference.py", "__inference__")
import importlib
inf = sys.modules["__inference__"]

# run only the first task (hard_ambiguous for best coverage)
from tasks.tasks import TASKS
inf.run_task(TASKS[3])

sys.stdout = sys.__stdout__
output = captured.getvalue()

print("=== CAPTURED OUTPUT ===")
print(output)

# ── validate format ───────────────────────────────────────────────────────────
lines = [l for l in output.strip().splitlines() if l.strip()]

start_lines = [l for l in lines if l.startswith("[START]")]
step_lines  = [l for l in lines if l.startswith("[STEP]")]
end_lines   = [l for l in lines if l.startswith("[END]")]

print("=== FORMAT CHECKS ===")
assert len(start_lines) == 1,    f"Expected 1 [START], got {len(start_lines)}"
print("[OK] [START] present (1 line)")

assert len(step_lines) >= 1,     f"Expected >=1 [STEP], got {len(step_lines)}"
print(f"[OK] [STEP] present ({len(step_lines)} lines)")

assert len(end_lines) == 1,      f"Expected 1 [END], got {len(end_lines)}"
print("[OK] [END] present (1 line)")

# Check [START] fields
sl = start_lines[0]
assert "task=" in sl and "env=" in sl and "model=" in sl, f"[START] fields missing: {sl}"
print("[OK] [START] has task= env= model=")

# Check [STEP] fields
for s in step_lines:
    assert "step=" in s, s
    assert "action=" in s, s
    assert "reward=" in s, s
    assert "done=" in s, s
    assert "error=" in s, s
    # reward formatted to 2 decimals
    import re
    m = re.search(r"reward=(\d+\.\d+)", s)
    assert m, f"reward not formatted: {s}"
    assert len(m.group(1).split(".")[1]) == 2, f"reward not 2dp: {s}"
print("[OK] [STEP] has step= action= reward(2dp) done= error=")

# Check done values are lowercase
for s in step_lines:
    assert "done=true" in s or "done=false" in s, f"done not lowercase: {s}"
print("[OK] done= is lowercase true/false")

# Check [END] fields
el = end_lines[0]
assert "success=" in el and "steps=" in el and "rewards=" in el
print("[OK] [END] has success= steps= rewards=")

print()
print("=== ALL FORMAT CHECKS PASSED ===")