ambiguity-env / app.py
Yaser77
fix: resolve Gradio generator pickling error and deprecation warnings
b922e87
import time
import requests
import gradio as gr
import json
# ── CONFIGURATION & STYLING ──────────────────────────────────────────────────
BASE_URL = "https://yaser77-ambiguity-env.hf.space"
MAX_STEPS = 5
CUSTOM_CSS = """
.gradio-container {
font-family: 'Inter', 'Segoe UI', sans-serif !important;
}
.header-banner {
background: linear-gradient(135deg, #1e1e2e 0%, #313244 100%);
padding: 30px;
border-radius: 12px;
text-align: center;
border: 1px solid #45475a;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
margin-bottom: 20px;
}
.header-banner h1 {
margin: 0;
color: #cdd6f4;
font-weight: 800;
}
.header-banner p {
color: #a6adc8;
font-size: 1.1em;
margin-top: 10px;
}
.step-card {
background: #181825;
border-left: 4px solid #89b4fa;
border-radius: 8px;
padding: 16px 20px;
margin-bottom: 15px;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.reward-tag {
display: inline-block;
padding: 3px 10px;
border-radius: 20px;
font-weight: bold;
font-size: 0.9em;
}
.reward-pos { background-color: rgba(166, 227, 161, 0.15); color: #a6e3a1; }
.reward-neg { background-color: rgba(243, 139, 168, 0.15); color: #f38ba8; }
.action-text {
font-family: monospace;
background: #11111b;
padding: 4px 8px;
border-radius: 4px;
color: #f5c2e7;
}
.info-box {
background-color: rgba(137, 180, 250, 0.1);
border: 1px solid rgba(137, 180, 250, 0.3);
border-radius: 8px;
padding: 15px;
margin-bottom: 20px;
}
.result-banner {
padding: 20px;
border-radius: 12px;
text-align: center;
font-size: 1.25em;
font-weight: bold;
margin-top: 20px;
}
.result-success { background: linear-gradient(135deg, rgba(166, 227, 161, 0.2), rgba(148, 226, 213, 0.2)); border: 1px solid #a6e3a1; color: #a6e3a1; }
.result-fail { background: linear-gradient(135deg, rgba(243, 139, 168, 0.2), rgba(250, 179, 135, 0.2)); border: 1px solid #f38ba8; color: #f38ba8; }
"""
TASK_MAPPING = {
"Easy Explicit": "easy_explicit",
"Medium Missing Time": "medium_missing_time",
"Medium Missing Participants": "medium_missing_participants",
"Hard Ambiguous": "hard_ambiguous"
}
# ── DOMAIN REASONING ─────────────────────────────────────────────────────────
def get_valid_times(constraints: dict) -> list[str]:
all_times = ["10 AM", "2 PM", "4 PM"]
unavailable = [t.strip().upper() for t in constraints.get("unavailable_times", [])]
deadline = constraints.get("deadline", "ASAP")
valid = []
for t in all_times:
t_up = t.strip().upper()
if t_up in unavailable: continue
if deadline == "before 3 PM" and "4 PM" in t_up: continue
valid.append(t)
return valid
def extract_from_text(text: str):
time_val = None
parts = []
t_up = text.upper()
for t in ["10 AM", "2 PM", "4 PM"]:
if t in t_up:
time_val = t
break
for p in ["TEAM A", "TEAM B", "TEAM C"]:
if p in t_up:
parts.append(p.title())
return time_val, parts
# ── AGENT LOGIC (Mirroring inference.py Intelligence) ────────────────────────
def demo_agent(obs_dict, task_name):
instruction = obs_dict.get("instruction", "")
known = obs_dict.get("known_info", {})
constraints = obs_dict.get("constraints", {})
inst_time, inst_parts = extract_from_text(instruction)
needs_time = ("time" in task_name.lower() or "hard" in task_name.lower()) and "time" not in known
needs_parts = ("participants" in task_name.lower() or "hard" in task_name.lower()) and "participants" not in known
if needs_time and not inst_time:
return {"type": "ask", "question": "What time works for the meeting?"}
if needs_parts and not inst_parts:
return {"type": "ask", "question": "Who should attend the meeting?"}
valid_times = get_valid_times(constraints)
revealed_time = known.get("time")
if revealed_time and any(revealed_time.upper() == vt.upper() for vt in valid_times):
final_time = revealed_time
elif inst_time and any(inst_time.upper() == vt.upper() for vt in valid_times):
final_time = inst_time
else:
final_time = valid_times[0] if valid_times else "10 AM"
revealed_parts = known.get("participants")
if revealed_parts:
final_participants = [p.strip() for p in revealed_parts.split(",")]
else:
final_participants = inst_parts if inst_parts else ["Team A"]
return {"type": "execute", "proposed_time": final_time, "proposed_participants": final_participants}
# ── CORE EXECUTION LOOP ──────────────────────────────────────────────────────
def run_interaction(task_display_name, custom_inst, is_demo=False):
if is_demo:
task_name = "hard_ambiguous"
custom_inst = "Schedule meeting ASAP with the team"
else:
task_name = TASK_MAPPING.get(task_display_name, "hard_ambiguous")
output_html = "<div><span style='color:#a6adc8;'><i>Initialising environment...</i></span></div>"
yield output_html
payload = {"task_name": task_name}
if custom_inst and custom_inst.strip():
payload["instruction"] = custom_inst.strip()
try:
r = requests.post(f"{BASE_URL}/reset", json=payload)
r.raise_for_status()
data = r.json()
obs = data["observation"]
except Exception as e:
yield f"<div class='step-card' style='border-left-color:#f38ba8;'><b>Error:</b> {e}</div>"
return
output_html = f"""
<div class='info-box'>
<div style='color:#89b4fa; font-size:0.9em; text-transform:uppercase; font-weight:bold; margin-bottom:5px;'>βœ… Session Start</div>
<div style='font-size:1.15em; color:#cdd6f4; margin-bottom:10px;'>"{obs['instruction']}"</div>
<div style='font-size:0.9em; color:#a6adc8; border-top:1px solid #45475a; padding-top:8px;'>
<b>Active Constraints:</b><br>
⏳ Deadline: <span style='color:#f9e2af;'>{obs.get('constraints', {}).get('deadline', 'None')}</span><br>
🚫 Unavailable: <span style='color:#f38ba8;'>{', '.join(obs.get('constraints', {}).get('unavailable_times', [])) or 'None'}</span>
</div>
</div>
"""
yield output_html
step = 0
done = False
rewards = []
while not done and step < MAX_STEPS:
step += 1
action = demo_agent(obs, task_name)
if action["type"] == "ask":
act_str = f"<span style='color:#89b4fa;'>Ask</span> <span style='color:#6c7086;'>β†’</span> <span class='action-text'>\"{action['question']}\"</span>"
else:
act_str = f"<span style='color:#a6e3a1;'>Execute</span> <span style='color:#6c7086;'>β†’</span> <span class='action-text'>time='{action['proposed_time']}', parts={action['proposed_participants']}</span>"
time.sleep(0.8)
try:
r = requests.post(f"{BASE_URL}/step", json=action)
r.raise_for_status()
res = r.json()
obs = res["observation"]
reward = res["reward"]
done = res["done"]
info = res.get("info", {})
raw_reward = info.get("raw_reward", reward)
rewards.append(reward)
except Exception as e:
output_html += f"<div class='step-card' style='border-left-color:#f38ba8;'><b>Step Error:</b> {e}</div>"
yield output_html
break
reward_class = "reward-pos" if raw_reward > 0 else "reward-neg"
status_text = "<span style='color:#a6e3a1'>βœ” Resolved</span>" if done else "<span style='color:#f9e2af'>⚑ Clarifying...</span>"
step_block = f"""
<div class='step-card'>
<div style='display:flex; justify-content:space-between; align-items:center; margin-bottom:10px;'>
<span style='color:#bac2de; font-weight:bold;'>Step {step}</span>
<span class='reward-tag {reward_class}'>{raw_reward:+.2f} Reward</span>
</div>
<div style='margin-bottom:8px;'>{act_str}</div>
<div style='font-size:0.9em;'>{status_text}</div>
</div>
"""
if not done and obs.get("last_response"):
step_block += f"""
<div style='margin-left:20px; padding:8px 12px; border-left:3px solid #cba6f7; background:rgba(203,166,247,0.05); margin-bottom:15px; margin-top:-5px;'>
<span style='color:#cba6f7; font-size:0.85em; text-transform:uppercase; font-weight:bold;'>Revealed Info</span><br>
<span style='color:#cdd6f4;'>{obs['last_response']}</span>
</div>
"""
output_html += step_block
yield output_html
if done:
score = sum(rewards) / max(len(rewards), 1)
banner = "result-success" if score > 0.5 else "result-fail"
msg = "Success" if score > 0.5 else "Failure"
output_html += f"<div class='result-banner {banner}'>{msg}! <br><span style='font-size:0.8em; font-weight:normal;'>Final Episode Score: {score:.2f}</span></div>"
yield output_html
# ── GRADIO UI WRAPPERS (Fixing Generator Pickling) ──────────────────────────
def start_agent_run(task, custom_inst):
yield from run_interaction(task, custom_inst, is_demo=False)
def start_demo_run(task, custom_inst):
yield from run_interaction(task, custom_inst, is_demo=True)
# ── GRADIO UI LAYOUT ─────────────────────────────────────────────────────────
with gr.Blocks(title="Ambiguity Resolution Demo") as app:
gr.HTML("<div class='header-banner'><h1>🧠 Ambiguity Resolution Benchmark Demo</h1><p>Visualizing intelligent multi-step decision making under constraints</p></div>")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### βš™οΈ Context")
task_dropdown = gr.Dropdown(label="Complexity", choices=list(TASK_MAPPING.keys()), value="Hard Ambiguous")
custom_input = gr.Textbox(label="Prompt", placeholder="Schedule meeting ASAP...")
with gr.Row():
btn_run = gr.Button("πŸš€ Start Agent", variant="primary")
btn_demo = gr.Button("β–Ά Quick Demo", variant="secondary")
gr.Markdown("<br>πŸ’‘ **Note:** The agent is deterministic and follows the high-quality reasoning benchmark rules.")
with gr.Column(scale=2):
gr.Markdown("### πŸ“‘ Trace")
output_display = gr.HTML(value="<div style='color:#a6adc8; text-align:center; padding:40px;'>Awaiting trigger...</div>")
btn_run.click(fn=start_agent_run, inputs=[task_dropdown, custom_input], outputs=[output_display])
btn_demo.click(fn=start_demo_run, inputs=[task_dropdown, custom_input], outputs=[output_display])
if __name__ == "__main__":
# Gradio 5.x/6.x Recommended: Apply theme and CSS in launch()
app.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Soft(),
css=CUSTOM_CSS
)