File size: 6,596 Bytes
b0496f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
test_log_format.py
Strict verification of [START]/[STEP]/[END] log format in inference.py.
Uses a mock OpenAI client β€” no real API calls.
"""
import sys, io, json, types, os, re, importlib.util

# UTF-8 stdout
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")

def _load(path, name):
    spec = importlib.util.spec_from_file_location(name, path)
    mod  = importlib.util.module_from_spec(spec)
    sys.modules[name] = mod
    spec.loader.exec_module(mod)
    return mod

_load("models/models.py",              "models.models")
_load("skills/ambiguity_detection.py", "skills.ambiguity_detection")
_load("skills/conversation_memory.py", "skills.conversation_memory")
_load("skills/reward_system.py",       "skills.reward_system")
_load("env/env.py",                    "env.env")
_load("tasks/tasks.py",                "tasks.tasks")
_load("grader/grader.py",              "grader.grader")

# ── mock OpenAI + dotenv ──────────────────────────────────────────────────────
CANNED = [
    {"type": "ask",     "question": "When should this happen?"},
    {"type": "ask",     "question": "Who are the participants?"},
    {"type": "execute", "proposed_time": "10 AM", "proposed_participants": ["Team A"]},
]
_idx = [0]

class _FC:
    def __init__(self, t): self.message = types.SimpleNamespace(content=t)
class _FR:
    def __init__(self, t): self.choices = [_FC(t)]
class _FCmp:
    def create(self, **k):
        i = min(_idx[0], len(CANNED) - 1); _idx[0] += 1
        return _FR(json.dumps(CANNED[i]))
class _FChat:
    completions = _FCmp()
class _FOAI:
    def __init__(self, **k): pass
    chat = _FChat()

_fk = types.ModuleType("openai"); _fk.OpenAI = _FOAI; sys.modules["openai"] = _fk
_fd = types.ModuleType("dotenv"); _fd.load_dotenv = lambda: None; sys.modules["dotenv"] = _fd
os.environ["HF_TOKEN"]     = "test"
os.environ["MODEL_NAME"]   = "Qwen/Qwen2.5-72B-Instruct"
os.environ["API_BASE_URL"] = "http://localhost"

# ── run inference, capture output to file ────────────────────────────────────
_load("inference.py", "__inf__")
inf = sys.modules["__inf__"]
from tasks.tasks import TASKS

LOG_FILE = "_test_log_capture.txt"
with open(LOG_FILE, "w", encoding="utf-8") as f:
    old_stdout = sys.stdout
    sys.stdout = f
    inf.run_task(TASKS[3])   # hard_ambiguous β€” most complete coverage
    sys.stdout = old_stdout

with open(LOG_FILE, "r", encoding="utf-8") as f:
    out = f.read()

os.remove(LOG_FILE)

# ── display ───────────────────────────────────────────────────────────────────
lines = [l for l in out.strip().splitlines() if l.strip()]

print("=== CAPTURED LOG ===")
for l in lines:
    print(" ", l)

# ── format checks ─────────────────────────────────────────────────────────────
print("\n=== FORMAT CHECKS ===")
passed, failed = [], []

def ck(name, cond, got=None):
    if cond:
        passed.append(name)
        print(f"  [PASS] {name}")
    else:
        failed.append(name)
        print(f"  [FAIL] {name}  got={got!r}")

starts = [l for l in lines if l.startswith("[START]")]
steps  = [l for l in lines if l.startswith("[STEP]")]
ends   = [l for l in lines if l.startswith("[END]")]

ck("[START] present (1 line)",  len(starts) == 1,  len(starts))
ck("[STEP]s present (>=1)",     len(steps)  >= 1,  len(steps))
ck("[END] present (1 line)",    len(ends)   == 1,  len(ends))

# [START] fields
sl = starts[0]
ck("[START] has task=",   "task="  in sl)
ck("[START] has env=",    "env="   in sl)
ck("[START] has model=",  "model=" in sl)

# [STEP] field checks (check every step)
for i, s in enumerate(steps):
    n = i + 1
    ck(f"[STEP]{n} has step=",            "step="   in s)
    ck(f"[STEP]{n} has action=",          "action=" in s)
    ck(f"[STEP]{n} has reward=",          "reward=" in s)
    ck(f"[STEP]{n} has done=",            "done="   in s)
    ck(f"[STEP]{n} has error=",           "error="  in s)
    m = re.search(r"reward=(\d+\.\d+)", s)
    dp = m and len(m.group(1).split(".")[1]) == 2
    ck(f"[STEP]{n} reward is 2dp",        dp,                             m and m.group(1))
    ck(f"[STEP]{n} done= is lowercase",   "done=true" in s or "done=false" in s)

# [END] field checks
el = ends[0]
ck("[END] has success=",  "success="  in el)
ck("[END] has steps=",    "steps="    in el)
ck("[END] has score=",    "score="    in el)
ck("[END] has rewards=",  "rewards="  in el)
ck("[END] success is lowercase",    "success=true" in el or "success=false" in el)

# score format + value
m2 = re.search(r"score=(\d+\.\d+)", el)
ck("[END] score= is present",   bool(m2),               el)
if m2:
    sc_str = m2.group(1)
    sc_val = float(sc_str)
    ck("[END] score is 2dp",        len(sc_str.split(".")[1]) == 2,  sc_str)
    ck("[END] score in [0.0, 1.0]", 0.0 <= sc_val <= 1.0,           sc_val)
    # score = mean(rewards) = mean([0.30, 0.30, 1.00]) = 0.53
    # Not 1.0 β€” that would be the final step reward only.
    # Score > 0.5 β†’ success=true (correct)
    ck("[END] score > 0.5 (success threshold)",  sc_val > 0.5,    sc_val)
    ck("[END] score = 0.53 (mean of 3 steps)",   sc_val == 0.53,  sc_val)

# rewards format
m3 = re.search(r"rewards=([\d.,]+)", el)
ck("[END] rewards= values present", bool(m3), el)
if m3:
    parts = m3.group(1).split(",")
    ck("[END] all rewards are 2dp",
       all(len(p.split(".")[1]) == 2 for p in parts if "." in p), parts)

# Field ORDER: success β†’ steps β†’ score β†’ rewards
try:
    order_ok = (
        el.index("success=") < el.index("steps=") <
        el.index("score=")  < el.index("rewards=")
    )
except ValueError:
    order_ok = False
ck("[END] field order: success→steps→score→rewards", order_ok)

# No extra spaces between fields
ck("[END] no double spaces", "  " not in el)

# ── summary ───────────────────────────────────────────────────────────────────
print(f"\nPassed: {len(passed)}/{len(passed)+len(failed)}")
if failed:
    print("FAILURES:")
    for f in failed:
        print(f"  - {f}")
    sys.exit(1)
else:
    print("ALL LOG FORMAT CHECKS PASSED")