Cyber-Machine commited on
Commit
fab9447
·
verified ·
1 Parent(s): aea0016

feat: add inference module

Browse files
Files changed (1) hide show
  1. inference.py +274 -0
inference.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+
7
+ from openai import OpenAI
8
+ from workflow_arena import WorkflowArenaAction, WorkflowArenaEnv
9
+ from workflow_arena.models import (
10
+ DifficultyPreset,
11
+ WorkflowActionType,
12
+ WorkflowArenaObservation,
13
+ WorkflowTaskView,
14
+ )
15
+ from workflow_arena.presets import get_preset_config
16
+
17
+ BENCHMARK = "WorkflowArena"
18
+ PRESETS = [
19
+ DifficultyPreset.EASY,
20
+ DifficultyPreset.MEDIUM,
21
+ DifficultyPreset.HARD,
22
+ ]
23
+ DEFAULT_BASE_URL = os.getenv("WORKFLOW_ARENA_BASE_URL", "http://localhost:8000")
24
+ TEMPERATURE = 0.0
25
+ MAX_STEPS = 256
26
+
27
+ SYSTEM_PROMPT = (
28
+ "You are scheduling a dependency-constrained workflow on limited workers. "
29
+ "Respond with compact JSON only. "
30
+ 'Valid formats: {"action_type":"wait","task_ids":[]} or '
31
+ '{"action_type":"dispatch","task_ids":["task_01","task_02"]}. '
32
+ "Only dispatch task ids that appear in ready_tasks for the current observation. "
33
+ "Never exceed free_workers. "
34
+ 'If free_workers is 0 and running_tasks is non-empty, respond with {"action_type":"wait","task_ids":[]}. '
35
+ "If your previous action was invalid, use validation_error to correct it while still reasoning from the current observation. "
36
+ "Never repeat a previously dispatched task unless it still appears in ready_tasks."
37
+ )
38
+
39
+
40
+ def log_start(task: str, env: str, model: str) -> None:
41
+ print(f"[START] task={task} env={env} model={model}", flush=True)
42
+
43
+
44
+ def log_step(
45
+ step: int, action: str, reward: float, done: bool, error: str | None
46
+ ) -> None:
47
+ error_val = error if error else "null"
48
+ done_val = str(done).lower()
49
+ print(
50
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
51
+ flush=True,
52
+ )
53
+
54
+
55
+ def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
56
+ rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
57
+ print(
58
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
59
+ flush=True,
60
+ )
61
+
62
+
63
+ def compact_task(task: WorkflowTaskView) -> dict[str, object]:
64
+ return {
65
+ "task_id": task.task_id,
66
+ "duration": task.duration,
67
+ "priority": task.priority,
68
+ "deadline": task.deadline,
69
+ "criticality": task.criticality,
70
+ "slack": task.slack,
71
+ "downstream_count": task.downstream_count,
72
+ "dependencies": task.dependencies,
73
+ "attempt_count": task.attempt_count,
74
+ }
75
+
76
+
77
+ def make_user_prompt(observation: WorkflowArenaObservation) -> str:
78
+ must_wait = observation.free_workers == 0 and bool(observation.running_tasks)
79
+ return json.dumps(
80
+ {
81
+ "instruction": observation.instruction,
82
+ "current_time": observation.current_time,
83
+ "effective_workers": observation.effective_workers,
84
+ "degraded_workers": observation.degraded_workers,
85
+ "free_workers": observation.free_workers,
86
+ "time_budget": observation.time_budget,
87
+ "time_remaining": observation.time_remaining,
88
+ "must_wait": must_wait,
89
+ "ready_tasks": [compact_task(task) for task in observation.ready_tasks],
90
+ "running_tasks": [compact_task(task) for task in observation.running_tasks],
91
+ "progress": observation.progress.model_dump(mode="json"),
92
+ "reward_breakdown": observation.last_reward_breakdown.model_dump(
93
+ mode="json"
94
+ ),
95
+ "note": observation.note,
96
+ "validation_error": observation.validation_error,
97
+ "recent_failure_events": [
98
+ event.model_dump(mode="json")
99
+ for event in observation.recent_failure_events
100
+ ],
101
+ "last_action": observation.received_action,
102
+ },
103
+ separators=(",", ":"),
104
+ )
105
+
106
+
107
+ def heuristic_action(observation: WorkflowArenaObservation) -> WorkflowArenaAction:
108
+ if observation.free_workers <= 0 and observation.running_tasks:
109
+ return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[])
110
+
111
+ if not observation.ready_tasks or observation.free_workers <= 0:
112
+ return WorkflowArenaAction(action_type=WorkflowActionType.WAIT, task_ids=[])
113
+
114
+ time_remaining = observation.time_remaining
115
+ ranked = sorted(
116
+ observation.ready_tasks,
117
+ key=lambda task: (
118
+ time_remaining is not None and task.duration > time_remaining,
119
+ max(0, task.duration - time_remaining) if time_remaining is not None else 0,
120
+ task.deadline if task.deadline is not None else 10**9,
121
+ -(task.criticality or 0.0),
122
+ -task.priority,
123
+ task.duration,
124
+ task.task_id,
125
+ ),
126
+ )
127
+ selected = [task.task_id for task in ranked[: observation.free_workers]]
128
+ return WorkflowArenaAction(
129
+ action_type=WorkflowActionType.DISPATCH,
130
+ task_ids=selected,
131
+ )
132
+
133
+
134
+ def parse_action(
135
+ text: str, observation: WorkflowArenaObservation
136
+ ) -> WorkflowArenaAction:
137
+ text = text.strip()
138
+ if not text:
139
+ raise ValueError("Model response did not include JSON action")
140
+ payload = json.loads(text)
141
+ return WorkflowArenaAction.model_validate(payload)
142
+
143
+
144
+ def get_model_action(
145
+ client: OpenAI,
146
+ model_name: str,
147
+ observation: WorkflowArenaObservation,
148
+ ) -> WorkflowArenaAction:
149
+ prompt = make_user_prompt(observation)
150
+ completion = client.chat.completions.create(
151
+ model=model_name,
152
+ messages=[
153
+ {"role": "system", "content": SYSTEM_PROMPT},
154
+ {"role": "user", "content": prompt},
155
+ ],
156
+ temperature=TEMPERATURE,
157
+ max_tokens=120,
158
+ )
159
+ text = (completion.choices[0].message.content or "").strip()
160
+ return parse_action(text, observation)
161
+
162
+
163
+ def action_to_log_string(action: WorkflowArenaAction) -> str:
164
+ payload = action.model_dump(mode="json")
165
+ if payload.get("metadata") == {}:
166
+ payload.pop("metadata", None)
167
+ return json.dumps(payload, separators=(",", ":"))
168
+
169
+
170
+ def compute_score(observation: WorkflowArenaObservation) -> float:
171
+ score = observation.benchmark_score
172
+ if score is None:
173
+ score = observation.success_metrics.benchmark_score
174
+ return max(0.0, min(1.0, float(score or 0.0)))
175
+
176
+
177
+ def is_success(observation: WorkflowArenaObservation) -> bool:
178
+ return bool(
179
+ observation.done
180
+ and observation.success_metrics.makespan is not None
181
+ and observation.termination_reason is None
182
+ )
183
+
184
+
185
+ @dataclass
186
+ class EpisodeResult:
187
+ success: bool
188
+ steps: int
189
+ score: float
190
+ rewards: list[float]
191
+
192
+
193
+ def run_episode(
194
+ client: OpenAI | None,
195
+ model_name: str,
196
+ preset: DifficultyPreset,
197
+ seed: int,
198
+ ) -> EpisodeResult:
199
+ rewards: list[float] = []
200
+ steps_taken = 0
201
+ success = False
202
+ score = 0.0
203
+
204
+ log_start(task=preset.value, env=BENCHMARK, model=model_name)
205
+
206
+ with WorkflowArenaEnv(base_url=DEFAULT_BASE_URL).sync() as env:
207
+ preset_config = get_preset_config(preset)
208
+ result = env.reset(
209
+ seed=seed,
210
+ preset=preset.value,
211
+ worker_count=preset_config.worker_count,
212
+ )
213
+ observation = result.observation
214
+
215
+ while not observation.done and steps_taken < MAX_STEPS:
216
+ try:
217
+ if client is None:
218
+ action = heuristic_action(observation)
219
+ else:
220
+ action = get_model_action(client, model_name, observation)
221
+ except (
222
+ Exception
223
+ ): # pragma: no cover - network/model failures are expected sometimes
224
+ action = heuristic_action(observation)
225
+
226
+ try:
227
+ result = env.step(action)
228
+ except (
229
+ Exception
230
+ ): # pragma: no cover - preserve log format and continue safely
231
+ action = heuristic_action(observation)
232
+ result = env.step(action)
233
+
234
+ observation = result.observation
235
+ reward = float(result.reward or 0.0)
236
+ rewards.append(reward)
237
+ steps_taken += 1
238
+ log_step(
239
+ step=steps_taken,
240
+ action=action_to_log_string(action),
241
+ reward=reward,
242
+ done=bool(result.done),
243
+ error=observation.validation_error,
244
+ )
245
+
246
+ success = is_success(observation)
247
+ score = compute_score(observation) if observation.done else 0.0
248
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
249
+
250
+ return EpisodeResult(
251
+ success=success, steps=steps_taken, score=score, rewards=rewards
252
+ )
253
+
254
+
255
+ def main() -> None:
256
+ api_base_url = os.environ["API_BASE_URL"]
257
+ model_name = os.environ["MODEL_NAME"]
258
+ api_key = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
259
+ if not api_key:
260
+ raise RuntimeError("HF_TOKEN or OPENAI_API_KEY must be set.")
261
+
262
+ client = OpenAI(base_url=api_base_url, api_key=api_key)
263
+
264
+ for index, preset in enumerate(PRESETS):
265
+ run_episode(
266
+ client=client,
267
+ model_name=model_name,
268
+ preset=preset,
269
+ seed=100 + index,
270
+ )
271
+
272
+
273
+ if __name__ == "__main__":
274
+ main()