harshraj22's picture
download
raw
14.5 kB
"""
CropRL Task Definitions.
Three single-agent tasks with the same objective (maximize net worth over
60 months) at different difficulty levels.
Multi-agent task variants are also provided:
easy_2agent, easy_4agent, easy_8agent
medium_2agent, medium_4agent, medium_8agent
hard_2agent, hard_4agent, hard_8agent
"""
from __future__ import annotations
from cropRL import CroprlAction
from typing import Optional
from .config import EnvConfig, MultiAgentConfig
# ── Single-agent tasks ──────────────────────────────────────────────────────
TASKS: dict[str, dict] = {
"easy": {
"description": (
"Maximize net worth over 60 months. Simplified conditions: "
"no interest on loans, stable weather, generous starting capital, "
"no inflation."
),
"config_overrides": {
"initial_cash": 15000.0,
"base_interest_rate": 0.0,
"weather_sigma": 0.05,
"weather_sigma_realisation": 0.02,
"market_price_sigma": 0.05,
"initial_soil_nitrogen": 0.8,
"max_storage_age": 12,
"inflation_rate": 0.0,
"monthly_fixed_cost": 100.0,
},
},
"medium": {
"description": (
"Maximize net worth over 60 months. Standard conditions."
),
"config_overrides": {
# All defaults from EnvConfig
},
},
"hard": {
"description": (
"Maximize net worth over 60 months. Harsh conditions: "
"high interest, volatile weather and markets, poor starting soil, "
"high inflation."
),
"config_overrides": {
"initial_cash": 7000.0,
"base_interest_rate": 0.12,
"weather_sigma": 0.25,
"weather_sigma_realisation": 0.08,
"market_price_sigma": 0.20,
"initial_soil_nitrogen": 0.35,
"max_storage_age": 4,
"inflation_rate": 0.07,
"monthly_fixed_cost": 300.0,
},
},
}
# ── Multi-agent task variants ───────────────────────────────────────────────
# Injected after TASKS is defined so config_overrides references are valid.
for _difficulty in ("easy", "medium", "hard"):
for _n in (2, 4, 8):
_key = f"{_difficulty}_{_n}agent"
TASKS[_key] = {
"description": (
f"Multi-agent ({_n} farms) variant of the '{_difficulty}' task. "
"Agents compete on a shared market with supply-demand pricing "
"and hype crop cycles."
),
"config_overrides": TASKS[_difficulty]["config_overrides"],
"multi_agent": True,
"num_agents": _n,
}
# ── Environment factories ───────────────────────────────────────────────────
def create_env_for_task(
task_id: str,
text_mode: bool = False,
objective_mode: str = "competitive",
) -> "MultiAgentCroprlEnvironment":
"""
Create a MultiAgentCroprlEnvironment configured for the given task.
For single-agent tasks (``"easy"``, ``"medium"``, ``"hard"``), the
environment is created with ``num_agents=1``.
For multi-agent tasks (``"easy_4agent"``, etc.), the agent count is
read from the task definition.
Parameters
----------
task_id : str
Any recognised task id (single-agent or multi-agent).
text_mode : bool
Enable text observation mode (for LLM agents).
objective_mode : str
``"competitive"``, ``"cooperative"``, or ``"mixed"``.
Returns
-------
MultiAgentCroprlEnvironment
"""
from .multi_agent_environment import MultiAgentCroprlEnvironment
if task_id not in TASKS:
raise KeyError(
f"Unknown task '{task_id}'. Available: {list(TASKS.keys())}"
)
task_info = TASKS[task_id]
num_agents = task_info.get("num_agents", 1)
overrides = task_info["config_overrides"].copy()
overrides["text_mode"] = text_mode
env_cfg = EnvConfig(**overrides)
ma_cfg = MultiAgentConfig(
num_agents=num_agents,
objective_mode=objective_mode,
)
return MultiAgentCroprlEnvironment(
env_config=env_cfg,
ma_config=ma_cfg,
task_id=task_id,
)
def create_multi_agent_env_for_task(
task_id: str,
text_mode: bool = False,
objective_mode: str = "competitive",
) -> "MultiAgentCroprlEnvironment":
"""
Create a MultiAgentCroprlEnvironment configured for the given multi-agent task.
Parameters
----------
task_id : str
A multi-agent task id, e.g. ``"easy_4agent"`` or ``"hard_8agent"``.
text_mode : bool
Enable text observation mode (for LLM agents).
objective_mode : str
``"competitive"``, ``"cooperative"``, or ``"mixed"``.
Returns
-------
MultiAgentCroprlEnvironment
"""
from .multi_agent_environment import MultiAgentCroprlEnvironment
if task_id not in TASKS:
raise KeyError(
f"Unknown task '{task_id}'. Available: {list(TASKS.keys())}"
)
task_info = TASKS[task_id]
num_agents = task_info.get("num_agents", 4)
overrides = task_info["config_overrides"].copy()
overrides["text_mode"] = text_mode
env_cfg = EnvConfig(**overrides)
ma_cfg = MultiAgentConfig(
num_agents=num_agents,
objective_mode=objective_mode,
)
return MultiAgentCroprlEnvironment(
env_config=env_cfg,
ma_config=ma_cfg,
task_id=task_id,
)
def list_tasks() -> dict[str, str]:
"""Return a dict of task_id → description."""
return {tid: info["description"] for tid, info in TASKS.items()}
# ── Grading ─────────────────────────────────────────────────────────────────
def grader(
task_id: str,
final_net_worth: float,
bankrupt: bool,
trajectory: list[dict] | None = None,
) -> float:
"""
Grade the agent's performance on a 0.0 – 1.0 scale.
Uses an "Empirical Oracle Upper Bound" approach. The oracle
calculates the theoretical maximum monthly profit from the exact
market prices the agent experienced, using the new 4-factor yield
formula at peak conditions.
Score = (Agent Net Worth − Baseline) / (Oracle Upper Bound − Baseline)
Parameters
----------
task_id : str
The task that was executed (``"easy"``, ``"medium"``, ``"hard"``).
Multi-agent task ids (e.g. ``"medium_4agent"``) are also accepted;
the base difficulty will be extracted automatically.
final_net_worth : float
The agent's net worth at the end of the episode.
bankrupt : bool
Whether the agent went bankrupt.
trajectory : list[dict]
Chronological list of step data containing exact market prices.
Returns
-------
float
A score clamped between 0.0 and 1.0.
"""
if bankrupt or final_net_worth <= 0 or not trajectory:
return 0.01
# Resolve base task (strip multi-agent suffix like "_4agent")
base_task = task_id
for _n in (2, 4, 8):
base_task = base_task.replace(f"_{_n}agent", "")
base_task = base_task if base_task in ("easy", "medium", "hard") else "medium"
# Reconstruct config for this task to get initial values
overrides = TASKS.get(base_task, {}).get("config_overrides", {})
cfg = EnvConfig(**overrides)
# Baseline: net worth if agent does absolutely nothing for 60 months
baseline_cash = cfg.initial_cash - (cfg.max_months * cfg.monthly_fixed_cost)
baseline_land = cfg.base_land_price * cfg.initial_soil_nitrogen
baseline_min = baseline_cash + baseline_land
if final_net_worth <= baseline_min:
return 0.01
# Oracle Upper Bound: maximum possible profit from these prices
# Assumes perfect conditions (nitrogen=1, water=1, optimal season, peak maturity)
total_oracle_profit = 0.0
for step_data in trajectory:
prices = step_data.get("prices", list(cfg.base_market_prices[1:]))
# Monthly amortized profit: (price × max_yield − seed_cost) / growth_months
profits = [0.0]
for i in range(1, cfg.num_crop_types):
prof = ((prices[i-1] * cfg.base_yield_tons[i]) - cfg.seed_costs[i]) / float(cfg.growth_months[i])
profits.append(prof)
total_oracle_profit += max(profits)
# Oracle max net worth includes perfect soil maintenance
oracle_land = cfg.base_land_price * 1.0 # perfect nitrogen
oracle_max = cfg.initial_cash + oracle_land + total_oracle_profit
if oracle_max <= baseline_min:
return 0.5 # edge case: prices so bad oracle can't beat baseline
score = (final_net_worth - baseline_min) / (oracle_max - baseline_min)
return float(max(0.01, min(0.99, score)))
# ── Single-agent grader classes ─────────────────────────────────────────────
class EasyGrader:
def grade(self, final_net_worth: float, bankrupt: bool, trajectory: list[dict] | None = None) -> float:
return max(0.01, min(0.99, grader("easy", final_net_worth, bankrupt, trajectory)))
class MediumGrader:
def grade(self, final_net_worth: float, bankrupt: bool, trajectory: list[dict] | None = None) -> float:
return max(0.01, min(0.99, grader("medium", final_net_worth, bankrupt, trajectory)))
class HardGrader:
def grade(self, final_net_worth: float, bankrupt: bool, trajectory: list[dict] | None = None) -> float:
return max(0.01, min(0.99, grader("hard", final_net_worth, bankrupt, trajectory)))
# ── Multi-agent grader ──────────────────────────────────────────────────────
class MultiAgentGrader:
"""
Grader for multi-agent episodes.
Delegates per-agent scoring to the single-agent ``grader()`` and
aggregates the result via ``MultiAgentCroprlEnvironment.compute_result()``.
"""
def __init__(self, task_id: str = "medium_4agent") -> None:
self.task_id = task_id
def grade(
self,
env: "MultiAgentCroprlEnvironment", # noqa: F821
trajectories: Optional[dict] = None,
) -> "MultiAgentResult": # noqa: F821
"""Return a MultiAgentResult for a completed episode."""
return env.compute_result(trajectories)
# ── Rule-based agent helper ─────────────────────────────────────────────────
def _rule_based_action(obs: "MultiAgentObservation") -> int: # noqa: F821
"""
Simple deterministic agent used for smoke-testing multi-agent episodes.
Priority: harvest if mature → plant if fallow → irrigate → wait.
"""
from .enums import ActionType, CropType
s_crop = obs.active_crop_type
s_age = obs.crop_age_months
s_water = obs.current_water_level
# Harvest if crop is mature or rotting
if s_crop != CropType.FALLOW and s_age >= 3:
return ActionType.HARVEST_SELL
# Plant corn if fallow and have cash
if s_crop == CropType.FALLOW and obs.cash_balance >= obs.cost_seed_1:
return ActionType.PLANT_CORN
# Irrigate if water is low
if (s_crop != CropType.FALLOW
and s_water < 0.3
and obs.cash_balance >= obs.cost_irrigate):
return ActionType.IRRIGATE
return ActionType.WAIT
# ── Multi-agent episode runner ──────────────────────────────────────────────
def run_multi_agent_episode(
task_id: str = "medium_4agent",
num_agents: Optional[int] = None,
seed: int = 42,
agent_fn=None,
verbose: bool = False,
) -> "MultiAgentResult": # noqa: F821
"""
Run a complete multi-agent episode with rule-based (or custom) agents.
Parameters
----------
task_id : str
Multi-agent task identifier, e.g. ``"easy_4agent"``, ``"hard_8agent"``.
num_agents : int, optional
Override the number of agents (default: read from task_id).
seed : int
Global random seed.
agent_fn : callable, optional
``(agent_id, observation) -> action_id``.
Defaults to the built-in rule-based agent.
verbose : bool
Print step messages when True.
Returns
-------
MultiAgentResult
Episode scoring result.
"""
from .models import MultiAgentAction
env = create_env_for_task(task_id)
env.reset(seed=seed)
n = env._ma_cfg.num_agents
trajectories: dict = {i: [] for i in range(n)}
fn = agent_fn or (lambda aid, obs: _rule_based_action(obs))
done_agents: set = set()
max_steps = env._env_cfg.max_steps * n
total_steps = 0
while len(done_agents) < n and total_steps < max_steps:
for agent_id in env.get_turn_order():
if agent_id in done_agents:
env.step(CroprlAction(action_id=0, agent_id=agent_id))
continue
obs = env.get_obs(agent_id)
if obs.done:
done_agents.add(agent_id)
continue
action_id = fn(agent_id, obs)
action = MultiAgentAction(action_id=action_id, agent_id=agent_id)
new_obs = env.step(action)
trajectories[agent_id].append({
"prices": [
new_obs.market_price_crop_1,
new_obs.market_price_crop_2,
new_obs.market_price_crop_3,
new_obs.market_price_crop_4,
new_obs.market_price_crop_5,
new_obs.market_price_crop_6,
]
})
if verbose:
print(f"[A{agent_id} s{new_obs.current_step}] {new_obs.message[:80]}")
if new_obs.done:
done_agents.add(agent_id)
total_steps += 1
return env.compute_result(trajectories)

Xet Storage Details

Size:
14.5 kB
·
Xet hash:
5af27fd3358a682f83a70434e49a3ff24f9bf9d8441fccee31c1eb56f4a42d50

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.