Spaces:
Sleeping
Sleeping
| """ | |
| inference.py | |
| Evaluation entry point for the Ambiguity Resolution Environment. | |
| Updated for LLM-driven evaluation via OpenEnv proxy. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import re | |
| from typing import Any, Tuple | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| # ββ load .env ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| load_dotenv() | |
| # OpenEnv Proxy / standard HF endpoints | |
| API_BASE_URL = os.getenv("API_BASE_URL") | |
| API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| MAX_STEPS = 5 | |
| if not API_KEY: | |
| print("ERROR: API_KEY or HF_TOKEN not set. Add it to your .env file.", file=sys.stderr) | |
| sys.exit(1) | |
| # Initialize Client | |
| client = OpenAI( | |
| base_url=API_BASE_URL, | |
| api_key=API_KEY | |
| ) | |
| from tasks.tasks import TASKS | |
| from env.env import AmbiguityEnv | |
| from models.models import Action | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOGGING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def log_start(task_name: str) -> None: | |
| print(f"[START] task={task_name} env=ambiguity_env model={MODEL_NAME}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: str | None = None) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True) | |
| def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| success_val = str(success).lower() | |
| print(f"[END] success={success_val} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LLM AGENT LOGIC | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_llm_response(raw_text: str) -> Action: | |
| """ | |
| Parses the LLM response into an Action model. | |
| Expected formats: | |
| - ask: <question> | |
| - execute: time=<time>, participants=<p1,p2> | |
| """ | |
| text = raw_text.strip().lower() | |
| # 1. Check for 'ask:' | |
| if text.startswith("ask:"): | |
| question = raw_text[4:].strip() | |
| return Action(type="ask", question=question) | |
| # 2. Check for 'execute:' | |
| if text.startswith("execute:"): | |
| params_str = raw_text[8:].strip() | |
| # Regex to extract time and participants | |
| time_match = re.search(r"time\s*=\s*([^,;]+)", params_str, re.IGNORECASE) | |
| parts_match = re.search(r"participants\s*=\s*([^,;]+)", params_str, re.IGNORECASE) | |
| proposed_time = time_match.group(1).strip() if time_match else "10 AM" | |
| participants_str = parts_match.group(1).strip() if parts_match else "Team A" | |
| proposed_participants = [p.strip() for p in participants_str.split(",")] | |
| return Action( | |
| type="execute", | |
| proposed_time=proposed_time, | |
| proposed_participants=proposed_participants | |
| ) | |
| # 3. Fallback: try to see if it just output JSON anyway | |
| try: | |
| data = json.loads(raw_text) | |
| return Action(**data) | |
| except: | |
| raise ValueError(f"Could not parse LLM response: {raw_text}") | |
| def get_deterministic_fallback(observation, task: dict) -> Action: | |
| """Purely observation-based fallback for stability.""" | |
| known = observation.known_info or {} | |
| missing_fields = task.get("missing_fields", []) | |
| if "time" in missing_fields and not known.get("time"): | |
| return Action(type="ask", question="What time works for the meeting?") | |
| if "participants" in missing_fields and not known.get("participants"): | |
| return Action(type="ask", question="Who should attend the meeting?") | |
| # Execute with what we have | |
| return Action( | |
| type="execute", | |
| proposed_time=known.get("time", "10 AM"), | |
| proposed_participants=[p.strip() for p in known.get("participants", "Team A").split(",")] if known.get("participants") else ["Team A"] | |
| ) | |
| def get_model_action(observation, task: dict) -> Tuple[Action, str | None]: | |
| """ | |
| FORCED LLM CALL EVERY STEP (Except for easy_explicit). | |
| Calls the LLM via proxy and parses the action. | |
| """ | |
| # ββ 1. FAST-PATH: DETECT NO AMBIGUITY ββ | |
| if task.get("name") == "easy_explicit": | |
| return Action( | |
| type="execute", | |
| proposed_time="10 AM", | |
| proposed_participants=["Team A"] | |
| ), None | |
| system_prompt = ( | |
| "You are an agent solving a scheduling task. " | |
| "Ask for missing info or execute when ready. " | |
| "Respond in the following format:\n" | |
| "ask: <question>\n" | |
| "OR\n" | |
| "execute: time=<time>, participants=<p1,p2>\n" | |
| ) | |
| user_content = ( | |
| f"Instruction: {observation.instruction}\n" | |
| f"Known info: {json.dumps(observation.known_info)}\n" | |
| f"Constraints: {json.dumps(observation.constraints)}\n" | |
| f"Last Response: {observation.last_response or 'None'}\n" | |
| ) | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_content} | |
| ], | |
| temperature=0, # Determinism | |
| max_tokens=100 | |
| ) | |
| raw_output = response.choices[0].message.content.strip() | |
| action = parse_llm_response(raw_output) | |
| return action, None | |
| except Exception as e: | |
| # Fallback safety | |
| action = get_deterministic_fallback(observation, task) | |
| return action, f"LLM Error: {str(e)}" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # EPISODE RUNNER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task(task: dict) -> dict: | |
| env = AmbiguityEnv() | |
| rewards: list[float] = [] | |
| steps = 0 | |
| log_start(task["name"]) | |
| try: | |
| observation = env.reset(task) | |
| for step_idx in range(1, MAX_STEPS + 1): | |
| action, error = get_model_action(observation, task) | |
| res = env.step(action) | |
| observation = res["observation"] | |
| reward = res["reward"] | |
| done = res["done"] | |
| rewards.append(reward) | |
| steps = step_idx | |
| log_step(step_idx, str(action.model_dump()), reward, done, error=error) | |
| if done: break | |
| except Exception as e: | |
| steps = max(steps, 1) | |
| if not rewards: rewards = [0.01] | |
| log_step(steps, "error_fallback", 0.01, True, error=str(e)) | |
| finally: | |
| score = sum(rewards) / max(len(rewards), 1) | |
| log_end(score > 0.5, steps, score, rewards) | |
| return {"name": task["name"], "score": score} | |
| if __name__ == "__main__": | |
| results = [run_task(t) for t in TASKS] | |
| print("\n" + "="*60 + "\nSUMMARY\n" + "="*60) | |
| for r in results: | |
| print(f" [{'PASS' if r['score'] > 0.5 else 'FAIL'}] {r['name']:<35} score={r['score']:.2f}") | |
| avg = sum(r['score'] for r in results) / len(results) | |
| print("-" * 60 + f"\n Average score: {avg:.2f}\n" + "="*60) | |