File size: 13,052 Bytes
fab9447
 
9007754
fab9447
 
9007754
 
fab9447
9007754
fab9447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9007754
 
 
86dae99
 
 
 
b522b5c
fab9447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b522b5c
 
 
 
fab9447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b522b5c
 
 
86dae99
b522b5c
 
 
 
 
86dae99
b522b5c
 
 
 
 
 
 
 
 
 
86dae99
b522b5c
 
 
 
 
 
 
fab9447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9007754
86dae99
9007754
 
86dae99
9007754
 
 
 
 
 
 
 
 
 
 
 
 
86dae99
9007754
 
 
 
 
 
 
 
 
 
 
86dae99
9007754
 
 
 
 
 
 
 
b522b5c
 
 
 
 
 
 
 
 
 
9007754
86dae99
9007754
 
 
b522b5c
 
 
 
9007754
 
 
 
fab9447
 
 
 
 
 
 
 
 
 
 
 
b522b5c
 
 
 
 
 
 
 
 
 
 
 
9007754
fab9447
9007754
 
 
fab9447
9007754
 
 
 
 
 
 
 
 
b522b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fab9447
9007754
 
 
 
 
 
 
 
 
 
 
fab9447
9007754
 
 
fab9447
 
 
 
 
 
9007754
b522b5c
fab9447
9007754
 
 
 
 
 
 
 
 
fab9447
 
 
b522b5c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
from __future__ import annotations

import asyncio
import json
import os
import subprocess
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path

from openai import OpenAI
from workflow_arena import WorkflowArenaAction, WorkflowArenaEnv
from workflow_arena.models import (
    DifficultyPreset,
    WorkflowActionType,
    WorkflowArenaObservation,
    WorkflowTaskView,
)

BENCHMARK = "WorkflowArena"
PRESETS = [
    DifficultyPreset.EASY,
    DifficultyPreset.MEDIUM,
    DifficultyPreset.HARD,
]
PROJECT_DIR = Path(__file__).resolve().parent
IMAGE_NAME = "workflow-arena-inference:latest"
DOCKERFILE_PATH = PROJECT_DIR / "server" / "Dockerfile"
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "qwen/qwen3.5-9b")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
DEFAULT_BASE_URL = os.getenv("WORKFLOW_ARENA_BASE_URL", "http://localhost:8000")
TEMPERATURE = 0.0
MAX_STEPS = 256

SYSTEM_PROMPT = (
    "You are scheduling a dependency-constrained workflow on limited workers. "
    "Respond with compact JSON only. "
    'Valid formats: {"action_type":"wait","task_ids":[]} or '
    '{"action_type":"dispatch","task_ids":["task_01","task_02"]}. '
    "Only dispatch task ids that appear in ready_tasks for the current observation. "
    "Never exceed free_workers. "
    'If free_workers is 0 and running_tasks is non-empty, respond with {"action_type":"wait","task_ids":[]}. '
    "If your previous action was invalid, use validation_error to correct it while still reasoning from the current observation. "
    "Never repeat a previously dispatched task unless it still appears in ready_tasks."
)


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(
    step: int, action: str, reward: float, done: bool, error: str | None
) -> None:
    error_val = error if error else "null"
    done_val = str(done).lower()
    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
    rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


def log_warning(message: str) -> None:
    print(f"[WARN] {message}", flush=True)


def compact_task(task: WorkflowTaskView) -> dict[str, object]:
    return {
        "task_id": task.task_id,
        "duration": task.duration,
        "priority": task.priority,
        "deadline": task.deadline,
        "criticality": task.criticality,
        "slack": task.slack,
        "downstream_count": task.downstream_count,
        "dependencies": task.dependencies,
        "attempt_count": task.attempt_count,
    }


def make_user_prompt(observation: WorkflowArenaObservation) -> str:
    must_wait = observation.free_workers == 0 and bool(observation.running_tasks)
    return json.dumps(
        {
            "instruction": observation.instruction,
            "current_time": observation.current_time,
            "effective_workers": observation.effective_workers,
            "degraded_workers": observation.degraded_workers,
            "free_workers": observation.free_workers,
            "time_budget": observation.time_budget,
            "time_remaining": observation.time_remaining,
            "must_wait": must_wait,
            "ready_tasks": [compact_task(task) for task in observation.ready_tasks],
            "running_tasks": [compact_task(task) for task in observation.running_tasks],
            "progress": observation.progress.model_dump(mode="json"),
            "reward_breakdown": observation.last_reward_breakdown.model_dump(
                mode="json"
            ),
            "note": observation.note,
            "validation_error": observation.validation_error,
            "recent_failure_events": [
                event.model_dump(mode="json")
                for event in observation.recent_failure_events
            ],
            "last_action": observation.received_action,
        },
        separators=(",", ":"),
    )


def heuristic_action(observation: WorkflowArenaObservation) -> WorkflowArenaAction:
    if observation.free_workers <= 0 and observation.running_tasks:
        return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[])

    if not observation.ready_tasks or observation.free_workers <= 0:
        return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[])

    time_remaining = observation.time_remaining
    ranked = sorted(
        observation.ready_tasks,
        key=lambda task: (
            time_remaining is not None and task.duration > time_remaining,
            max(0, task.duration - time_remaining) if time_remaining is not None else 0,
            task.deadline if task.deadline is not None else 10**9,
            -(task.criticality or 0.0),
            -task.priority,
            task.duration,
            task.task_id,
        ),
    )
    selected = [task.task_id for task in ranked[: observation.free_workers]]
    return WorkflowArenaAction(
        action_type=WorkflowActionType.DISPATCH,
        task_ids=selected,
    )


def parse_action(
    text: str, observation: WorkflowArenaObservation
) -> WorkflowArenaAction:
    text = text.strip()
    if not text:
        raise ValueError("Model response did not include JSON action")
    payload = json.loads(text)
    return WorkflowArenaAction.model_validate(payload)


def get_model_action(
    client: OpenAI,
    model_name: str,
    observation: WorkflowArenaObservation,
) -> WorkflowArenaAction:
    prompt = make_user_prompt(observation)
    completion = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt},
        ],
        temperature=TEMPERATURE,
        max_tokens=120,
    )
    text = (completion.choices[0].message.content or "").strip()
    return parse_action(text, observation)


def action_to_log_string(action: WorkflowArenaAction) -> str:
    payload = action.model_dump(mode="json")
    if payload.get("metadata") == {}:
        payload.pop("metadata", None)
    return json.dumps(payload, separators=(",", ":"))


def resolve_model_client() -> tuple[OpenAI | None, str]:
    api_key = (
        os.getenv("API_KEY")
        or HF_TOKEN
        or os.getenv("OPENAI_API_KEY")
    )
    missing = []

    if not api_key:
        missing.append("API_KEY or HF_TOKEN")

    if missing:
        log_warning(
            "Missing model configuration ("
            + ", ".join(missing)
            + "). Falling back to heuristic policy."
        )
        return None, "heuristic"

    try:
        return OpenAI(base_url=API_BASE_URL, api_key=api_key), MODEL_NAME
    except Exception as exc:  # pragma: no cover - defensive initialization fallback
        log_warning(
            f"Failed to initialize model client: {exc}. Falling back to heuristic policy."
        )
        return None, "heuristic"


def compute_score(observation: WorkflowArenaObservation) -> float:
    score = observation.benchmark_score
    if score is None:
        score = observation.success_metrics.benchmark_score
    return max(0.0, min(1.0, float(score or 0.0)))


def is_success(observation: WorkflowArenaObservation) -> bool:
    return bool(
        observation.done
        and observation.success_metrics.makespan is not None
        and observation.termination_reason is None
    )


@dataclass
class EpisodeResult:
    success: bool
    steps: int
    score: float
    rewards: list[float]


def ensure_local_image() -> None:
    local_image_name = LOCAL_IMAGE_NAME or IMAGE_NAME
    try:
        inspect_result = subprocess.run(
            ["docker", "image", "inspect", local_image_name],
            cwd=PROJECT_DIR,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            check=False,
        )
    except OSError as exc:
        raise RuntimeError(f"Failed to execute docker: {exc}") from exc

    if inspect_result.returncode == 0:
        return

    try:
        build_result = subprocess.run(
            ["docker", "build", "-t", local_image_name, "-f", str(DOCKERFILE_PATH), "."],
            cwd=PROJECT_DIR,
            capture_output=True,
            text=True,
            check=False,
        )
    except OSError as exc:
        raise RuntimeError(f"Failed to execute docker build: {exc}") from exc

    if build_result.returncode != 0:
        raise RuntimeError(
            "Failed to build Docker image for inference.\n"
            f"Command: docker build -t {local_image_name} -f {DOCKERFILE_PATH} .\n"
            f"Exit code: {build_result.returncode}\n"
            f"Stdout: {build_result.stdout}\n"
            f"Stderr: {build_result.stderr}"
        )


@asynccontextmanager
async def managed_env():
    try:
        async with WorkflowArenaEnv(base_url=DEFAULT_BASE_URL) as env:
            yield env
            return
    except Exception as exc:
        log_warning(
            f"Failed to connect to environment at {DEFAULT_BASE_URL}: {exc}. "
            "Trying local Docker fallback."
        )

    ensure_local_image()
    env = await WorkflowArenaEnv.from_docker_image(LOCAL_IMAGE_NAME or IMAGE_NAME)
    try:
        yield env
    finally:
        try:
            await env.close()
        except Exception as exc:  # pragma: no cover - teardown failures should not fail inference
            log_warning(f"Failed to close Docker environment cleanly: {exc}")


async def run_episode(
    env,
    client: OpenAI | None,
    model_name: str,
    preset: DifficultyPreset,
    seed: int,
) -> EpisodeResult:
    rewards: list[float] = []
    steps_taken = 0
    success = False
    score = 0.0

    log_start(task=preset.value, env=BENCHMARK, model=model_name)

    try:
        result = await env.reset(
            seed=seed,
            preset=preset.value,
        )
    except Exception as exc:  # pragma: no cover - env availability failures are external
        log_warning(f"Failed to reset preset={preset.value}: {exc}")
        log_end(success=False, steps=steps_taken, score=score, rewards=rewards)
        return EpisodeResult(
            success=success, steps=steps_taken, score=score, rewards=rewards
        )

    observation = result.observation

    while not observation.done and steps_taken < MAX_STEPS:
        try:
            if client is None:
                action = heuristic_action(observation)
            else:
                action = get_model_action(client, model_name, observation)
        except (
            Exception
        ):  # pragma: no cover - network/model failures are expected sometimes
            action = heuristic_action(observation)

        try:
            result = await env.step(action)
        except Exception as exc:  # pragma: no cover - preserve log format and continue safely
            fallback_action = heuristic_action(observation)
            if fallback_action != action:
                log_warning(
                    f"Step failed for preset={preset.value} with model action: {exc}. "
                    "Retrying with heuristic action."
                )
            action = fallback_action
            try:
                result = await env.step(action)
            except Exception as retry_exc:
                log_warning(
                    f"Step failed for preset={preset.value} even with heuristic action: {retry_exc}"
                )
                break

        observation = result.observation
        reward = float(result.reward or 0.0)
        rewards.append(reward)
        steps_taken += 1
        log_step(
            step=steps_taken,
            action=action_to_log_string(action),
            reward=reward,
            done=bool(result.done),
            error=observation.validation_error,
        )

    success = is_success(observation)
    score = compute_score(observation) if observation.done else 0.0
    log_end(success=success, steps=steps_taken, score=score, rewards=rewards)

    return EpisodeResult(
        success=success, steps=steps_taken, score=score, rewards=rewards
    )


async def main() -> None:
    client, model_name = resolve_model_client()

    async with managed_env() as env:
        for index, preset in enumerate(PRESETS):
            await run_episode(
                env=env,
                client=client,
                model_name=model_name,
                preset=preset,
                seed=100 + index,
            )


if __name__ == "__main__":
    try:
        asyncio.run(main())
    except Exception as exc:  # pragma: no cover - final safeguard for validator stability
        log_warning(f"Fatal inference error: {exc}")