harshraj22's picture
download
raw
17.1 kB
"""
Inference Script for CropRL Environment
=================================================
STDOUT FORMAT
- The script must emit exactly three line types to stdout, in this order:
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...,rn>
"""
import os
import re
import sys
import argparse
from pathlib import Path
from typing import Any, List, Optional, Dict
# Ensure the root directory is on the path so cropRL module works anywhere
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from openai import OpenAI
from cropRL.tasks import create_env_for_task, grader, TASKS
from cropRL.models import MultiAgentAction
from cropRL.enums import ActionType, CropType
# ── Configuration ──────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:11434/v1")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "ollama")
MODEL_NAME = os.getenv("MODEL_NAME", "gemma4:e4b")
TEMPERATURE = 0.0 # Set to 0 to prevent erratic thinking tokens
MAX_TOKENS = 50 # Increased to prevent the model from rambling or thinking, but allow messages
SHAPE_REWARDS = os.getenv("SHAPE_REWARDS", "true").lower() == "true"
SYSTEM_PROMPT = """\
You are an expert farm manager AI. You manage a small Indian farm over 60 months.
You may be competing or cooperating with other AI farmers in the village.
OBJECTIVE: Maximize your net worth (cash + land value + crop value - debt) by the end of 60 months.
ACTIONS (reply with ONLY the action number, or if action 11, reply with: 11 <your message>):
0: Wait / No-Op — Do nothing but consume 1 action slot.
1: Plant Corn — High cost, high yield, depletes soil nitrogen heavily.
2: Plant Wheat — Moderate cost/yield, mild nitrogen drain. Best in Winter.
3: Plant Chickpea — Low cost, lower yield, RESTORES soil nitrogen.
4: Irrigate — Adds water to field instantly. Critical during dry months.
5: Fertilize — Boosts soil nitrogen by 0.15 instantly.
6: Harvest & Store — Harvest crop and store it (auto-sells old storage).
7: Harvest & Sell — Harvest crop and queue sale for month-end clearing.
8: Sell Inventory — Queue stored crops for month-end sale.
9: Take Loan — Get cash (only if no active loan). Interest locked at current rate.
10: Repay Loan — Pay off full debt (must have enough cash).
11: Post Forum Message — Send a short intent message to other agents. Format: 11 <your message>
12: Plant Matcha (Hype Crop) — High hype premium but saturates fast.
13: Plant Quinoa (Hype Crop) — Moderate hype premium.
14: Plant Turmeric (Hype Crop) — Moderate hype premium.
KEY RULES:
- Action 0 (Wait) consumes an action slot and does nothing else. The month advances ONLY when all agents expend all configured action slots.
- Actions cost 1 action slot each month.
- Crops queued to sell are cleared at the END of the month. High supply drops the market clearing price for everyone.
- Hype crops follow unpredictable cycles. Monitor Social Media Trends.
- Can only plant on fallow (empty) land.
- Can only harvest crops aged >= 1 month.
- Storage rots after 6 months. Only one slot.
- One loan at a time. Must repay full amount. Interest uses rate when loan was taken.
- Soil nitrogen is crucial: low N = poor yields. Chickpeas restore N, Corn destroys it.
- Water level matters.
- Growing crops in their optimal season gives much better yields.
- Inflation increases costs each year.
- Monthly fixed costs are deducted every month.
- Bankruptcy (negative cash + loan) ends the game with heavy penalty.
CRITICAL INSTRUCTION:
DO NOT use <think> tags.
DO NOT output any reasoning, chain-of-thought, or explanation.
Respond IMMEDIATELY with ONLY a single integer (0-14), or if using action 11, the integer followed by your message.
"""
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
def rule_based_agent(obs) -> int:
"""
Deterministic rule-based agent for CropRL environment.
"""
# 1. Clear inventory first if any
if obs.stored_amount > 0:
return ActionType.SELL_INVENTORY
# 2. Plant if land is fallow
if obs.active_crop_type == CropType.FALLOW:
# If soil nitrogen is low, plant restorative crop (Chickpea)
if obs.soil_nitrogen < 0.4 and obs.cash_balance >= getattr(obs, "cost_seed_3", 200.0):
return ActionType.PLANT_CHICKPEA
# If we have lots of cash and decent soil, maybe plant Hype or Corn
elif obs.cash_balance >= 1500 and obs.soil_nitrogen > 0.5:
# Just default to corn, hype is risky for rules
return ActionType.PLANT_CORN
elif obs.cash_balance >= getattr(obs, "cost_seed_1", 800.0) and obs.soil_nitrogen > 0.5:
return ActionType.PLANT_CORN
# Otherwise plant moderate (Wheat)
elif obs.cash_balance >= getattr(obs, "cost_seed_2", 500.0):
return ActionType.PLANT_WHEAT
# Failsafe if broke
elif obs.cash_balance < getattr(obs, "cost_seed_3", 200.0) and obs.current_debt == 0:
return ActionType.TAKE_LOAN
return ActionType.WAIT
# 3. Manage growing crop
if obs.active_crop_type != CropType.FALLOW:
# If crop is mature enough, harvest & sell
if obs.crop_age_months >= 4:
return ActionType.HARVEST_SELL
elif obs.crop_age_months >= 3 and obs.expected_yield_potential > 0.8:
return ActionType.HARVEST_SELL
# Fertilize if soil is very low
if obs.soil_nitrogen < 0.2 and obs.cash_balance >= getattr(obs, "cost_fertilize", 300.0):
return ActionType.FERTILIZE
# Irrigate if water is low
if obs.current_water_level < 0.2 and obs.cash_balance >= getattr(obs, "cost_irrigate", 300.0):
return ActionType.IRRIGATE
return ActionType.WAIT
def parse_action(response_text: str, fallback_action: int) -> tuple[int, Optional[str]]:
"""Extract an action integer and optional message from the LLM response."""
cleaned = response_text.strip()
# Check if the string matches the pattern "action_id message"
matched = re.match(r"^(\d{1,2})(?:[:\s-]+(.+))?", cleaned)
if matched:
val = int(matched.group(1))
if 0 <= val <= 14:
message = matched.group(2).strip() if matched.group(2) else None
return val, message
matches = re.findall(r"\b(\d{1,2})\b", cleaned)
for match in matches:
val = int(match)
if 0 <= val <= 14:
return val, None
return fallback_action, None
def get_agent_system_prompt(agent_id: int, num_agents: int) -> str:
"""Build a per-agent system prompt with identity context."""
return SYSTEM_PROMPT + (
f"\n\nAGENT IDENTITY:\n"
f"You are Agent {agent_id} (out of {num_agents} farmers in this village).\n"
f"Your farm is independent — you have your own land, cash, and crops.\n"
f"You can see what other agents plant (via the observation) and \n"
f"communicate via the Forum. Coordinate to avoid saturating the market \n"
f"with the same crop — if multiple agents sell the same crop, the \n"
f"clearing price drops for everyone. Messages are limited to 150 chars\n"
)
def get_model_action(
client: OpenAI, obs, history: List[str],
agent_id: Optional[int] = None, num_agents: int = 1,
) -> tuple[int, Optional[str]]:
fallback = rule_based_agent(obs)
user_msg = obs.text_summary if getattr(obs, "text_summary", None) else str(obs)
history_block = "\n".join(history[-12:]) if history else "None"
user_msg += f"\n\nRecent History:\n{history_block}"
# Use per-agent prompt if agent_id is provided (multi-agent mode)
if agent_id is not None:
prompt = get_agent_system_prompt(agent_id, num_agents)
else:
prompt = SYSTEM_PROMPT
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": user_msg},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
)
response = completion.choices[0].message.content or ""
return parse_action(response, fallback)
except Exception as e:
print(f"[DEBUG] LLM error: {e}", file=sys.stderr)
return fallback, None
def run_single_agent_episode(client: OpenAI, task_id: str):
"""Run a single-agent episode using MultiAgentCroprlEnvironment with num_agents=1."""
env = create_env_for_task(task_id, text_mode=True)
env.reset(seed=42)
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env="croprl", model=MODEL_NAME)
max_steps = env._env_cfg.max_steps
trajectory: list = []
prev_net_worth = env._farms[0].compute_net_worth() if SHAPE_REWARDS else 0.0
try:
for step in range(1, max_steps + 1):
# Always fetch fresh observation
obs = env.get_obs(0)
if obs.done:
break
obs_details = obs.text_summary if getattr(obs, "text_summary", None) else str(obs)
print(f"\n[OBSERVATION - Step {step}]\n{obs_details}\n", flush=True)
action_id, forum_message = get_model_action(client, obs, history, agent_id=0, num_agents=1)
action_name = env._env_cfg.action_names[action_id] if action_id < len(env._env_cfg.action_names) else f"Action {action_id}"
action = MultiAgentAction(action_id=action_id, agent_id=0, forum_message=forum_message)
result_obs = env.step(action)
if SHAPE_REWARDS:
current_net_worth = env._farms[0].compute_net_worth()
reward = current_net_worth - prev_net_worth
prev_net_worth = current_net_worth
else:
reward = result_obs.reward or 0.0
done = result_obs.done
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_name, reward=reward, done=done, error=None)
history.append(f"Step {step}: Selected '{action_name}' -> Reward {reward:+.2f}")
trajectory.append({
"step": step,
"action_id": action_id,
"reward": reward,
"cash": result_obs.cash_balance,
"debt": result_obs.current_debt,
"soil_n": result_obs.soil_nitrogen,
"prices": [
result_obs.market_price_crop_1,
result_obs.market_price_crop_2,
result_obs.market_price_crop_3,
result_obs.market_price_crop_4,
result_obs.market_price_crop_5,
result_obs.market_price_crop_6,
]
})
if done:
break
# Use compute_result for consistent scoring
result = env.compute_result({0: trajectory})
score = result.aggregate_score
success = score >= 0.1
except Exception as e:
print(f"[DEBUG] Error during episode execution: {e}", flush=True)
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
def run_multi_agent_episode_llm(client: OpenAI, task_id: str):
"""Run a multi-agent episode with LLM agents."""
env = create_env_for_task(task_id, text_mode=True)
env.reset(seed=42)
n = env._ma_cfg.num_agents
histories: Dict[int, List[str]] = {i: [] for i in range(n)}
trajectories: Dict[int, List[dict]] = {i: [] for i in range(n)}
done_agents: set = set()
max_steps = env._env_cfg.max_steps * n
total_steps = 0
score = 0.0
success = False
log_start(task=task_id, env="croprl_multi_agent", model=MODEL_NAME)
prev_net_worths = {i: env._farms[i].compute_net_worth() for i in range(n)} if SHAPE_REWARDS else {}
try:
while len(done_agents) < n and total_steps < max_steps:
for agent_id in env.get_turn_order():
# Always fetch fresh observation — no caching needed
obs = env.get_obs(agent_id)
if obs.done:
done_agents.add(agent_id)
# Dead/done agents automatically wait out their slots so they don't block TimeController
action_id = 0
forum_message = None
else:
action_id, forum_message = get_model_action(client, obs, histories[agent_id], agent_id=agent_id, num_agents=n)
action_name = env._env_cfg.action_names[action_id] if action_id < len(env._env_cfg.action_names) else f"Action {action_id}"
action = MultiAgentAction(action_id=action_id, agent_id=agent_id, forum_message=forum_message)
new_obs = env.step(action)
if SHAPE_REWARDS:
current_net_worth = env._farms[agent_id].compute_net_worth()
reward = current_net_worth - prev_net_worths[agent_id]
prev_net_worths[agent_id] = current_net_worth
else:
reward = new_obs.reward or 0.0
total_steps += 1
log_step(step=total_steps, action=f"A{agent_id}:{action_name}", reward=reward, done=new_obs.done, error=None)
histories[agent_id].append(f"Step {new_obs.current_step}: Selected '{action_name}' -> Reward {reward:+.2f}")
# Trajectory bookkeeping
trajectories[agent_id].append({
"step": new_obs.current_step,
"action_id": action_id,
"reward": reward,
"cash": new_obs.cash_balance,
"debt": new_obs.current_debt,
"soil_n": new_obs.soil_nitrogen,
"prices": [
new_obs.market_price_crop_1,
new_obs.market_price_crop_2,
new_obs.market_price_crop_3,
new_obs.market_price_crop_4,
new_obs.market_price_crop_5,
new_obs.market_price_crop_6,
]
})
# Only print observation detail if they actually took a choice (aren't dead yet)
if not obs.done:
obs_details = new_obs.text_summary if getattr(new_obs, "text_summary", None) else str(new_obs)
print(f"\n[OBSERVATION - A{agent_id} Step {new_obs.current_step}]\n{obs_details}\n", flush=True)
if new_obs.done:
done_agents.add(agent_id)
result = env.compute_result(trajectories)
score = result.aggregate_score
success = score >= 0.1
for agent_id in range(n):
terminal_profit = env._farms[agent_id].compute_terminal_value()
print(f"[AGENT {agent_id}] Terminal Profit: {terminal_profit:.4f}", flush=True)
log_end(success=success, steps=total_steps, score=score, rewards=list(result.agent_scores.values()))
except Exception as e:
print(f"[DEBUG] Error during multi-agent episode execution: {e}", flush=True)
log_end(success=False, steps=total_steps, score=0.0, rewards=[])
def run_episode(client: OpenAI, task_id: str):
task_info = TASKS.get(task_id, {})
if task_info.get("multi_agent", False):
run_multi_agent_episode_llm(client, task_id)
else:
run_single_agent_episode(client, task_id)
def main():
global MODEL_NAME
parser = argparse.ArgumentParser(description="Run CropRL inference")
parser.add_argument("--task", type=str, default="easy_2agent", help="Task ID to run")
parser.add_argument("--model", type=str, default=MODEL_NAME, help="Model name")
args = parser.parse_args()
MODEL_NAME = args.model
client = OpenAI(
base_url=API_BASE_URL,
api_key=API_KEY,
)
# Run task
run_episode(client, args.task)
if __name__ == "__main__":
main()

Xet Storage Details

Size:
17.1 kB
·
Xet hash:
fd95c7004fecac7dfe1e417786b46104e618469a3f3e5efc720e52f029a231ca

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.