Lonelyguyse1 commited on
Commit
df67741
·
verified ·
1 Parent(s): ce9fb65

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +336 -0
inference.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Baseline inference runner for the Ecom returns decision environment.
2
+
3
+ This script follows the required structured stdout format:
4
+ [START] ...
5
+ [STEP] ...
6
+ [END] ...
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import asyncio
12
+ import json
13
+ import os
14
+ import re
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ from openenv.core.client_types import StepResult
19
+
20
+ from ecom import EcomAction, EcomEnv
21
+ from ecom.server.ecom_environment import EcomEnvironment
22
+
23
+ try:
24
+ from openai import OpenAI
25
+ except Exception: # pragma: no cover
26
+ OpenAI = None # type: ignore[assignment]
27
+
28
+
29
+ HF_TOKEN = os.getenv("HF_TOKEN")
30
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
31
+ API_KEY = OPENAI_API_KEY or HF_TOKEN or os.getenv("API_KEY")
32
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
33
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
34
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
35
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL")
36
+
37
+ BENCHMARK = os.getenv("ECOM_BENCHMARK", "ecom_returns_decision")
38
+ MAX_STEPS = 3
39
+ MAX_TOKENS = 180
40
+
41
+ TASKS: List[str] = [
42
+ "easy_policy_compliance",
43
+ "medium_balanced_judgment",
44
+ "hard_conflicting_signals",
45
+ ]
46
+
47
+
48
+ @dataclass
49
+ class EpisodeOutcome:
50
+ success: bool
51
+ steps: int
52
+ rewards: List[float]
53
+
54
+
55
+ class LocalEnvRunner:
56
+ """Fallback runner when Docker or remote endpoint is unavailable."""
57
+
58
+ def __init__(self, mode: str = "medium"):
59
+ self._env = EcomEnvironment(mode=mode) # type: ignore[arg-type]
60
+
61
+ async def reset(self, **kwargs: Any) -> StepResult[Any]:
62
+ obs = self._env.reset(**kwargs)
63
+ return StepResult(observation=obs, reward=obs.reward, done=obs.done)
64
+
65
+ async def step(self, action: EcomAction, **kwargs: Any) -> StepResult[Any]:
66
+ del kwargs
67
+ obs = self._env.step(action)
68
+ return StepResult(observation=obs, reward=obs.reward, done=obs.done)
69
+
70
+ async def close(self) -> None:
71
+ self._env.close()
72
+
73
+
74
+ def log_start(task: str, env: str, model: str) -> None:
75
+ print(f"[START] task={task} env={env} model={model}", flush=True)
76
+
77
+
78
+ def log_step(
79
+ step: int, action: str, reward: float, done: bool, error: Optional[str]
80
+ ) -> None:
81
+ done_val = str(done).lower()
82
+ error_val = "null" if error is None else error
83
+ print(
84
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
85
+ flush=True,
86
+ )
87
+
88
+
89
+ def log_end(success: bool, steps: int, rewards: List[float]) -> None:
90
+ success_val = str(success).lower()
91
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
92
+ print(
93
+ f"[END] success={success_val} steps={steps} rewards={rewards_str}", flush=True
94
+ )
95
+
96
+
97
+ def format_action(action: EcomAction) -> str:
98
+ if action.reason_code is None:
99
+ return action.action_type
100
+ return f"{action.action_type}({action.reason_code})"
101
+
102
+
103
+ def extract_return_window(policy_summary: str) -> int:
104
+ match = re.search(r"within\s+(\d+)\s+days", policy_summary, flags=re.IGNORECASE)
105
+ if match:
106
+ return int(match.group(1))
107
+ return 30
108
+
109
+
110
+ def _safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
111
+ text = text.strip()
112
+ if not text:
113
+ return None
114
+
115
+ try:
116
+ parsed = json.loads(text)
117
+ if isinstance(parsed, dict):
118
+ return parsed
119
+ return None
120
+ except json.JSONDecodeError:
121
+ pass
122
+
123
+ start = text.find("{")
124
+ end = text.rfind("}")
125
+ if start == -1 or end == -1 or end <= start:
126
+ return None
127
+ try:
128
+ parsed = json.loads(text[start : end + 1])
129
+ if isinstance(parsed, dict):
130
+ return parsed
131
+ except json.JSONDecodeError:
132
+ return None
133
+ return None
134
+
135
+
136
+ def _extract_last_action_error(observation: Any) -> Optional[str]:
137
+ if not hasattr(observation, "info"):
138
+ return None
139
+ info = observation.info
140
+ if not isinstance(info, dict):
141
+ return None
142
+ value = info.get("last_action_error")
143
+ if value is None:
144
+ return None
145
+ return str(value)
146
+
147
+
148
+ def heuristic_policy(observation: Any, step: int) -> EcomAction:
149
+ window = extract_return_window(observation.policy_summary)
150
+ notes = observation.product_condition_notes.lower()
151
+ reason = observation.return_reason
152
+ return_rate = float(observation.return_rate)
153
+
154
+ if observation.days_since_purchase > window:
155
+ return EcomAction(action_type="REJECT", reason_code="TIME_EXPIRED")
156
+
157
+ if "restricted class" in notes:
158
+ return EcomAction(action_type="REJECT", reason_code="POLICY_VIOLATION")
159
+
160
+ if return_rate >= 0.60 and observation.product_value == "high":
161
+ return EcomAction(action_type="REJECT", reason_code="SUSPECTED_FRAUD")
162
+
163
+ if reason in ("defective", "wrong-item", "damaged-shipping") and return_rate < 0.55:
164
+ return EcomAction(action_type="APPROVE")
165
+
166
+ ambiguous = (
167
+ ("mixed indicators" in notes)
168
+ or ("conflict" in notes)
169
+ or (0.40 <= return_rate <= 0.60)
170
+ )
171
+ if step == 1 and ambiguous:
172
+ return EcomAction(action_type="REQUEST_INFO")
173
+
174
+ if return_rate >= 0.55:
175
+ return EcomAction(action_type="ESCALATE")
176
+ return EcomAction(action_type="APPROVE")
177
+
178
+
179
+ def model_policy(
180
+ client: Optional[Any], observation: Any, step: int
181
+ ) -> Optional[EcomAction]:
182
+ if client is None:
183
+ return None
184
+
185
+ prompt = (
186
+ "You are a returns operations agent. Choose one action JSON only.\n"
187
+ "Allowed action_type: APPROVE, REJECT, ESCALATE, REQUEST_INFO\n"
188
+ "If action_type is REJECT, include reason_code with one of: "
189
+ "TIME_EXPIRED, POLICY_VIOLATION, SUSPECTED_FRAUD\n"
190
+ "Output ONLY JSON, no prose.\n\n"
191
+ f"Step: {step}\n"
192
+ f"return_reason: {observation.return_reason}\n"
193
+ f"product_category: {observation.product_category}\n"
194
+ f"product_value: {observation.product_value}\n"
195
+ f"days_since_purchase: {observation.days_since_purchase}\n"
196
+ f"user_account_age_days: {observation.user_account_age_days}\n"
197
+ f"product_condition_notes: {observation.product_condition_notes}\n"
198
+ f"return_rate: {observation.return_rate:.3f}\n"
199
+ f"total_orders: {observation.total_orders}\n"
200
+ f"policy_summary: {observation.policy_summary}\n"
201
+ )
202
+
203
+ try:
204
+ response = client.chat.completions.create(
205
+ model=MODEL_NAME,
206
+ messages=[
207
+ {"role": "system", "content": "Respond with valid JSON only."},
208
+ {"role": "user", "content": prompt},
209
+ ],
210
+ temperature=0,
211
+ max_tokens=MAX_TOKENS,
212
+ )
213
+ text = (response.choices[0].message.content or "").strip()
214
+ data = _safe_json_parse(text)
215
+ if data is None:
216
+ return None
217
+
218
+ action_type = str(data.get("action_type", "")).strip().upper()
219
+ reason_code = data.get("reason_code")
220
+ if reason_code is not None:
221
+ reason_code = str(reason_code).strip().upper()
222
+
223
+ if action_type == "REJECT":
224
+ if reason_code not in {
225
+ "TIME_EXPIRED",
226
+ "POLICY_VIOLATION",
227
+ "SUSPECTED_FRAUD",
228
+ }:
229
+ return None
230
+ return EcomAction(action_type="REJECT", reason_code=reason_code)
231
+
232
+ if action_type in {"APPROVE", "ESCALATE", "REQUEST_INFO"}:
233
+ return EcomAction(action_type=action_type)
234
+ except Exception:
235
+ return None
236
+
237
+ return None
238
+
239
+
240
+ async def run_task(task_name: str, client: Optional[Any]) -> EpisodeOutcome:
241
+ rewards: List[float] = []
242
+ steps_taken = 0
243
+ success = False
244
+ env: Optional[Any] = None
245
+
246
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
247
+
248
+ try:
249
+ if ENV_BASE_URL:
250
+ env = EcomEnv(base_url=ENV_BASE_URL)
251
+ await env.connect()
252
+ else:
253
+ if not LOCAL_IMAGE_NAME:
254
+ raise RuntimeError(
255
+ "LOCAL_IMAGE_NAME is required when ENV_BASE_URL is not set"
256
+ )
257
+ env = await EcomEnv.from_docker_image(LOCAL_IMAGE_NAME)
258
+
259
+ result = await env.reset(task_name=task_name)
260
+
261
+ for step in range(1, MAX_STEPS + 1):
262
+ observation = result.observation
263
+
264
+ action = model_policy(client, observation, step)
265
+ if action is None:
266
+ action = heuristic_policy(observation, step)
267
+
268
+ result = await env.step(action)
269
+
270
+ reward = float(result.reward or 0.0)
271
+ done = bool(result.done)
272
+ error = _extract_last_action_error(result.observation)
273
+
274
+ rewards.append(reward)
275
+ steps_taken = step
276
+ log_step(
277
+ step=step,
278
+ action=format_action(action),
279
+ reward=reward,
280
+ done=done,
281
+ error=error,
282
+ )
283
+
284
+ if done:
285
+ success = bool(result.observation.info.get("grader_success", False))
286
+ break
287
+
288
+ except Exception:
289
+ # Fallback to deterministic local execution for reproducible baseline.
290
+ env = LocalEnvRunner(mode="medium")
291
+ result = await env.reset(task_name=task_name)
292
+
293
+ for step in range(1, MAX_STEPS + 1):
294
+ observation = result.observation
295
+ action = heuristic_policy(observation, step)
296
+ result = await env.step(action)
297
+
298
+ reward = float(result.reward or 0.0)
299
+ done = bool(result.done)
300
+ error = _extract_last_action_error(result.observation)
301
+
302
+ rewards.append(reward)
303
+ steps_taken = step
304
+ log_step(
305
+ step=step,
306
+ action=format_action(action),
307
+ reward=reward,
308
+ done=done,
309
+ error=error,
310
+ )
311
+
312
+ if done:
313
+ success = bool(result.observation.info.get("grader_success", False))
314
+ break
315
+ finally:
316
+ if env is not None:
317
+ try:
318
+ await env.close()
319
+ except Exception:
320
+ pass
321
+ log_end(success=success, steps=steps_taken, rewards=rewards)
322
+
323
+ return EpisodeOutcome(success=success, steps=steps_taken, rewards=rewards)
324
+
325
+
326
+ async def main() -> None:
327
+ client = None
328
+ if OpenAI is not None and API_KEY:
329
+ client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
330
+
331
+ for task_name in TASKS:
332
+ await run_task(task_name, client)
333
+
334
+
335
+ if __name__ == "__main__":
336
+ asyncio.run(main())