Spaces:
Sleeping
Sleeping
| """Baseline inference runner for the Ecom returns decision environment.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| import textwrap | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional | |
| from openai import OpenAI | |
| from ecom import EcomAction, EcomEnv | |
| BENCHMARK = "ecom_returns_decision" | |
| MAX_STEPS = 5 | |
| TEMPERATURE = 0 | |
| MAX_TOKENS = 180 | |
| SYSTEM_PROMPT = textwrap.dedent( | |
| """ | |
| You are a returns operations agent. | |
| Choose exactly one action in JSON only. | |
| Allowed action_type values: | |
| - APPROVE | |
| - REJECT | |
| - ESCALATE | |
| - REQUEST_INFO | |
| If action_type is REJECT, include reason_code with one of: | |
| - TIME_EXPIRED | |
| - POLICY_VIOLATION | |
| - SUSPECTED_FRAUD | |
| Output JSON only. No prose, no markdown. | |
| """ | |
| ).strip() | |
| def _model_name() -> str: | |
| return os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| def _image_name() -> Optional[str]: | |
| return os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME") | |
| def _env_base_url() -> Optional[str]: | |
| return os.getenv("ENV_BASE_URL") | |
| def _task_names() -> List[str]: | |
| task_name = os.getenv("ECOM_TASK_NAME") or os.getenv("ECOM_TASK") | |
| if task_name: | |
| return [task_name] | |
| return [ | |
| "easy_policy_compliance", | |
| "medium_balanced_judgment", | |
| "hard_conflicting_signals", | |
| ] | |
| class EpisodeOutcome: | |
| success: bool | |
| steps: int | |
| score: float | |
| rewards: List[float] | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step( | |
| step: int, action: str, reward: float, done: bool, error: Optional[str] | |
| ) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def format_action(action: EcomAction) -> str: | |
| if action.reason_code is None: | |
| return action.action_type | |
| return f"{action.action_type}({action.reason_code})" | |
| def extract_return_window(policy_summary: str) -> int: | |
| match = re.search(r"within\s+(\d+)\s+days", policy_summary, flags=re.IGNORECASE) | |
| if match: | |
| return int(match.group(1)) | |
| return 30 | |
| def exception_applies(observation: Any) -> bool: | |
| reason = str(observation.return_reason).lower() | |
| policy_summary = str(observation.policy_summary).lower() | |
| match = re.search(r"exception:\s*([^.]*)", policy_summary) | |
| clause = match.group(1) if match else "" | |
| if reason == "damaged-shipping" and ( | |
| "damage in transit" in clause or "damaged" in clause | |
| ): | |
| return True | |
| if reason == "defective" and "defective" in clause: | |
| return True | |
| return False | |
| def is_restricted_class_case(observation: Any) -> bool: | |
| return "restricted class" in str(observation.product_condition_notes).lower() | |
| def should_reject_time_expired( | |
| observation: Any, window: int, has_exception: bool | |
| ) -> bool: | |
| if observation.days_since_purchase <= window: | |
| return False | |
| if has_exception and not is_restricted_class_case(observation): | |
| return False | |
| return True | |
| def _safe_json_parse(text: str) -> Optional[Dict[str, Any]]: | |
| text = text.strip() | |
| if not text: | |
| return None | |
| try: | |
| parsed = json.loads(text) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| return None | |
| except json.JSONDecodeError: | |
| pass | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start == -1 or end == -1 or end <= start: | |
| return None | |
| try: | |
| parsed = json.loads(text[start : end + 1]) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except json.JSONDecodeError: | |
| return None | |
| return None | |
| def _extract_last_action_error(observation: Any) -> Optional[str]: | |
| info = getattr(observation, "info", None) | |
| if not isinstance(info, dict): | |
| return None | |
| for key in ("last_action_error", "invalid_action"): | |
| value = info.get(key) | |
| if value is not None: | |
| return str(value) | |
| return None | |
| def _extract_available_actions(observation: Any) -> List[str]: | |
| info = getattr(observation, "info", None) | |
| if not isinstance(info, dict): | |
| return [] | |
| raw = info.get("available_actions") | |
| if not isinstance(raw, list): | |
| return [] | |
| return [str(value) for value in raw] | |
| def _extract_reject_reason_codes(observation: Any) -> List[str]: | |
| info = getattr(observation, "info", None) | |
| if not isinstance(info, dict): | |
| return [] | |
| raw = info.get("reject_reason_codes") | |
| if not isinstance(raw, list): | |
| return [] | |
| return [str(value) for value in raw] | |
| def _enforce_action_contract( | |
| observation: Any, action: EcomAction | |
| ) -> Optional[EcomAction]: | |
| available_actions = _extract_available_actions(observation) | |
| if available_actions and action.action_type not in set(available_actions): | |
| return None | |
| if action.action_type == "REJECT": | |
| valid_reasons = set(_extract_reject_reason_codes(observation)) | |
| if valid_reasons and action.reason_code not in valid_reasons: | |
| return None | |
| return action | |
| def heuristic_policy(observation: Any, step: int) -> EcomAction: | |
| available_actions = set(_extract_available_actions(observation)) | |
| window = extract_return_window(observation.policy_summary) | |
| has_exception = exception_applies(observation) | |
| notes = str(observation.product_condition_notes).lower() | |
| reason = str(observation.return_reason) | |
| return_rate = float(observation.return_rate) | |
| ambiguous = ( | |
| ("mixed indicators" in notes) | |
| or ("conflict" in notes) | |
| or (0.40 <= return_rate <= 0.65) | |
| or (observation.days_since_purchase > window and has_exception) | |
| ) | |
| if step == 1 and (not available_actions or "REQUEST_INFO" in available_actions): | |
| if ambiguous: | |
| return EcomAction(action_type="REQUEST_INFO") | |
| if should_reject_time_expired(observation, window, has_exception): | |
| if available_actions and "REJECT" not in available_actions: | |
| return EcomAction(action_type="APPROVE") | |
| return EcomAction(action_type="REJECT", reason_code="TIME_EXPIRED") | |
| if "restricted class" in notes: | |
| if available_actions and "REJECT" not in available_actions: | |
| return EcomAction(action_type="APPROVE") | |
| return EcomAction(action_type="REJECT", reason_code="POLICY_VIOLATION") | |
| if ( | |
| step >= 2 | |
| and observation.product_value == "high" | |
| and return_rate >= 0.50 | |
| and ( | |
| "conflict" in notes | |
| or "disputed evidence" in notes | |
| or reason in ("changed-mind", "wrong-item") | |
| ) | |
| ): | |
| if available_actions and "REJECT" not in available_actions: | |
| return EcomAction(action_type="APPROVE") | |
| return EcomAction(action_type="REJECT", reason_code="SUSPECTED_FRAUD") | |
| if return_rate >= 0.60 and observation.product_value == "high": | |
| if available_actions and "REJECT" not in available_actions: | |
| return EcomAction(action_type="APPROVE") | |
| return EcomAction(action_type="REJECT", reason_code="SUSPECTED_FRAUD") | |
| if reason in ("defective", "wrong-item", "damaged-shipping") and return_rate < 0.55: | |
| if available_actions and "APPROVE" not in available_actions: | |
| return EcomAction(action_type="ESCALATE") | |
| return EcomAction(action_type="APPROVE") | |
| if return_rate >= 0.55: | |
| if available_actions and "ESCALATE" not in available_actions: | |
| return EcomAction(action_type="APPROVE") | |
| return EcomAction(action_type="ESCALATE") | |
| if available_actions and "APPROVE" not in available_actions: | |
| if "ESCALATE" in available_actions: | |
| return EcomAction(action_type="ESCALATE") | |
| if "REJECT" in available_actions: | |
| return EcomAction(action_type="REJECT", reason_code="SUSPECTED_FRAUD") | |
| return EcomAction(action_type="APPROVE") | |
| def build_user_prompt(step: int, observation: Any, history: List[str]) -> str: | |
| history_block = "\n".join(history[-4:]) if history else "None" | |
| available_actions = _extract_available_actions(observation) | |
| reject_reason_codes = _extract_reject_reason_codes(observation) | |
| prompt = textwrap.dedent( | |
| f""" | |
| Step: {step} | |
| return_reason: {observation.return_reason} | |
| product_category: {observation.product_category} | |
| product_value: {observation.product_value} | |
| days_since_purchase: {observation.days_since_purchase} | |
| user_account_age_days: {observation.user_account_age_days} | |
| product_condition_notes: {observation.product_condition_notes} | |
| return_rate: {float(observation.return_rate):.3f} | |
| total_orders: {observation.total_orders} | |
| policy_summary: {observation.policy_summary} | |
| available_actions: {", ".join(available_actions) if available_actions else "None"} | |
| available_reject_reason_codes: {", ".join(reject_reason_codes) if reject_reason_codes else "None"} | |
| Previous steps: | |
| {history_block} | |
| """ | |
| ).strip() | |
| return prompt | |
| def get_model_action( | |
| client: OpenAI, step: int, observation: Any, history: List[str] | |
| ) -> Optional[EcomAction]: | |
| user_prompt = build_user_prompt(step, observation, history) | |
| try: | |
| completion = client.chat.completions.create( | |
| model=_model_name(), | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| stream=False, | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| except Exception: | |
| return None | |
| data = _safe_json_parse(text) | |
| if data is None: | |
| return None | |
| action_type = str(data.get("action_type", "")).strip().upper() | |
| reason_code = data.get("reason_code") | |
| if reason_code is not None: | |
| reason_code = str(reason_code).strip().upper() | |
| if action_type == "REJECT": | |
| if reason_code not in { | |
| "TIME_EXPIRED", | |
| "POLICY_VIOLATION", | |
| "SUSPECTED_FRAUD", | |
| }: | |
| return None | |
| action = EcomAction(action_type="REJECT", reason_code=reason_code) | |
| return _enforce_action_contract(observation, action) | |
| if action_type in {"APPROVE", "ESCALATE", "REQUEST_INFO"}: | |
| action = EcomAction(action_type=action_type) | |
| return _enforce_action_contract(observation, action) | |
| return None | |
| def _build_llm_client() -> OpenAI: | |
| api_base_url = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| api_key = os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| if not api_key: | |
| raise RuntimeError("HF_TOKEN or API_KEY environment variable is required.") | |
| return OpenAI( | |
| base_url=api_base_url, | |
| api_key=api_key, | |
| ) | |
| def _probe_llm_proxy(client: OpenAI) -> None: | |
| try: | |
| client.chat.completions.create( | |
| model=_model_name(), | |
| messages=[ | |
| {"role": "system", "content": "Reply with OK."}, | |
| {"role": "user", "content": "OK"}, | |
| ], | |
| temperature=0, | |
| max_tokens=2, | |
| stream=False, | |
| ) | |
| except Exception as exc: | |
| raise RuntimeError( | |
| "Failed to make an LLM request through API_BASE_URL using HF_TOKEN/API_KEY." | |
| ) from exc | |
| async def run_task(task_name: str, client: OpenAI) -> EpisodeOutcome: | |
| history: List[str] = [] | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| score = 0.0 | |
| success = False | |
| env: Optional[EcomEnv] = None | |
| log_start(task=task_name, env=BENCHMARK, model=_model_name()) | |
| try: | |
| env_base_url = _env_base_url() | |
| image_name = _image_name() | |
| if env_base_url: | |
| env = EcomEnv(base_url=env_base_url) | |
| await env.connect() | |
| else: | |
| if not image_name: | |
| raise RuntimeError( | |
| "IMAGE_NAME or LOCAL_IMAGE_NAME is required when ENV_BASE_URL is not set" | |
| ) | |
| env = await EcomEnv.from_docker_image(image_name) | |
| result = await env.reset(task_name=task_name) | |
| for step in range(1, MAX_STEPS + 1): | |
| if result.done: | |
| break | |
| observation = result.observation | |
| action = get_model_action(client, step, observation, history) | |
| if action is None: | |
| action = heuristic_policy(observation, step) | |
| result = await env.step(action) | |
| reward = float(result.reward or 0.0) | |
| done = bool(result.done) | |
| error = _extract_last_action_error(result.observation) | |
| rewards.append(reward) | |
| steps_taken = step | |
| log_step( | |
| step=step, | |
| action=format_action(action), | |
| reward=reward, | |
| done=done, | |
| error=error, | |
| ) | |
| history.append( | |
| f"Step {step}: {format_action(action)} -> reward {reward:.2f} error={error or 'null'}" | |
| ) | |
| if done: | |
| info = result.observation.info | |
| if isinstance(info, dict): | |
| success = bool(info.get("grader_success", False)) | |
| raw_score = info.get("grader_score", 0.0) | |
| try: | |
| score = float(raw_score) | |
| except (TypeError, ValueError): | |
| score = 0.0 | |
| score = max(0.0, min(1.0, score)) | |
| break | |
| except Exception: | |
| success = False | |
| finally: | |
| if env is not None: | |
| try: | |
| await env.close() | |
| except Exception: | |
| pass | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| return EpisodeOutcome( | |
| success=success, | |
| steps=steps_taken, | |
| score=score, | |
| rewards=rewards, | |
| ) | |
| async def main() -> None: | |
| client = _build_llm_client() | |
| _probe_llm_proxy(client) | |
| if not _env_base_url() and not _image_name(): | |
| raise RuntimeError( | |
| "Set ENV_BASE_URL or IMAGE_NAME/LOCAL_IMAGE_NAME before running inference.py" | |
| ) | |
| for task_name in _task_names(): | |
| await run_task(task_name, client) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |