Vittal-M commited on
Commit
91ced0a
·
verified ·
1 Parent(s): 5a735ce

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +253 -0
inference.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference script for the Scheduling Optimisation Environment.
2
+
3
+ Emits exactly three line types per episode:
4
+ [START] task=<task_name> env=<benchmark> model=<model_name>
5
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
6
+ [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...,rn>
7
+
8
+ Required environment variables:
9
+ API_BASE_URL — Base URL for the OpenAI-compatible API endpoint
10
+ MODEL_NAME — Model identifier to use for inference
11
+ HF_TOKEN — Your Hugging Face / API key
12
+
13
+ Usage (oracle mock — no API key needed):
14
+ python inference.py
15
+
16
+ Usage (real LLM):
17
+ API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4o-mini HF_TOKEN=sk-... python inference.py
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import os
24
+ import sys
25
+ from typing import List, Optional
26
+
27
+ from openai import OpenAI
28
+
29
+ from environment import INSTANCE_BANK, SchedulingOptEnv
30
+ from models import Action
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Configuration
34
+ # ---------------------------------------------------------------------------
35
+
36
+ API_BASE_URL: str = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
37
+ MODEL_NAME: str = os.getenv("MODEL_NAME") or "gpt-4o-mini"
38
+ HF_TOKEN: str = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or ""
39
+ BENCHMARK: str = "scheduling-opt-env"
40
+ SUCCESS_THRESHOLD: float = 0.95
41
+
42
+ USE_LLM: bool = bool(HF_TOKEN)
43
+
44
+ if not USE_LLM:
45
+ print("[WARN] HF_TOKEN not set — using oracle mock responses.", file=sys.stderr, flush=True)
46
+
47
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "no-key")
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Structured log helpers (exact required format)
51
+ # ---------------------------------------------------------------------------
52
+
53
+
54
+ def log_start(task: str, env: str, model: str) -> None:
55
+ print(f"[START] task={task} env={env} model={model}", flush=True)
56
+
57
+
58
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
59
+ error_val = error if error else "null"
60
+ done_val = str(done).lower()
61
+ # Sanitise action: collapse newlines and truncate to keep lines readable
62
+ action_clean = action.replace("\n", " ").replace("\r", "")[:120]
63
+ print(
64
+ f"[STEP] step={step} action={action_clean} reward={reward:.2f} done={done_val} error={error_val}",
65
+ flush=True,
66
+ )
67
+
68
+
69
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
70
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
71
+ print(
72
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
73
+ flush=True,
74
+ )
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # LLM helper
79
+ # ---------------------------------------------------------------------------
80
+
81
+
82
+ def _llm(system: str, user: str) -> str:
83
+ try:
84
+ resp = client.chat.completions.create(
85
+ model=MODEL_NAME,
86
+ messages=[
87
+ {"role": "system", "content": system},
88
+ {"role": "user", "content": user},
89
+ ],
90
+ max_tokens=1024,
91
+ temperature=0.0,
92
+ )
93
+ return (resp.choices[0].message.content or "").strip()
94
+ except Exception as exc:
95
+ print(f"[DEBUG] LLM error: {exc}", file=sys.stderr, flush=True)
96
+ return ""
97
+
98
+
99
+ # ---------------------------------------------------------------------------
100
+ # Oracle mock responses (used when HF_TOKEN is absent)
101
+ # ---------------------------------------------------------------------------
102
+
103
+ _MOCK_FEASIBILITY: dict[int, str] = {
104
+ 0: "infeasible", 1: "infeasible", 2: "infeasible", 3: "infeasible",
105
+ 4: "infeasible", 5: "infeasible", 6: "infeasible", 7: "infeasible",
106
+ 8: "infeasible", 9: "infeasible", 10: "feasible", 11: "feasible",
107
+ }
108
+
109
+ _MOCK_CLASSIFICATION: dict[int, str] = {
110
+ 0: "resource_overload", 1: "deadline_violation",
111
+ 2: "precedence_violation", 3: "availability_conflict",
112
+ 4: "capacity_exceeded", 5: "resource_overload",
113
+ 6: "deadline_violation", 7: "precedence_violation",
114
+ 8: "availability_conflict",9: "capacity_exceeded",
115
+ }
116
+
117
+
118
+ def _mock_repair(idx: int) -> str:
119
+ entry = INSTANCE_BANK[idx]
120
+ sched = entry.get("optimal_schedule") or entry["instance"].get("proposed_schedule", {})
121
+ return json.dumps(sched)
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Per-task agent prompts
126
+ # ---------------------------------------------------------------------------
127
+
128
+
129
+ def _agent_feasibility(instance_str: str, instance_idx: int) -> str:
130
+ if not USE_LLM:
131
+ return _MOCK_FEASIBILITY.get(instance_idx, "infeasible")
132
+ return _llm(
133
+ "You are a scheduling expert. Determine if the proposed schedule satisfies "
134
+ "all constraints. Reply with ONLY 'feasible' or 'infeasible'. No extra text.",
135
+ instance_str,
136
+ )
137
+
138
+
139
+ def _agent_classification(instance_str: str, instance_idx: int) -> str:
140
+ if not USE_LLM:
141
+ return _MOCK_CLASSIFICATION.get(instance_idx, "resource_overload")
142
+ return _llm(
143
+ "You are a scheduling expert. Identify the single constraint violation type. "
144
+ "Reply with ONLY one of: resource_overload, deadline_violation, "
145
+ "precedence_violation, availability_conflict, capacity_exceeded. No extra text.",
146
+ instance_str,
147
+ )
148
+
149
+
150
+ def _agent_repair(instance_str: str, instance_idx: int) -> str:
151
+ if not USE_LLM:
152
+ return _mock_repair(instance_idx)
153
+ return _llm(
154
+ 'You are a scheduling expert. Repair the infeasible schedule. Return ONLY a '
155
+ 'valid JSON object: {"assignments": [{"job_id": "...", "machine_id": "...", '
156
+ '"start_time": <int>}, ...]}. No markdown, no explanation.',
157
+ instance_str,
158
+ )
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Single episode runner
163
+ # ---------------------------------------------------------------------------
164
+
165
+ TASK_CONFIG = {
166
+ "feasibility_check": {"max_steps": 3, "agent": _agent_feasibility},
167
+ "conflict_classification":{"max_steps": 5, "agent": _agent_classification},
168
+ "schedule_repair": {"max_steps": 8, "agent": _agent_repair},
169
+ }
170
+
171
+
172
+ def run_episode(
173
+ env: SchedulingOptEnv,
174
+ task_id: str,
175
+ instance_idx: int,
176
+ instance_entry: dict,
177
+ ) -> None:
178
+ """Run one episode and emit [START] / [STEP]s / [END]."""
179
+ cfg = TASK_CONFIG[task_id]
180
+ max_steps: int = cfg["max_steps"]
181
+ agent_fn = cfg["agent"]
182
+ instance_str = json.dumps(instance_entry["instance"], indent=2)
183
+
184
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
185
+
186
+ obs = env.reset(task_id=task_id)
187
+
188
+ rewards: List[float] = []
189
+ steps_taken = 0
190
+ success = False
191
+
192
+ try:
193
+ for step in range(1, max_steps + 1):
194
+ response = agent_fn(instance_str, instance_idx)
195
+ action = Action(response=response, task_id=task_id)
196
+
197
+ obs, reward, done, info = env.step(action)
198
+
199
+ error = info.get("grading_breakdown", {}).get("feedback") if reward < SUCCESS_THRESHOLD else None
200
+ # Only surface error string for failed/partial steps
201
+ if reward >= SUCCESS_THRESHOLD:
202
+ error = None
203
+
204
+ rewards.append(reward)
205
+ steps_taken = step
206
+ log_step(step=step, action=response, reward=reward, done=done, error=error)
207
+
208
+ if done:
209
+ break
210
+
211
+ final_reward = rewards[-1] if rewards else 0.0
212
+ score = min(max(final_reward, 0.0), 1.0)
213
+ success = score >= SUCCESS_THRESHOLD
214
+
215
+ except Exception as exc:
216
+ print(f"[DEBUG] Episode error: {exc}", file=sys.stderr, flush=True)
217
+ if not rewards:
218
+ rewards = [0.0]
219
+ score = 0.0
220
+
221
+ finally:
222
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
223
+
224
+
225
+ # ---------------------------------------------------------------------------
226
+ # Main — run all 32 episodes across 3 tasks
227
+ # ---------------------------------------------------------------------------
228
+
229
+
230
+ def main() -> None:
231
+ env = SchedulingOptEnv()
232
+
233
+ # Task 1: Feasibility Check — all 12 instances
234
+ for i, entry in enumerate(INSTANCE_BANK):
235
+ run_episode(env, "feasibility_check", i, entry)
236
+
237
+ # Task 2: Conflict Classification — 10 infeasible instances only
238
+ for i, entry in enumerate(INSTANCE_BANK):
239
+ if not entry["is_feasible"]:
240
+ run_episode(env, "conflict_classification", i, entry)
241
+
242
+ # Task 3: Schedule Repair — 10 infeasible instances only
243
+ for i, entry in enumerate(INSTANCE_BANK):
244
+ if not entry["is_feasible"]:
245
+ run_episode(env, "schedule_repair", i, entry)
246
+
247
+
248
+ if __name__ == "__main__":
249
+ try:
250
+ main()
251
+ except Exception as exc:
252
+ print(f"[ERROR] {exc}", file=sys.stderr, flush=True)
253
+ sys.exit(1)