Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Baseline inference script for ProcureRL. | |
| Runs an LLM agent against the procurement negotiation environment | |
| and outputs results in exact [START][STEP][END] format. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| BENCHMARK = "procure-rl" | |
| MAX_STEPS = 10 | |
| try: | |
| from openai import OpenAI | |
| client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL) | |
| except Exception as e: | |
| print(f"[ERROR] Failed to initialize OpenAI client: {e}") | |
| sys.exit(1) | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from server.Procure_RL_environment import ProcureRLEnvironment | |
| from models import NegotiationAction | |
| TASKS = ["single_issue", "multi_issue", "adversarial"] | |
| SYSTEM_PROMPT = """You are a professional procurement negotiator. Your goal is to negotiate the best possible deal for your company. | |
| You will receive a supplier's message and current offer terms. You must respond with a JSON action in this exact format: | |
| { | |
| "move_type": "make_offer", | |
| "terms": {"price": 42000, "payment_days": 45}, | |
| "message": "Your natural language response to the supplier" | |
| } | |
| move_type must be one of: make_offer, accept, reject, bundle | |
| terms must include price and any other issues being negotiated. | |
| message should be professional and collaborative when possible. | |
| Your buyer constraints will be provided. Do not exceed your budget. Try to reach the target price.""" | |
| def get_agent_action(obs_dict: dict) -> dict: | |
| task_id = obs_dict.get("task_id", "single_issue") | |
| supplier_msg = obs_dict.get("supplier_message", "") | |
| current_offer = obs_dict.get("current_offer", {}) | |
| constraints = obs_dict.get("buyer_constraints", {}) | |
| rapport_hint = obs_dict.get("rapport_hint", "neutral") | |
| round_num = obs_dict.get("round_number", 0) | |
| max_rounds = obs_dict.get("max_rounds", 10) | |
| user_content = f"""Task: {task_id} | |
| Round: {round_num}/{max_rounds} | |
| Supplier says: "{supplier_msg}" | |
| Current offer on table: {json.dumps(current_offer)} | |
| Your constraints: {json.dumps(constraints)} | |
| Relationship rapport: {rapport_hint} | |
| Respond with your negotiation action as JSON.""" | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_content}, | |
| ], | |
| max_tokens=300, | |
| temperature=0.3, | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return { | |
| "move_type": "make_offer", | |
| "terms": current_offer, | |
| "message": f"Error: {str(e)}", | |
| } | |
| try: | |
| start = content.find("{") | |
| end = content.rfind("}") + 1 | |
| if start >= 0 and end > start: | |
| action_dict = json.loads(content[start:end]) | |
| else: | |
| action_dict = { | |
| "move_type": "make_offer", | |
| "terms": current_offer, | |
| "message": content[:200] | |
| if content | |
| else "I'd like to continue our discussion.", | |
| } | |
| except: | |
| action_dict = { | |
| "move_type": "make_offer", | |
| "terms": current_offer, | |
| "message": "I'd like to continue our discussion.", | |
| } | |
| return action_dict | |
| def obs_to_dict(obs) -> dict: | |
| return { | |
| "task_id": obs.task_id, | |
| "round_number": obs.round_number, | |
| "max_rounds": obs.max_rounds, | |
| "supplier_message": obs.supplier_message, | |
| "current_offer": obs.current_offer, | |
| "buyer_constraints": obs.buyer_constraints, | |
| "rapport_hint": obs.rapport_hint, | |
| "done": obs.done, | |
| } | |
| def run_task(task_id: str) -> dict: | |
| env = ProcureRLEnvironment() | |
| obs = env.reset(task_id=task_id, seed=42) | |
| obs_dict = obs_to_dict(obs) | |
| print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}") | |
| rewards = [] | |
| step = 0 | |
| done = False | |
| final_score = 0.0 | |
| while not done and step < MAX_STEPS: | |
| step += 1 | |
| action_dict = get_agent_action(obs_dict) | |
| action = NegotiationAction( | |
| move_type=action_dict.get("move_type", "make_offer"), | |
| terms=action_dict.get("terms", {}), | |
| message=action_dict.get("message", ""), | |
| ) | |
| obs = env.step(action) | |
| rewards.append(obs.reward if obs.reward is not None else 0.0) | |
| action_str = f"{action.move_type}({json.dumps(action.terms)})" | |
| error = obs.metadata.get("error", None) if obs.metadata else None | |
| print( | |
| f"[STEP] step={step} action={action_str} reward={obs.reward if obs.reward else 0.0:.2f} done={str(obs.done).lower()} error={error if error else 'null'}" | |
| ) | |
| if obs.done: | |
| final_score = ( | |
| obs.reward | |
| if obs.reward is not None and obs.reward > 0 | |
| else (max(rewards) if rewards else 0.0) | |
| ) | |
| break | |
| obs_dict = obs_to_dict(obs) | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| success = final_score > 0.1 | |
| print( | |
| f"[END] success={str(success).lower()} steps={step} score={final_score:.2f} rewards={rewards_str}" | |
| ) | |
| return {"task": task_id, "score": final_score, "steps": step} | |
| if __name__ == "__main__": | |
| if not API_KEY: | |
| print("[ERROR] HF_TOKEN or API_KEY environment variable not set") | |
| sys.exit(1) | |
| results = [] | |
| for task in TASKS: | |
| try: | |
| result = run_task(task) | |
| results.append(result) | |
| except Exception as e: | |
| print(f"[ERROR] Task {task} failed: {e}") | |
| results.append({"task": task, "score": 0.0, "steps": 0, "error": str(e)}) | |
| print(f"\nBaseline Results:") | |
| for r in results: | |
| task = r["task"] | |
| score = r["score"] | |
| print(f" {task}: {score:.3f}") | |