NetZero-Nav / inference.py
Aryanshh
Compliance: Force strict (0, 1) score range in both env.py and inference.py logs
d57c77b
import json
import os
import sys
import time
import textwrap
from typing import List, Optional
from openai import OpenAI
import httpx
# ---------------------------------------------------------------------------
# Config (MANDATORY per Checklist)
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
# Environment Server URL
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
if not API_KEY:
# We print and exit to avoid unhandled exceptions later
print("ERROR: HF_TOKEN or API_KEY environment variable is required", flush=True)
sys.exit(1)
# OpenAI Client configured via environment variables
client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
# ---------------------------------------------------------------------------
# Logging Utilities
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ---------------------------------------------------------------------------
# Agent Logic
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You are an Eco-Resilient Logistics Agent.
Your goal is to fulfill orders while minimizing CO2.
Available Actions: {"action_type": "order_parts | produce | offset | skip", "part_type": "chips | sensors | batteries | casing", "quantity": count, "mode": "sea | air | rail | road", "product": "EcoPhone | GreenTab"}
Respond ONLY with a valid JSON object."""
def get_action(obs) -> dict:
prompt = f"Current Observation: {json.dumps(obs)}\nChoose next action:"
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
timeout=15.0
)
return json.loads(response.choices[0].message.content)
except Exception as e:
# Emergency fallback to prevent script crash
print(f"[DEBUG] Model error: {e}", file=sys.stderr)
return {"action_type": "skip"}
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
def run_task(task_name: str):
success = False
score = 0.01 # Initialize to valid strictly-positive value
steps_taken = 0
rewards = []
log_start(task=task_name, env="netzero-nav", model=MODEL_NAME)
try:
with httpx.Client(base_url=ENV_URL, timeout=30.0) as app:
# Reset environment
resp = app.post("/reset", json={"task": task_name})
obs = resp.json()
done = False
while not done and steps_taken < 50:
steps_taken += 1
action_json = get_action(obs)
# Take step
resp = app.post("/step", json=action_json).json()
obs = resp["observation"]
reward = float(resp["reward"] or 0.0)
done = resp["done"]
info = resp.get("info", {})
error = info.get("error")
rewards.append(reward)
# Format action for logs
act_type = action_json.get("action_type", "skip")
act_part = action_json.get("part_type", "")
act_str = f"{act_type}-{act_part}" if act_part else act_type
log_step(step=steps_taken, action=act_str, reward=reward, done=done, error=error)
if done:
score = float(info.get("final_score", 0.01))
# Clamp strictly between 0 and 1
score = max(0.01, min(0.99, score))
success = score >= 0.95
except Exception as e:
print(f"[DEBUG] Runtime error during task {task_name}: {e}", file=sys.stderr)
finally:
# Final safety clamp before mandatory STDOUT
score = max(0.01, min(0.99, score))
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
if __name__ == "__main__":
for task in ["easy", "medium", "hard"]:
run_task(task)