smart-farm-env / inference.py
Andrew2712's picture
Fix inference: run all tasks, respect allowed actions
5fdcb8d
"""
inference.py – Smart Farm Resource Manager
Runs ALL tasks (easy, medium, hard) and emits strictly-formatted OpenEnv logs.
Log format (no deviations):
[START] task=<task_name> env=smart-farm-env model=<model>
[STEP] step=<n> action=<action> reward=<0.xxxx> done=<true|false> error=<null|string>
[END] success=<true|false> steps=<n> score=<0.xxxx> rewards=<r1,r2,...>
Run:
export API_BASE_URL=https://router.huggingface.co/v1
export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
export HF_TOKEN=hf_...
python inference.py
"""
from __future__ import annotations
import argparse
import importlib
import json
import os
import sys
from typing import Any, Dict, List, Optional
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
for _pkg in ("src", "src/envs"):
_init = os.path.join(os.path.dirname(__file__), _pkg, "__init__.py")
if not os.path.exists(_init):
open(_init, "w").close()
from openai import OpenAI
from src.envs.smart_farm_env.models import Action, TaskConfig
from src.envs.smart_farm_env.server.environment import SmartFarmEnv
# ---------------------------------------------------------------------------
# Client setup
# ---------------------------------------------------------------------------
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
# ---------------------------------------------------------------------------
# Strict log helpers
# ---------------------------------------------------------------------------
def log_start(task: str, model: str) -> None:
print(f"[START] task={task} env=smart-farm-env model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
print(
f"[STEP] step={step} action={action} reward={reward:.4f} "
f"done={'true' if done else 'false'} error={'null' if error is None else error}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
print(
f"[END] success={'true' if success else 'false'} steps={steps} "
f"score={score:.4f} rewards={','.join(f'{r:.4f}' for r in rewards)}",
flush=True,
)
# ---------------------------------------------------------------------------
# LLM Agent — strictly respects allowed_actions
# ---------------------------------------------------------------------------
def build_system_prompt(allowed_actions: List[str]) -> str:
action_examples = []
if "water" in allowed_actions:
action_examples.append('{"action_type": "water", "plot_id": 0, "amount": 1.0}')
if "fertilize" in allowed_actions:
action_examples.append('{"action_type": "fertilize", "plot_id": 0, "amount": 1.0}')
if "spray_pesticide" in allowed_actions:
action_examples.append('{"action_type": "spray_pesticide", "plot_id": 0, "amount": 1.0}')
if "inspect" in allowed_actions:
action_examples.append('{"action_type": "inspect", "plot_id": 0, "amount": 1.0}')
if "wait" in allowed_actions:
action_examples.append('{"action_type": "wait", "plot_id": null, "amount": 1.0}')
return f"""You are a smart farm manager AI. You MUST only use these allowed actions:
{json.dumps(allowed_actions)}
Example valid actions:
{chr(10).join(action_examples)}
Decision rules:
- If moisture_est < 0.45 AND "water" is allowed → water that plot
- If nutrients_est < 0.40 AND "fertilize" is allowed → fertilize that plot
- If pest_est > 0.45 AND "spray_pesticide" is allowed → spray that plot
- Otherwise → wait
CRITICAL: NEVER use an action not in the allowed list above.
Return ONLY a single valid JSON object. No explanation. No markdown."""
def llm_act(obs_dict: Dict[str, Any], allowed_actions: List[str]) -> Action:
user_msg = f"""Current farm state:
{json.dumps(obs_dict, indent=2)}
Allowed actions: {allowed_actions}
Pick the best action and return ONLY JSON."""
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": build_system_prompt(allowed_actions)},
{"role": "user", "content": user_msg},
],
max_tokens=80,
temperature=0.1,
)
raw = response.choices[0].message.content.strip()
# Strip markdown fences
if "```" in raw:
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
action_dict = json.loads(raw.strip())
action_type = action_dict.get("action_type", "wait")
# Safety: if LLM returns a disallowed action, fall back to wait
if action_type not in allowed_actions:
return Action(action_type="wait")
return Action(
action_type=action_type,
plot_id=action_dict.get("plot_id"),
amount=float(action_dict.get("amount", 1.0)),
)
except Exception:
return Action(action_type="wait")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def load_task(task_name: str) -> TaskConfig:
tasks_path = os.path.join(
os.path.dirname(__file__),
"src", "envs", "smart_farm_env", "tasks", "tasks.json",
)
with open(tasks_path) as f:
tasks = {t["name"]: t for t in json.load(f)["tasks"]}
if task_name not in tasks:
raise ValueError(f"Unknown task '{task_name}'. Choose from: {list(tasks)}")
return TaskConfig(**tasks[task_name])
def load_grader(task_name: str):
module = f"src.envs.smart_farm_env.server.graders.grader_{task_name}"
return importlib.import_module(module)
# ---------------------------------------------------------------------------
# Single task runner
# ---------------------------------------------------------------------------
def run(task_name: str, model: str = MODEL_NAME) -> None:
task_cfg = load_task(task_name)
env = SmartFarmEnv(task_cfg)
grader = load_grader(task_name)
log_start(task_name, model)
obs = env.reset()
rewards: List[float] = []
step_num = 0
done = False
while not done:
step_num += 1
action = llm_act(obs.model_dump(), task_cfg.allowed_actions)
action_str = f"{action.action_type}(plot={action.plot_id},amt={action.amount})"
try:
result = env.step(action)
except RuntimeError as exc:
log_step(step_num, action_str, 0.0, True, str(exc))
break
rewards.append(result.reward.value)
done = result.done
error = result.info.get("error") if result.info else None
log_step(step_num, action_str, result.reward.value, done, error)
obs = result.observation
final_state = env.state()
score = grader.grade(final_state.get("history", []), final_state)
success = score >= 0.50
log_end(success, step_num, score, rewards)
# ---------------------------------------------------------------------------
# CLI — runs ALL tasks by default
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Smart Farm Inference Runner")
parser.add_argument("--task", default="all", choices=["all", "easy", "medium", "hard"])
parser.add_argument("--model", default=MODEL_NAME)
args = parser.parse_args()
tasks = ["easy", "medium", "hard"] if args.task == "all" else [args.task]
for task in tasks:
run(task_name=task, model=args.model)