from __future__ import annotations import enum from datetime import date, datetime from typing import Any import math from pydantic import BaseModel, Field, field_validator, model_validator from shapely.geometry import box as shapely_box from pyproj import Geod def _sanitize_float(v: Any) -> Any: """Replace NaN/inf float values with 0.0 for JSON compatibility.""" if isinstance(v, float) and (math.isnan(v) or math.isinf(v)): return 0.0 return v def sanitize_for_json(obj: Any) -> Any: """Recursively sanitize a data structure, replacing NaN/inf with 0.0.""" if isinstance(obj, float): return _sanitize_float(obj) if isinstance(obj, dict): return {k: sanitize_for_json(v) for k, v in obj.items()} if isinstance(obj, list): return [sanitize_for_json(v) for v in obj] return obj # --- East Africa bounding box (approximate) --- EA_BOUNDS = (22.0, -5.0, 52.0, 23.0) # (min_lon, min_lat, max_lon, max_lat) MAX_LOOKBACK_DAYS = 3 * 365 + 1 # ~3 years class StatusLevel(str, enum.Enum): GREEN = "green" AMBER = "amber" RED = "red" class TrendDirection(str, enum.Enum): IMPROVING = "improving" STABLE = "stable" DETERIORATING = "deteriorating" class ConfidenceLevel(str, enum.Enum): HIGH = "high" MODERATE = "moderate" LOW = "low" class JobStatus(str, enum.Enum): QUEUED = "queued" PROCESSING = "processing" COMPLETE = "complete" FAILED = "failed" class AOI(BaseModel): name: str bbox: list[float] = Field(min_length=4, max_length=4) @property def area_km2(self) -> float: geod = Geod(ellps="WGS84") poly = shapely_box(*self.bbox) area_m2, _ = geod.geometry_area_perimeter(poly) return abs(area_m2) / 1e6 @model_validator(mode="after") def validate_geography(self) -> AOI: min_lon, min_lat, max_lon, max_lat = self.bbox ea_min_lon, ea_min_lat, ea_max_lon, ea_max_lat = EA_BOUNDS # Check area first so "too large" error takes priority from app.config import MAX_AOI_KM2 if self.area_km2 > MAX_AOI_KM2: raise ValueError( f"AOI area ({self.area_km2:.0f} km²) exceeds {MAX_AOI_KM2:,} km² limit" ) if ( max_lon < ea_min_lon or min_lon > ea_max_lon or max_lat < ea_min_lat or min_lat > ea_max_lat ): raise ValueError( "AOI must intersect the East Africa region " f"({ea_min_lon}–{ea_max_lon}°E, {ea_min_lat}–{ea_max_lat}°N)" ) return self class TimeRange(BaseModel): start: date = Field(default=None) end: date = Field(default=None) @model_validator(mode="after") def set_defaults_and_validate(self) -> TimeRange: today = date.today() if self.end is None: self.end = today if self.start is None: self.start = date(today.year - 1, today.month, today.day) if (self.end - self.start).days > MAX_LOOKBACK_DAYS: raise ValueError("Time range cannot exceed 3 years") return self class JobRequest(BaseModel): aoi: AOI time_range: TimeRange = Field(default_factory=TimeRange) product_ids: list[str] email: str @model_validator(mode="before") @classmethod def _accept_legacy_field_names(cls, data): """Accept old 'indicator_ids' field name from stored database records.""" if isinstance(data, dict): if "indicator_ids" in data and "product_ids" not in data: data["product_ids"] = data.pop("indicator_ids") return data season_start: int = Field(default=1, ge=1, le=12) season_end: int = Field(default=12, ge=1, le=12) def season_months(self) -> list[int]: """Return ordered list of month numbers in the analysis season. Supports year-boundary wrapping: season_start=10, season_end=3 yields [10, 11, 12, 1, 2, 3]. """ if self.season_start <= self.season_end: return list(range(self.season_start, self.season_end + 1)) else: return list(range(self.season_start, 13)) + list(range(1, self.season_end + 1)) @field_validator("product_ids") @classmethod def require_at_least_one_product(cls, v: list[str]) -> list[str]: if len(v) == 0: raise ValueError("At least one EO product must be selected") return v class ProductResult(BaseModel): product_id: str headline: str @model_validator(mode="before") @classmethod def _accept_legacy_and_sanitize(cls, data): """Accept old field names and sanitize NaN/inf floats.""" if isinstance(data, dict): if "indicator_id" in data and "product_id" not in data: data["product_id"] = data.pop("indicator_id") # Sanitize all float values to prevent JSON serialization errors return sanitize_for_json(data) return data status: StatusLevel trend: TrendDirection confidence: ConfidenceLevel map_layer_path: str chart_data: dict[str, Any] summary: str methodology: str limitations: list[str] data_source: str = "satellite" anomaly_months: int = 0 z_score_current: float = 0.0 hotspot_pct: float = 0.0 confidence_factors: dict[str, float] = Field(default_factory=dict) class Job(BaseModel): id: str request: JobRequest status: JobStatus = JobStatus.QUEUED created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) progress: dict[str, str] = Field(default_factory=dict) results: list[ProductResult] = Field(default_factory=list) error: str | None = None class ProductMeta(BaseModel): id: str name: str category: str question: str estimated_minutes: int class AoiAdviceRequest(BaseModel): bbox: list[float] = Field(min_length=4, max_length=4) class CompoundSignal(BaseModel): name: str triggered: bool confidence: str # "strong", "moderate", "weak" description: str indicators: list[str] overlap_pct: float = 0.0 affected_ha: float = 0.0