Spaces:
Sleeping
Sleeping
| import time | |
| import requests | |
| import gradio as gr | |
| import json | |
| # ββ CONFIGURATION & STYLING ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_URL = "https://yaser77-ambiguity-env.hf.space" | |
| MAX_STEPS = 5 | |
| CUSTOM_CSS = """ | |
| .gradio-container { | |
| font-family: 'Inter', 'Segoe UI', sans-serif !important; | |
| } | |
| .header-banner { | |
| background: linear-gradient(135deg, #1e1e2e 0%, #313244 100%); | |
| padding: 30px; | |
| border-radius: 12px; | |
| text-align: center; | |
| border: 1px solid #45475a; | |
| box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); | |
| margin-bottom: 20px; | |
| } | |
| .header-banner h1 { | |
| margin: 0; | |
| color: #cdd6f4; | |
| font-weight: 800; | |
| } | |
| .header-banner p { | |
| color: #a6adc8; | |
| font-size: 1.1em; | |
| margin-top: 10px; | |
| } | |
| .step-card { | |
| background: #181825; | |
| border-left: 4px solid #89b4fa; | |
| border-radius: 8px; | |
| padding: 16px 20px; | |
| margin-bottom: 15px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
| } | |
| .reward-tag { | |
| display: inline-block; | |
| padding: 3px 10px; | |
| border-radius: 20px; | |
| font-weight: bold; | |
| font-size: 0.9em; | |
| } | |
| .reward-pos { background-color: rgba(166, 227, 161, 0.15); color: #a6e3a1; } | |
| .reward-neg { background-color: rgba(243, 139, 168, 0.15); color: #f38ba8; } | |
| .action-text { | |
| font-family: monospace; | |
| background: #11111b; | |
| padding: 4px 8px; | |
| border-radius: 4px; | |
| color: #f5c2e7; | |
| } | |
| .info-box { | |
| background-color: rgba(137, 180, 250, 0.1); | |
| border: 1px solid rgba(137, 180, 250, 0.3); | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-bottom: 20px; | |
| } | |
| .result-banner { | |
| padding: 20px; | |
| border-radius: 12px; | |
| text-align: center; | |
| font-size: 1.25em; | |
| font-weight: bold; | |
| margin-top: 20px; | |
| } | |
| .result-success { background: linear-gradient(135deg, rgba(166, 227, 161, 0.2), rgba(148, 226, 213, 0.2)); border: 1px solid #a6e3a1; color: #a6e3a1; } | |
| .result-fail { background: linear-gradient(135deg, rgba(243, 139, 168, 0.2), rgba(250, 179, 135, 0.2)); border: 1px solid #f38ba8; color: #f38ba8; } | |
| """ | |
| TASK_MAPPING = { | |
| "Easy Explicit": "easy_explicit", | |
| "Medium Missing Time": "medium_missing_time", | |
| "Medium Missing Participants": "medium_missing_participants", | |
| "Hard Ambiguous": "hard_ambiguous" | |
| } | |
| # ββ DOMAIN REASONING βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_valid_times(constraints: dict) -> list[str]: | |
| all_times = ["10 AM", "2 PM", "4 PM"] | |
| unavailable = [t.strip().upper() for t in constraints.get("unavailable_times", [])] | |
| deadline = constraints.get("deadline", "ASAP") | |
| valid = [] | |
| for t in all_times: | |
| t_up = t.strip().upper() | |
| if t_up in unavailable: continue | |
| if deadline == "before 3 PM" and "4 PM" in t_up: continue | |
| valid.append(t) | |
| return valid | |
| def extract_from_text(text: str): | |
| time_val = None | |
| parts = [] | |
| t_up = text.upper() | |
| for t in ["10 AM", "2 PM", "4 PM"]: | |
| if t in t_up: | |
| time_val = t | |
| break | |
| for p in ["TEAM A", "TEAM B", "TEAM C"]: | |
| if p in t_up: | |
| parts.append(p.title()) | |
| return time_val, parts | |
| # ββ AGENT LOGIC (Mirroring inference.py Intelligence) ββββββββββββββββββββββββ | |
| def demo_agent(obs_dict, task_name): | |
| instruction = obs_dict.get("instruction", "") | |
| known = obs_dict.get("known_info", {}) | |
| constraints = obs_dict.get("constraints", {}) | |
| inst_time, inst_parts = extract_from_text(instruction) | |
| needs_time = ("time" in task_name.lower() or "hard" in task_name.lower()) and "time" not in known | |
| needs_parts = ("participants" in task_name.lower() or "hard" in task_name.lower()) and "participants" not in known | |
| if needs_time and not inst_time: | |
| return {"type": "ask", "question": "What time works for the meeting?"} | |
| if needs_parts and not inst_parts: | |
| return {"type": "ask", "question": "Who should attend the meeting?"} | |
| valid_times = get_valid_times(constraints) | |
| revealed_time = known.get("time") | |
| if revealed_time and any(revealed_time.upper() == vt.upper() for vt in valid_times): | |
| final_time = revealed_time | |
| elif inst_time and any(inst_time.upper() == vt.upper() for vt in valid_times): | |
| final_time = inst_time | |
| else: | |
| final_time = valid_times[0] if valid_times else "10 AM" | |
| revealed_parts = known.get("participants") | |
| if revealed_parts: | |
| final_participants = [p.strip() for p in revealed_parts.split(",")] | |
| else: | |
| final_participants = inst_parts if inst_parts else ["Team A"] | |
| return {"type": "execute", "proposed_time": final_time, "proposed_participants": final_participants} | |
| # ββ CORE EXECUTION LOOP ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_interaction(task_display_name, custom_inst, is_demo=False): | |
| if is_demo: | |
| task_name = "hard_ambiguous" | |
| custom_inst = "Schedule meeting ASAP with the team" | |
| else: | |
| task_name = TASK_MAPPING.get(task_display_name, "hard_ambiguous") | |
| output_html = "<div><span style='color:#a6adc8;'><i>Initialising environment...</i></span></div>" | |
| yield output_html | |
| payload = {"task_name": task_name} | |
| if custom_inst and custom_inst.strip(): | |
| payload["instruction"] = custom_inst.strip() | |
| try: | |
| r = requests.post(f"{BASE_URL}/reset", json=payload) | |
| r.raise_for_status() | |
| data = r.json() | |
| obs = data["observation"] | |
| except Exception as e: | |
| yield f"<div class='step-card' style='border-left-color:#f38ba8;'><b>Error:</b> {e}</div>" | |
| return | |
| output_html = f""" | |
| <div class='info-box'> | |
| <div style='color:#89b4fa; font-size:0.9em; text-transform:uppercase; font-weight:bold; margin-bottom:5px;'>β Session Start</div> | |
| <div style='font-size:1.15em; color:#cdd6f4; margin-bottom:10px;'>"{obs['instruction']}"</div> | |
| <div style='font-size:0.9em; color:#a6adc8; border-top:1px solid #45475a; padding-top:8px;'> | |
| <b>Active Constraints:</b><br> | |
| β³ Deadline: <span style='color:#f9e2af;'>{obs.get('constraints', {}).get('deadline', 'None')}</span><br> | |
| π« Unavailable: <span style='color:#f38ba8;'>{', '.join(obs.get('constraints', {}).get('unavailable_times', [])) or 'None'}</span> | |
| </div> | |
| </div> | |
| """ | |
| yield output_html | |
| step = 0 | |
| done = False | |
| rewards = [] | |
| while not done and step < MAX_STEPS: | |
| step += 1 | |
| action = demo_agent(obs, task_name) | |
| if action["type"] == "ask": | |
| act_str = f"<span style='color:#89b4fa;'>Ask</span> <span style='color:#6c7086;'>β</span> <span class='action-text'>\"{action['question']}\"</span>" | |
| else: | |
| act_str = f"<span style='color:#a6e3a1;'>Execute</span> <span style='color:#6c7086;'>β</span> <span class='action-text'>time='{action['proposed_time']}', parts={action['proposed_participants']}</span>" | |
| time.sleep(0.8) | |
| try: | |
| r = requests.post(f"{BASE_URL}/step", json=action) | |
| r.raise_for_status() | |
| res = r.json() | |
| obs = res["observation"] | |
| reward = res["reward"] | |
| done = res["done"] | |
| info = res.get("info", {}) | |
| raw_reward = info.get("raw_reward", reward) | |
| rewards.append(reward) | |
| except Exception as e: | |
| output_html += f"<div class='step-card' style='border-left-color:#f38ba8;'><b>Step Error:</b> {e}</div>" | |
| yield output_html | |
| break | |
| reward_class = "reward-pos" if raw_reward > 0 else "reward-neg" | |
| status_text = "<span style='color:#a6e3a1'>β Resolved</span>" if done else "<span style='color:#f9e2af'>β‘ Clarifying...</span>" | |
| step_block = f""" | |
| <div class='step-card'> | |
| <div style='display:flex; justify-content:space-between; align-items:center; margin-bottom:10px;'> | |
| <span style='color:#bac2de; font-weight:bold;'>Step {step}</span> | |
| <span class='reward-tag {reward_class}'>{raw_reward:+.2f} Reward</span> | |
| </div> | |
| <div style='margin-bottom:8px;'>{act_str}</div> | |
| <div style='font-size:0.9em;'>{status_text}</div> | |
| </div> | |
| """ | |
| if not done and obs.get("last_response"): | |
| step_block += f""" | |
| <div style='margin-left:20px; padding:8px 12px; border-left:3px solid #cba6f7; background:rgba(203,166,247,0.05); margin-bottom:15px; margin-top:-5px;'> | |
| <span style='color:#cba6f7; font-size:0.85em; text-transform:uppercase; font-weight:bold;'>Revealed Info</span><br> | |
| <span style='color:#cdd6f4;'>{obs['last_response']}</span> | |
| </div> | |
| """ | |
| output_html += step_block | |
| yield output_html | |
| if done: | |
| score = sum(rewards) / max(len(rewards), 1) | |
| banner = "result-success" if score > 0.5 else "result-fail" | |
| msg = "Success" if score > 0.5 else "Failure" | |
| output_html += f"<div class='result-banner {banner}'>{msg}! <br><span style='font-size:0.8em; font-weight:normal;'>Final Episode Score: {score:.2f}</span></div>" | |
| yield output_html | |
| # ββ GRADIO UI WRAPPERS (Fixing Generator Pickling) ββββββββββββββββββββββββββ | |
| def start_agent_run(task, custom_inst): | |
| yield from run_interaction(task, custom_inst, is_demo=False) | |
| def start_demo_run(task, custom_inst): | |
| yield from run_interaction(task, custom_inst, is_demo=True) | |
| # ββ GRADIO UI LAYOUT βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Ambiguity Resolution Demo") as app: | |
| gr.HTML("<div class='header-banner'><h1>π§ Ambiguity Resolution Benchmark Demo</h1><p>Visualizing intelligent multi-step decision making under constraints</p></div>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Context") | |
| task_dropdown = gr.Dropdown(label="Complexity", choices=list(TASK_MAPPING.keys()), value="Hard Ambiguous") | |
| custom_input = gr.Textbox(label="Prompt", placeholder="Schedule meeting ASAP...") | |
| with gr.Row(): | |
| btn_run = gr.Button("π Start Agent", variant="primary") | |
| btn_demo = gr.Button("βΆ Quick Demo", variant="secondary") | |
| gr.Markdown("<br>π‘ **Note:** The agent is deterministic and follows the high-quality reasoning benchmark rules.") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π‘ Trace") | |
| output_display = gr.HTML(value="<div style='color:#a6adc8; text-align:center; padding:40px;'>Awaiting trigger...</div>") | |
| btn_run.click(fn=start_agent_run, inputs=[task_dropdown, custom_input], outputs=[output_display]) | |
| btn_demo.click(fn=start_demo_run, inputs=[task_dropdown, custom_input], outputs=[output_display]) | |
| if __name__ == "__main__": | |
| # Gradio 5.x/6.x Recommended: Apply theme and CSS in launch() | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| theme=gr.themes.Soft(), | |
| css=CUSTOM_CSS | |
| ) | |