workflow-twin / inference.py
NDGCodes's picture
fix repo structure for HF
1a692ce
import json
import os
import re
from openai import OpenAI
from baseline.policy import heuristic_policy
from env.environment import OpenEnv
from env.models import Action, Observation
from env.runtime_config import RuntimeConfig
API_BASE_URL = os.getenv("API_BASE_URL")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.getenv("HF_TOKEN")
runtime_config = RuntimeConfig.from_env()
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "EMPTY")
VALID_ACTIONS = {"triage", "respond", "resolve", "escalate"}
def compute_partial_score(total_reward: float, max_steps: int) -> float:
max_possible_reward = max(float(max_steps), 1.0)
return max(0.0, min(1.0, total_reward / max_possible_reward))
def safe_parse(content: str) -> dict:
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", content, re.DOTALL)
if match:
try:
parsed = json.loads(match.group(0))
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return {"action_type": "triage", "note": "fallback"}
def choose_action(observation: Observation) -> Action:
if not API_BASE_URL or not HF_TOKEN:
return heuristic_policy(observation)
prompt = (
"Return one action_type from [triage, respond, resolve, escalate] and a short note in JSON "
"with keys action_type and note. "
f"Observation: {observation.model_dump_json()}"
)
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=0,
)
content = response.choices[0].message.content or "{}"
payload = safe_parse(content)
action_type = payload.get("action_type")
if action_type not in VALID_ACTIONS:
return heuristic_policy(observation)
payload["note"] = str(payload.get("note", ""))
return Action(**payload)
except Exception:
return heuristic_policy(observation)
def run_episode(task: str, max_steps: int = 20) -> dict:
env = OpenEnv(
**{**runtime_config.to_env_kwargs(), "difficulty": task},
)
observation = env.reset()
openai_client_configured = bool(API_BASE_URL and HF_TOKEN)
print(
f"[START] task={task} env=workflow model={MODEL_NAME} "
f"openai_client={'enabled' if openai_client_configured else 'fallback'}"
)
done = False
total_reward = 0.0
reward_trace: list[str] = []
steps = 0
while not done and steps < max_steps:
action = choose_action(observation)
observation, reward, done, info = env.step(action)
total_reward += reward
reward_trace.append(f"{reward:.2f}")
steps += 1
print(
f"[STEP] step={steps} action={action.action_type} reward={reward:.2f} "
f"done={str(done).lower()} error=null"
)
if done:
break
success = observation.ticket_status == "resolved"
print(
f"[END] success={str(success).lower()} steps={steps} "
f"rewards={','.join(reward_trace)}"
)
final_state = env.state()
env_score = float(final_state.get("score", 0.0))
partial_score = compute_partial_score(total_reward, max_steps)
score = env_score if env_score > 0.0 else partial_score
return {
"task": task,
"success": success,
"steps": steps,
"rewards": round(total_reward, 2),
"score": round(score, 4),
"env_score": round(env_score, 4),
"partial_score": round(partial_score, 4),
"openai_client_configured": openai_client_configured,
}
def run_all_tasks() -> list[dict]:
results = []
for task in ["easy", "medium", "hard"]:
results.append(run_episode(task))
return results
if __name__ == "__main__":
summary = run_all_tasks()
print(json.dumps(summary, indent=2))