Spaces:
Sleeping
Sleeping
File size: 2,046 Bytes
3bf3009 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | """
Base Task and Grader Abstract Classes
======================================
All tasks and graders must inherit from these bases.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional, Any
@dataclass
class TaskConfig:
task_name: str
crop_key: str
soil_key: str
region_key: str
max_steps: int
description: str
success_criteria: str
difficulty: str # "easy" | "medium" | "hard"
@dataclass
class EpisodeRecord:
"""Full episode history for grader evaluation."""
task_name: str
crop_key: str
days: list = field(default_factory=list)
actions: list = field(default_factory=list)
observations: list = field(default_factory=list)
step_rewards: list = field(default_factory=list)
step_infos: list = field(default_factory=list)
final_yield_ton_per_ha: float = 0.0
final_revenue_inr_per_ha: float = 0.0
total_irrigation_mm: float = 0.0
total_spray_events: int = 0
total_cost_inr: float = 0.0
cumulative_stress_days: int = 0
harvest_day: Optional[int] = None
harvest_gdd: Optional[float] = None
unnecessary_sprays: int = 0
correct_sprays: int = 0
class BaseTask(ABC):
"""Abstract base for all AgroEnv tasks."""
@property
@abstractmethod
def config(self) -> TaskConfig:
pass
@abstractmethod
def compute_step_reward(
self,
action: Any,
obs_before: Any,
obs_after: Any,
step_info: dict,
) -> tuple[float, str]:
pass
@abstractmethod
def is_done(self, obs: Any, step: int, episode: EpisodeRecord) -> bool:
pass
class BaseGrader(ABC):
"""Abstract base for all AgroEnv task graders."""
@property
@abstractmethod
def task_name(self) -> str:
pass
@abstractmethod
def grade(self, episode: EpisodeRecord) -> tuple[float, dict]:
pass
def _clamp(self, value: float, lo: float = 0.0, hi: float = 1.0) -> float:
return max(lo, min(hi, value))
|