Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| FastAPI application for the ProcureRL Environment. | |
| This module creates an HTTP server that exposes the ProcureRLEnvironment | |
| over HTTP and WebSocket endpoints, compatible with EnvClient. | |
| Endpoints: | |
| - POST /reset: Reset the environment | |
| - POST /step: Execute an action | |
| - GET /state: Get current environment state | |
| - GET /schema: Get action/observation schemas | |
| - WS /ws: WebSocket endpoint for persistent sessions | |
| Usage: | |
| # Development (with auto-reload): | |
| uvicorn server.app:app --reload --host 0.0.0.0 --port 7860 | |
| # Production: | |
| uvicorn server.app:app --host 0.0.0.0 --port 7860 --workers 4 | |
| # Or run directly: | |
| python -m server.app | |
| """ | |
| import sys | |
| import os | |
| import json | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| try: | |
| from openenv.core.env_server.http_server import create_app | |
| import openenv.core.env_server.web_interface as _mod | |
| _orig = _mod.get_quick_start_markdown | |
| def _fixed(md, ac, oc): | |
| return _orig(md, ac, oc).replace( | |
| "http://localhost:8000", "http://localhost:7860" | |
| ) | |
| _mod.get_quick_start_markdown = _fixed | |
| except Exception as e: | |
| raise ImportError( | |
| "openenv is required for the web interface. Install dependencies with '\n uv sync\n'" | |
| ) from e | |
| import gradio as gr | |
| from models import NegotiationAction, NegotiationObservation, NegotiationState | |
| from server.Procure_RL_environment import ProcureRLEnvironment | |
| _env_instance = ProcureRLEnvironment() | |
| def build_custom_gradio_ui( | |
| web_manager, | |
| action_fields, | |
| metadata, | |
| is_chat_env, | |
| title, | |
| quick_start_md, | |
| ): | |
| """Custom Gradio UI with interactive negotiation simulation.""" | |
| readme_content = _load_readme_content(metadata) | |
| display_title = metadata.name if metadata else title | |
| custom_quick_start_md = """### Connect to this environment | |
| Connect from Python using `ProcureRLEnv`: | |
| ```python | |
| from client import ProcureRLEnv | |
| # Connect to Hugging Face Space | |
| with ProcureRLEnv.from_env("akshaypulla/procure-rl") as env: | |
| result = await env.step(NegotiationAction(...)) | |
| # Or connect to local server | |
| with ProcureRLEnv(base_url="http://localhost:7860") as env: | |
| result = env.step(NegotiationAction(...)) | |
| ``` | |
| ### Web Interface | |
| Access the visual playground at `/web` to: | |
| - **Play Now**: Make offers and negotiate with the supplier | |
| - **Watch Agent**: See a strategic agent negotiate step-by-step | |
| - **Instructions**: Learn how to play and what each field means | |
| ### Quick Tips | |
| - Use **collaborative language** ("partnership", "mutual") to increase rapport | |
| - In **multi_issue**, offering Net-30 payment can get you a better price | |
| - In **adversarial**, avoid 2+ consecutive concessions or opponent hardens | |
| """ | |
| EXAMPLE_1 = { | |
| "move_type": "make_offer", | |
| "terms": {"price": 48000}, | |
| "message": "I value our partnership and believe we can reach a fair agreement together. Let's work collaboratively to find a solution.", | |
| } | |
| EXAMPLE_2 = { | |
| "move_type": "make_offer", | |
| "terms": {"price": 45000}, | |
| "message": "We appreciate your flexibility. Here's our counter-offer to move us closer to a mutual agreement.", | |
| } | |
| AGENT_STRATEGY = [ | |
| ("make_offer", {"price": 48000}, "I value our partnership."), | |
| ("make_offer", {"price": 46000}, "I appreciate your movement."), | |
| ("make_offer", {"price": 44000}, "We're getting closer."), | |
| ("make_offer", {"price": 42000}, "I believe we've found a good deal."), | |
| ("accept", {}, ""), | |
| ] | |
| async def reset_env(task_id, seed): | |
| try: | |
| data = await web_manager.reset_environment( | |
| {"task_id": task_id, "seed": int(seed)} | |
| ) | |
| obs_d = _format_observation_full(data) | |
| conv_h = _build_conversation_hist([]) | |
| price_d = _build_price_display(0, 52000, 36000, 52000) | |
| status = "โ Reset successful! Make your offer." | |
| json_d = json.dumps(data, indent=2) | |
| return obs_d, conv_h, price_d, status, json_d | |
| except Exception as e: | |
| return f"Error: {e}", "", "", f"Error: {e}", "" | |
| async def step_manual(move_type, terms_str, message, conversation_state): | |
| try: | |
| terms = json.loads(terms_str) if terms_str.strip() else {} | |
| action_data = {"move_type": move_type, "terms": terms, "message": message} | |
| data = await web_manager.step_environment(action_data) | |
| new_conv = conversation_state.copy() if conversation_state else [] | |
| new_conv.append( | |
| { | |
| "role": "you", | |
| "message": message or f"[{move_type}: {terms}]", | |
| "terms": terms, | |
| } | |
| ) | |
| if not data.get("observation", {}).get("done"): | |
| supplier_msg = data.get("observation", {}).get("supplier_message", "") | |
| new_conv.append( | |
| { | |
| "role": "supplier", | |
| "message": supplier_msg, | |
| "terms": data.get("observation", {}).get("current_offer", {}), | |
| } | |
| ) | |
| obs = data.get("observation", {}) | |
| current_price = obs.get("current_offer", {}).get("price", 0) | |
| reward = obs.get("reward") | |
| done = obs.get("done", False) | |
| status_msg = f"Step complete! Round {obs.get('round_number', 0)}/{obs.get('max_rounds', 6)}" | |
| if done and reward is not None: | |
| status_msg = f"๐ Deal done! Final score: {reward:.4f}" | |
| elif done: | |
| status_msg = "โ No deal reached." | |
| obs_display = _format_observation_full(data) | |
| conv_hist = _build_conversation_hist(new_conv) | |
| price_disp = _build_price_display( | |
| obs.get("round_number", 0), current_price, 36000, 52000 | |
| ) | |
| json_data = json.dumps(data, indent=2) | |
| return obs_display, conv_hist, price_disp, status_msg, json_data | |
| except json.JSONDecodeError: | |
| return "", "", "", "โ Invalid JSON in terms field", "" | |
| except Exception as e: | |
| return "", "", "", f"Error: {e}", f"Error: {str(e)}" | |
| async def run_agent_example(task_id="single_issue", seed=42): | |
| try: | |
| await web_manager.reset_environment({"task_id": task_id, "seed": seed}) | |
| conv = [] | |
| steps_log = [] | |
| price_points = [] | |
| for i, (move_type, terms, message) in enumerate(AGENT_STRATEGY): | |
| action_data = { | |
| "move_type": move_type, | |
| "terms": terms, | |
| "message": message, | |
| } | |
| data = await web_manager.step_environment(action_data) | |
| obs = data.get("observation", {}) | |
| current_price = obs.get("current_offer", {}).get("price", 0) | |
| price_points.append(current_price) | |
| conv.append( | |
| { | |
| "role": "you", | |
| "message": message or f"[{move_type}: {terms}]", | |
| "terms": terms, | |
| } | |
| ) | |
| steps_log.append( | |
| f"**Step {i + 1}:** `{move_type}` โ ${current_price:,.0f}" | |
| ) | |
| if obs.get("done"): | |
| steps_log.append( | |
| f"โ Deal completed! Reward: **{obs.get('reward', 0):.4f}**" | |
| ) | |
| conv.append( | |
| { | |
| "role": "supplier", | |
| "message": obs.get("supplier_message", ""), | |
| "terms": obs.get("current_offer", {}), | |
| } | |
| ) | |
| break | |
| supplier_msg = obs.get("supplier_message", "") | |
| conv.append( | |
| { | |
| "role": "supplier", | |
| "message": supplier_msg, | |
| "terms": obs.get("current_offer", {}), | |
| } | |
| ) | |
| return ( | |
| _build_agent_demo_result(steps_log, conv, price_points), | |
| json.dumps(data, indent=2), | |
| "โ Agent demo complete!", | |
| ) | |
| except Exception as e: | |
| return f"Error: {e}", "", f"Error: {e}" | |
| def _format_observation_full(data): | |
| if not data: | |
| return "No data" | |
| obs = data.get("observation", data) | |
| lines = [f"## ๐ฏ Round {obs.get('round_number', 0)}/{obs.get('max_rounds', 6)}"] | |
| lines.append(f"**Task:** `{obs.get('task_id', '')}`") | |
| lines.append( | |
| f"**Rapport:** {_get_rapport_emoji(obs.get('rapport_hint', 'neutral'))} {obs.get('rapport_hint', 'neutral')}" | |
| ) | |
| if obs.get("done"): | |
| r = obs.get("reward") | |
| lines.append(f"\n### ๐ Episode Complete!") | |
| if r is not None: | |
| lines.append(f"**Final Score:** `{r:.4f}`") | |
| return "\n".join(lines) | |
| lines.append(f"\n### ๐ฌ Supplier says:") | |
| lines.append(f"> {obs.get('supplier_message', '')}") | |
| offer = obs.get("current_offer", {}) | |
| if offer: | |
| lines.append(f"\n### ๐ Current Offer:") | |
| for k, v in offer.items(): | |
| lines.append( | |
| f"- **{k.title()}:** `{v:,.2f}`" | |
| if isinstance(v, float) | |
| else f"- **{k.title()}:** `{v}`" | |
| ) | |
| constraints = obs.get("buyer_constraints", {}) | |
| if constraints: | |
| lines.append(f"\n### ๐ฏ Your Targets:") | |
| for k, v in constraints.items(): | |
| if isinstance(v, dict): | |
| lines.append( | |
| f"- **{k.title()}:** target `${v.get('target', 'N/A'):,}` | worst `${v.get('worst', 'N/A'):,}`" | |
| ) | |
| return "\n".join(lines) | |
| def _get_rapport_emoji(rapport): | |
| if rapport == "positive": | |
| return "๐" | |
| elif rapport == "negative": | |
| return "๐ค" | |
| return "๐" | |
| def _build_conversation_hist(conv): | |
| if not conv: | |
| return "**Conversation will appear here...**\n\nMake your first offer to start the negotiation!" | |
| lines = ["## ๐ฌ Conversation History\n"] | |
| for msg in conv: | |
| if msg["role"] == "you": | |
| lines.append(f"**๐ง You:** {msg['message']}") | |
| if msg.get("terms"): | |
| lines.append(f" โ Terms: `{json.dumps(msg['terms'])}`") | |
| else: | |
| lines.append(f"**๐ช Supplier:** {msg['message']}") | |
| return "\n".join(lines) | |
| def _build_price_display(round_num, current_price, target, opening): | |
| range_price = opening - target | |
| progress = ( | |
| ((opening - current_price) / range_price * 100) if range_price > 0 else 0 | |
| ) | |
| progress = max(0, min(100, progress)) | |
| bar = "โ" * int(progress / 5) + "โ" * (20 - int(progress / 5)) | |
| lines = [ | |
| f"## ๐ Price Tracker\n", | |
| f"Opening: `${opening:,.0f}`", | |
| f"Target: `${target:,.0f}`", | |
| f"Current: `${current_price:,.0f}`", | |
| f"\n**Progress:** `{progress:.1f}%`", | |
| f"\n[{bar}]", | |
| ] | |
| return "\n".join(lines) | |
| def _build_agent_demo_result(steps_log, conv, price_points): | |
| lines = [ | |
| "## ๐ค Agent Negotiation Demo\n", | |
| "Watch how a strategic agent negotiates:\n", | |
| "### ๐ Steps:", | |
| ] | |
| lines.extend(steps_log) | |
| lines.append("\n### ๐ฌ Full Conversation:") | |
| for msg in conv: | |
| if msg["role"] == "you": | |
| lines.append(f"**๐ง You:** {msg['message']}") | |
| else: | |
| lines.append(f"**๐ช Supplier:** {msg['message']}") | |
| if price_points: | |
| lines.append(f"\n### ๐ Price Journey:") | |
| lines.append(f"`{' โ '.join(f'${p:,.0f}' for p in price_points)}`") | |
| return "\n".join(lines) | |
| with gr.Blocks(title=display_title) as demo: | |
| gr.Markdown(f"# ๐ค {display_title}") | |
| gr.Markdown("### Interactive Procurement Negotiation Simulation") | |
| with gr.Tabs(): | |
| with gr.TabItem("๐ฎ Play Now"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| conversation_display = gr.Markdown("*Click Reset to start!*") | |
| price_tracker = gr.Markdown( | |
| "## ๐ Price Tracker\n*Reset to see price tracker*" | |
| ) | |
| obs_display = gr.Markdown("*Reset to see current state*") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### โ๏ธ Controls") | |
| task_dropdown = gr.Dropdown( | |
| choices=["single_issue", "multi_issue", "adversarial"], | |
| value="single_issue", | |
| label="Task", | |
| ) | |
| seed_input = gr.Number(value=42, label="Seed") | |
| move_type_input = gr.Textbox( | |
| label="Move Type", | |
| value="make_offer", | |
| info="make_offer | accept | reject | bundle", | |
| ) | |
| terms_input = gr.Textbox( | |
| label="Terms (JSON)", | |
| value='{"price": 48000}', | |
| info='Example: {"price": 45000}', | |
| ) | |
| message_input = gr.Textbox( | |
| label="Your Message", | |
| value="I value our partnership.", | |
| lines=2, | |
| ) | |
| gr.Markdown("**๐ก Quick Examples:**") | |
| with gr.Row(): | |
| eg1_btn = gr.Button( | |
| "๐ Friendly", variant="secondary", size="sm" | |
| ) | |
| eg2_btn = gr.Button( | |
| "๐ผ Professional", variant="secondary", size="sm" | |
| ) | |
| eg3_btn = gr.Button( | |
| "โก Counter-Offer", variant="secondary", size="sm" | |
| ) | |
| with gr.Row(): | |
| step_btn = gr.Button("๐ค Submit Offer", variant="primary") | |
| accept_btn = gr.Button("โ Accept Deal", variant="primary") | |
| reset_btn = gr.Button("๐ Reset", variant="secondary") | |
| status_output = gr.Textbox( | |
| label="Status", interactive=False, lines=1 | |
| ) | |
| with gr.Accordion("๐ Raw JSON", open=False): | |
| raw_json = gr.Code( | |
| label="", language="json", interactive=False, lines=10 | |
| ) | |
| FRIENDLY_EX = ( | |
| "make_offer", | |
| '{"price": 48000}', | |
| "I truly value our partnership and believe we can find a fair solution.", | |
| ) | |
| PROF_EX = ( | |
| "make_offer", | |
| '{"price": 46000}', | |
| "Based on market research and our long-term relationship, I believe $46,000 is fair.", | |
| ) | |
| COUNTER_EX = ( | |
| "make_offer", | |
| '{"price": 44000}', | |
| "We've made good progress. I can meet you at $44,000.", | |
| ) | |
| def get_friendly(): | |
| return FRIENDLY_EX[0], FRIENDLY_EX[1], FRIENDLY_EX[2] | |
| def get_prof(): | |
| return PROF_EX[0], PROF_EX[1], PROF_EX[2] | |
| def get_counter(): | |
| return COUNTER_EX[0], COUNTER_EX[1], COUNTER_EX[2] | |
| eg1_btn.click( | |
| fn=get_friendly, | |
| outputs=[move_type_input, terms_input, message_input], | |
| ) | |
| eg2_btn.click( | |
| fn=get_prof, outputs=[move_type_input, terms_input, message_input] | |
| ) | |
| eg3_btn.click( | |
| fn=get_counter, | |
| outputs=[move_type_input, terms_input, message_input], | |
| ) | |
| async def do_reset(task_id, seed): | |
| return await reset_env(task_id, seed) | |
| reset_btn.click( | |
| fn=do_reset, | |
| inputs=[task_dropdown, seed_input], | |
| outputs=[ | |
| conversation_display, | |
| price_tracker, | |
| obs_display, | |
| status_output, | |
| raw_json, | |
| ], | |
| ) | |
| async def do_step(mt, ts, msg): | |
| return await step_manual(mt, ts, msg, []) | |
| step_btn.click( | |
| fn=do_step, | |
| inputs=[move_type_input, terms_input, message_input], | |
| outputs=[ | |
| obs_display, | |
| conversation_display, | |
| price_tracker, | |
| status_output, | |
| raw_json, | |
| ], | |
| ) | |
| async def do_accept(): | |
| return await step_manual("accept", "{}", "", []) | |
| accept_btn.click( | |
| fn=do_accept, | |
| outputs=[ | |
| obs_display, | |
| conversation_display, | |
| price_tracker, | |
| status_output, | |
| raw_json, | |
| ], | |
| ) | |
| with gr.TabItem("๐ค Watch Agent"): | |
| gr.Markdown("### Watch a Strategic Agent Negotiate") | |
| gr.Markdown( | |
| "This demo shows how a strategic agent approaches the negotiation." | |
| ) | |
| with gr.Row(): | |
| task_selector = gr.Dropdown( | |
| choices=["single_issue", "multi_issue", "adversarial"], | |
| value="single_issue", | |
| label="Select Task", | |
| ) | |
| run_btn = gr.Button( | |
| "โถ๏ธ Run Agent Demo", variant="primary", size="lg" | |
| ) | |
| agent_result = gr.Markdown( | |
| "*Click 'Run Agent Demo' to watch the agent negotiate*" | |
| ) | |
| agent_json = gr.Code( | |
| label="Full JSON", language="json", interactive=False, lines=15 | |
| ) | |
| agent_status = gr.Textbox(label="Status", interactive=False) | |
| async def do_agent_run(tid): | |
| return await run_agent_example(tid, 42) | |
| run_btn.click( | |
| fn=do_agent_run, | |
| inputs=[task_selector], | |
| outputs=[agent_result, agent_json, agent_status], | |
| ) | |
| with gr.TabItem("๐ Instructions"): | |
| gr.Markdown(""" | |
| ## ๐ฎ How to Play | |
| ### 1. Choose Your Task | |
| - **single_issue**: Negotiate only the price (easiest) | |
| - **multi_issue**: Negotiate price + payment terms (medium) | |
| - **adversarial**: Negotiate price + payment + support (hardest) | |
| ### 2. Make Offers | |
| - **Move Type**: `make_offer` to propose, `accept` to take deal, `reject` to walk away | |
| - **Terms**: JSON with your offered price | |
| - **Message**: Be collaborative for better rapport! | |
| ### 3. Watch the Response | |
| - Your **rapport** changes based on language quality | |
| - Higher rapport โ opponent gives better concessions | |
| ### 4. Goal | |
| - Get price close to your target | |
| - Use fewer rounds for better efficiency score | |
| - **Don't make 2+ consecutive concessions** in adversarial mode! | |
| ## ๐ฏ Quick Tips | |
| | Do | Don't | | |
| |---|---| | |
| | Use collaborative language | Use aggressive language | | |
| | Make strategic concessions | Concede every round | | |
| | Offer Net-30 payment | Ignore payment terms | | |
| """) | |
| with gr.Accordion("๐ Quick Start Guide", open=False): | |
| gr.Markdown(custom_quick_start_md) | |
| with gr.Accordion("๐ Full README", open=False): | |
| gr.Markdown(readme_content) | |
| return demo | |
| def _load_readme_content(metadata): | |
| if metadata and hasattr(metadata, "readme_content") and metadata.readme_content: | |
| return metadata.readme_content | |
| try: | |
| from pathlib import Path | |
| readme_path = Path("/app/README.md") | |
| if readme_path.exists(): | |
| return readme_path.read_text(encoding="utf-8") | |
| except: | |
| pass | |
| return "No README available." | |
| app = create_app( | |
| lambda: _env_instance, | |
| NegotiationAction, | |
| NegotiationObservation, | |
| env_name="ProcureRL", | |
| max_concurrent_envs=1, | |
| gradio_builder=build_custom_gradio_ui, | |
| ) | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", 7860)) | |
| import uvicorn | |
| uvicorn.run("server.app:app", host="0.0.0.0", port=port) | |
| def main(): | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run("server.app:app", host="0.0.0.0", port=port) | |