umar-sharif821 commited on
Commit
3589be9
Β·
1 Parent(s): 06c3a1d

feat: update inference logic in inference.py

Browse files
Files changed (1) hide show
  1. inference.py +7 -47
inference.py CHANGED
@@ -27,7 +27,7 @@ HF_TOKEN = os.getenv("HF_TOKEN")
27
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
28
 
29
  if not HF_TOKEN:
30
- print("[WARN] HF_TOKEN not set. Using API_BASE_URL without auth header override.")
31
 
32
  client = OpenAI(
33
  base_url=API_BASE_URL,
@@ -127,17 +127,7 @@ def run_task(task_id: str) -> dict:
127
  step_num = 0
128
 
129
  # ── [START] ──
130
- print(json.dumps({
131
- "type": "START",
132
- "task_id": task_id,
133
- "task_name": config.name,
134
- "difficulty": config.difficulty,
135
- "episode_length": config.episode_length,
136
- "cache_capacity_mb": config.cache_capacity_mb,
137
- "model": MODEL_NAME,
138
- "seed": SEED,
139
- }))
140
- sys.stdout.flush()
141
 
142
  while True:
143
  action = llm_action(obs, step_num)
@@ -146,26 +136,7 @@ def run_task(task_id: str) -> dict:
146
  total_reward += result.reward.total
147
 
148
  # ── [STEP] ──
149
- print(json.dumps({
150
- "type": "STEP",
151
- "task_id": task_id,
152
- "step": step_num,
153
- "action": {"evict_file_id": action.evict_file_id},
154
- "cache_hit": result.observation.cache_hit,
155
- "reward": result.reward.total,
156
- "reward_breakdown": {
157
- "cache_hit_bonus": result.reward.cache_hit_bonus,
158
- "eviction_penalty": result.reward.eviction_penalty,
159
- "thrash_penalty": result.reward.thrash_penalty,
160
- "bandwidth_saved": result.reward.bandwidth_saved,
161
- "wasted_capacity_penalty": result.reward.wasted_capacity_penalty,
162
- },
163
- "cumulative_reward": round(total_reward, 4),
164
- "hit_rate": result.observation.recent_hit_rate,
165
- "cache_fill": result.observation.cache_fill_ratio,
166
- "done": result.done,
167
- }))
168
- sys.stdout.flush()
169
 
170
  obs = result.observation
171
  step_num += 1
@@ -173,29 +144,18 @@ def run_task(task_id: str) -> dict:
173
  if result.done:
174
  break
175
 
176
- final_state = env.state()
177
  final_hit_rate = final_state["hit_rate"]
 
178
 
179
  # ── [END] ──
180
- print(json.dumps({
181
- "type": "END",
182
- "task_id": task_id,
183
- "task_name": config.name,
184
- "total_steps": step_num,
185
- "total_reward": round(total_reward, 4),
186
- "final_hit_rate": round(final_hit_rate, 4),
187
- "bandwidth_saved_mb": round(final_state["bandwidth_saved_mb"], 2),
188
- "total_hits": final_state["hits"],
189
- "total_misses": final_state["misses"],
190
- "score": round(min(1.0, final_hit_rate / {"task_easy": 0.60, "task_medium": 0.55, "task_hard": 0.45}[task_id]), 4),
191
- }))
192
- sys.stdout.flush()
193
 
194
  return {
195
  "task_id": task_id,
196
  "total_reward": round(total_reward, 4),
197
  "final_hit_rate": round(final_hit_rate, 4),
198
- "score": round(min(1.0, final_hit_rate / {"task_easy": 0.60, "task_medium": 0.55, "task_hard": 0.45}[task_id]), 4),
199
  }
200
 
201
 
 
27
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
28
 
29
  if not HF_TOKEN:
30
+ print("[WARN] HF_TOKEN not set. Using API_BASE_URL without auth header override.", file=sys.stderr)
31
 
32
  client = OpenAI(
33
  base_url=API_BASE_URL,
 
127
  step_num = 0
128
 
129
  # ── [START] ──
130
+ print(f"[START] task={task_id}", flush=True)
 
 
 
 
 
 
 
 
 
 
131
 
132
  while True:
133
  action = llm_action(obs, step_num)
 
136
  total_reward += result.reward.total
137
 
138
  # ── [STEP] ──
139
+ print(f"[STEP] step={step_num} reward={round(result.reward.total, 4)}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  obs = result.observation
142
  step_num += 1
 
144
  if result.done:
145
  break
146
 
147
+ final_state = env.state()
148
  final_hit_rate = final_state["hit_rate"]
149
+ score = round(min(1.0, final_hit_rate / {"task_easy": 0.60, "task_medium": 0.55, "task_hard": 0.45}[task_id]), 4)
150
 
151
  # ── [END] ──
152
+ print(f"[END] task={task_id} score={score} steps={step_num}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  return {
155
  "task_id": task_id,
156
  "total_reward": round(total_reward, 4),
157
  "final_hit_rate": round(final_hit_rate, 4),
158
+ "score": score,
159
  }
160
 
161