openenv-rl-environment / inference.py
Sid8421's picture
Fix final OpenEnv validator compliance: inference stdout format, Dockerfile deps, API vars, and grader bounds
ba2722e
import os
import json
import logging
import asyncio
from typing import List, Optional
from openai import OpenAI
from env.environment import SupportTicketEnv
from env.models import Action
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
MAX_STEPS = 10
MAX_TOTAL_REWARD = 1.0
SUCCESS_SCORE_THRESHOLD = 0.8
def log_start(task: str, env: str, model: str):
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str] = None):
error_val = error if error else "null"
done_val = str(done).lower()
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: list):
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
success_val = str(success).lower()
print(f"[END] success={success_val} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
def parse_action(text: str) -> Action:
# Robustly extract the first JSON object from text and validate with Pydantic
try:
decoder = json.JSONDecoder()
idx = 0
while True:
idx = text.find('{', idx)
if idx == -1:
break
try:
obj, end = decoder.raw_decode(text, idx)
if isinstance(obj, dict):
try:
return Action.model_validate(obj)
except Exception as val_err:
logger.warning("Action validation failed: %s", val_err)
# Fallback to manual construction with validation
action_type = obj.get("action_type", "close_ticket")
valid_actions = ["fetch_user_data", "check_policy", "issue_refund", "reply_to_customer", "escalate", "close_ticket"]
if action_type not in valid_actions:
logger.error("Invalid action_type: %s. Defaulting to 'close_ticket'.", action_type)
action_type = "close_ticket"
return Action(action_type=action_type, parameters=obj.get("parameters", {}))
except json.JSONDecodeError:
idx += 1
except Exception as e:
logger.error("Failed to parse action: %s", e)
# Default fallback if no valid action is found
return Action(action_type="close_ticket", parameters={})
def get_model_message(client, step: int, env_state: str, history: List[str]) -> str:
system_prompt = (
"You are an AI support agent resolving customer tickets.\n"
"Available Actions:\n"
"- fetch_user_data(user_id)\n"
"- check_policy(issue_type)\n"
"- issue_refund(amount)\n"
"- reply_to_customer(message)\n"
"- escalate(reason)\n"
"- close_ticket(resolution)\n\n"
"Must respond with JSON format:\n"
"{\"action_type\": \"...\", \"parameters\": {\"...\": \"...\"}}"
)
history_str = "\n".join(history)
user_prompt = f"History:\n{history_str}\n\nCurrent Observation:\n{env_state}\n\nWhat is your next action JSON?"
import time
# retry/backoff parameters
max_retries = 3
backoff_base = 0.5
try:
# Support a few possible client interfaces (chat.completions or responses)
for attempt in range(1, max_retries + 1):
try:
if hasattr(client, "chat") and hasattr(client.chat, "completions"):
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1
)
text = (completion.choices[0].message.content or "").strip()
return text if text else "{}"
if hasattr(client, "responses") and hasattr(client.responses, "create"):
completion = client.responses.create(model=MODEL_NAME, input=user_prompt, temperature=0.1)
text = getattr(completion, "output_text", None)
if text:
return text.strip()
out = []
for item in getattr(completion, "output", []) or []:
for c in item.get("content", []):
if c.get("type") == "output_text":
out.append(c.get("text", ""))
if out:
return "".join(out).strip()
raise RuntimeError("No supported model client method available")
except Exception as exc:
logger.warning("Model request attempt %d failed: %s", attempt, exc)
if attempt == max_retries:
break
sleep_time = backoff_base * (2 ** (attempt - 1))
time.sleep(sleep_time)
except Exception as exc:
logger.exception("Unexpected error in get_model_message: %s", exc)
return "{}"
async def run_task(task_id: str, client: OpenAI) -> None:
env = SupportTicketEnv(task_id=task_id)
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env="SupportTicketEnv", model=MODEL_NAME)
try:
obs = env.reset()
last_echoed = obs.model_dump_json(indent=2)
for step in range(1, MAX_STEPS + 1):
if env.get_state().is_done:
break
message = get_model_message(client, step, last_echoed, history)
action = parse_action(message)
obs_obj, reward, done, info = env.step(action)
obs_json = obs_obj.model_dump_json(indent=2)
error = None
actual_reward = info.get("current_reward", 0.0)
rewards.append(actual_reward)
steps_taken = step
last_echoed = obs_json
log_step(step=step, action=message, reward=actual_reward, done=done, error=error)
history.append(f"Step {step}: {message!r} -> reward {actual_reward:+.2f}")
if done:
score = actual_reward
break
score = min(max(score, 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
async def main() -> None:
api_key = os.getenv("HF_TOKEN")
client = OpenAI(base_url=API_BASE_URL, api_key=api_key)
tasks = ["task_easy_1", "task_medium_1", "task_hard_1"]
for task_id in tasks:
await run_task(task_id, client)
if __name__ == "__main__":
asyncio.run(main())