Roshan Kumar commited on
Commit
bc0dd7f
·
unverified ·
1 Parent(s): d755709

update inference

Browse files
Files changed (1) hide show
  1. inference.py +246 -9
inference.py CHANGED
@@ -1,11 +1,248 @@
1
- def parse_action(text):
2
- parts = text.split()
 
 
 
3
 
4
- if parts[0] == "assign_job":
5
- return Action(
6
- action_type="assign_job",
7
- job_id=parts[1],
8
- machine_id=parts[2],
9
- )
10
 
11
- return Action(action_type="wait")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Factory Environment Inference Script
3
+ ===================================
4
+ Follows OpenEnv evaluation format strictly.
5
+ """
6
 
7
+ import asyncio
8
+ import os
9
+ import textwrap
10
+ from typing import List, Optional
 
 
11
 
12
+ from openai import OpenAI
13
+
14
+ from factory_env.env import FactoryEnv
15
+ from factory_env.models import Action
16
+
17
+ # =========================
18
+ # ENV VARIABLES (MANDATORY)
19
+ # =========================
20
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
21
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
22
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
23
+
24
+ TASK_NAME = os.getenv("FACTORY_TASK", "easy")
25
+ BENCHMARK = "factory_env"
26
+
27
+ MAX_STEPS = 20
28
+ TEMPERATURE = 0.2
29
+ MAX_TOKENS = 100
30
+ SUCCESS_SCORE_THRESHOLD = 0.5
31
+
32
+ # =========================
33
+ # PROMPTS
34
+ # =========================
35
+ SYSTEM_PROMPT = textwrap.dedent(
36
+ """
37
+ You are controlling a factory scheduling system.
38
+
39
+ Your goal:
40
+ - Assign jobs to machines efficiently
41
+ - Minimize idle machines
42
+ - Finish all jobs as fast as possible
43
+
44
+ Available actions:
45
+ 1. assign_job <job_id> <machine_id>
46
+ 2. wait
47
+
48
+ Rules:
49
+ - Only assign jobs that exist
50
+ - Only assign to idle machines
51
+ - One action per step
52
+
53
+ Respond ONLY with the action string.
54
+ Example:
55
+ assign_job J1 M1
56
+ """
57
+ ).strip()
58
+
59
+
60
+ # =========================
61
+ # LOGGING FUNCTIONS (STRICT FORMAT)
62
+ # =========================
63
+ def log_start(task: str, env: str, model: str) -> None:
64
+ print(f"[START] task={task} env={env} model={model}", flush=True)
65
+
66
+
67
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
68
+ error_val = error if error else "null"
69
+ done_val = str(done).lower()
70
+ print(
71
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
72
+ flush=True,
73
+ )
74
+
75
+
76
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
77
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
78
+ print(
79
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
80
+ flush=True,
81
+ )
82
+
83
+
84
+ # =========================
85
+ # PROMPT BUILDER
86
+ # =========================
87
+ def build_user_prompt(step, obs, last_reward):
88
+ machines_str = "\n".join(
89
+ [f"{m.id}: {m.status} (job={m.current_job})" for m in obs.machines]
90
+ )
91
+
92
+ jobs_str = "\n".join(
93
+ [f"{j.id}: remaining={j.remaining_time}, deadline={j.deadline}" for j in obs.pending_jobs]
94
+ ) or "None"
95
+
96
+ return textwrap.dedent(
97
+ f"""
98
+ Step: {step}
99
+
100
+ Current Time: {obs.time}
101
+
102
+ Machines:
103
+ {machines_str}
104
+
105
+ Pending Jobs:
106
+ {jobs_str}
107
+
108
+ Last reward: {last_reward:.2f}
109
+
110
+ What action do you take?
111
+ """
112
+ ).strip()
113
+
114
+
115
+ # =========================
116
+ # LLM CALL
117
+ # =========================
118
+ def get_model_action(client: OpenAI, step, obs, last_reward) -> str:
119
+ try:
120
+ user_prompt = build_user_prompt(step, obs, last_reward)
121
+
122
+ completion = client.chat.completions.create(
123
+ model=MODEL_NAME,
124
+ messages=[
125
+ {"role": "system", "content": SYSTEM_PROMPT},
126
+ {"role": "user", "content": user_prompt},
127
+ ],
128
+ temperature=TEMPERATURE,
129
+ max_tokens=MAX_TOKENS,
130
+ )
131
+
132
+ text = (completion.choices[0].message.content or "").strip()
133
+ return text if text else "wait"
134
+
135
+ except Exception as e:
136
+ print(f"[DEBUG] LLM error: {e}", flush=True)
137
+ return "wait"
138
+
139
+
140
+ # =========================
141
+ # ACTION PARSER
142
+ # =========================
143
+ def parse_action(text: str) -> Action:
144
+ try:
145
+ parts = text.strip().split()
146
+
147
+ if parts[0] == "assign_job" and len(parts) == 3:
148
+ return Action(
149
+ action_type="assign_job",
150
+ job_id=parts[1],
151
+ machine_id=parts[2],
152
+ )
153
+
154
+ elif parts[0] == "wait":
155
+ return Action(action_type="wait")
156
+
157
+ except Exception:
158
+ pass
159
+
160
+ # fallback safe action
161
+ return Action(action_type="wait")
162
+
163
+
164
+ # =========================
165
+ # SIMPLE HEURISTIC FALLBACK
166
+ # =========================
167
+ def heuristic_action(obs) -> Action:
168
+ for job in obs.pending_jobs:
169
+ for machine in obs.machines:
170
+ if machine.status == "idle":
171
+ return Action(
172
+ action_type="assign_job",
173
+ job_id=job.id,
174
+ machine_id=machine.id,
175
+ )
176
+ return Action(action_type="wait")
177
+
178
+
179
+ # =========================
180
+ # MAIN LOOP
181
+ # =========================
182
+ async def main():
183
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
184
+
185
+ env = FactoryEnv(task=TASK_NAME)
186
+
187
+ rewards: List[float] = []
188
+ steps_taken = 0
189
+ score = 0.0
190
+ success = False
191
+
192
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
193
+
194
+ try:
195
+ result = await env.reset()
196
+ obs = result.observation
197
+ last_reward = 0.0
198
+
199
+ for step in range(1, MAX_STEPS + 1):
200
+ if result.done:
201
+ break
202
+
203
+ # LLM decision
204
+ action_text = get_model_action(client, step, obs, last_reward)
205
+
206
+ # Parse action
207
+ action = parse_action(action_text)
208
+
209
+ # Fallback if invalid
210
+ if action.action_type == "wait" and len(obs.pending_jobs) > 0:
211
+ action = heuristic_action(obs)
212
+ action_text = "heuristic_assign"
213
+
214
+ # Step env
215
+ result = await env.step(action)
216
+
217
+ obs = result.observation
218
+ reward = result.reward or 0.0
219
+ done = result.done
220
+ error = None
221
+
222
+ rewards.append(reward)
223
+ steps_taken = step
224
+ last_reward = reward
225
+
226
+ log_step(step, action_text, reward, done, error)
227
+
228
+ if done:
229
+ break
230
+
231
+ # Normalize score
232
+ if rewards:
233
+ score = sum(rewards) / len(rewards)
234
+ score = max(0.0, min(1.0, score))
235
+
236
+ success = score >= SUCCESS_SCORE_THRESHOLD
237
+
238
+ finally:
239
+ try:
240
+ await env.close()
241
+ except Exception as e:
242
+ print(f"[DEBUG] env.close error: {e}", flush=True)
243
+
244
+ log_end(success, steps_taken, score, rewards)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ asyncio.run(main())