| from __future__ import annotations |
|
|
| import enum |
| from datetime import date, datetime |
| from typing import Any |
|
|
| import math |
|
|
| from pydantic import BaseModel, Field, field_validator, model_validator |
| from shapely.geometry import box as shapely_box |
| from pyproj import Geod |
|
|
|
|
| def _sanitize_float(v: Any) -> Any: |
| """Replace NaN/inf float values with 0.0 for JSON compatibility.""" |
| if isinstance(v, float) and (math.isnan(v) or math.isinf(v)): |
| return 0.0 |
| return v |
|
|
|
|
| def sanitize_for_json(obj: Any) -> Any: |
| """Recursively sanitize a data structure, replacing NaN/inf with 0.0.""" |
| if isinstance(obj, float): |
| return _sanitize_float(obj) |
| if isinstance(obj, dict): |
| return {k: sanitize_for_json(v) for k, v in obj.items()} |
| if isinstance(obj, list): |
| return [sanitize_for_json(v) for v in obj] |
| return obj |
|
|
| |
| EA_BOUNDS = (22.0, -5.0, 52.0, 23.0) |
| MAX_LOOKBACK_DAYS = 3 * 365 + 1 |
|
|
|
|
| class StatusLevel(str, enum.Enum): |
| GREEN = "green" |
| AMBER = "amber" |
| RED = "red" |
|
|
|
|
| class TrendDirection(str, enum.Enum): |
| IMPROVING = "improving" |
| STABLE = "stable" |
| DETERIORATING = "deteriorating" |
|
|
|
|
| class ConfidenceLevel(str, enum.Enum): |
| HIGH = "high" |
| MODERATE = "moderate" |
| LOW = "low" |
|
|
|
|
| class JobStatus(str, enum.Enum): |
| QUEUED = "queued" |
| PROCESSING = "processing" |
| COMPLETE = "complete" |
| FAILED = "failed" |
|
|
|
|
| class AOI(BaseModel): |
| name: str |
| bbox: list[float] = Field(min_length=4, max_length=4) |
|
|
| @property |
| def area_km2(self) -> float: |
| geod = Geod(ellps="WGS84") |
| poly = shapely_box(*self.bbox) |
| area_m2, _ = geod.geometry_area_perimeter(poly) |
| return abs(area_m2) / 1e6 |
|
|
| @model_validator(mode="after") |
| def validate_geography(self) -> AOI: |
| min_lon, min_lat, max_lon, max_lat = self.bbox |
| ea_min_lon, ea_min_lat, ea_max_lon, ea_max_lat = EA_BOUNDS |
| |
| from app.config import MAX_AOI_KM2 |
| if self.area_km2 > MAX_AOI_KM2: |
| raise ValueError( |
| f"AOI area ({self.area_km2:.0f} km²) exceeds {MAX_AOI_KM2:,} km² limit" |
| ) |
| if ( |
| max_lon < ea_min_lon |
| or min_lon > ea_max_lon |
| or max_lat < ea_min_lat |
| or min_lat > ea_max_lat |
| ): |
| raise ValueError( |
| "AOI must intersect the East Africa region " |
| f"({ea_min_lon}–{ea_max_lon}°E, {ea_min_lat}–{ea_max_lat}°N)" |
| ) |
| return self |
|
|
|
|
| class TimeRange(BaseModel): |
| start: date = Field(default=None) |
| end: date = Field(default=None) |
|
|
| @model_validator(mode="after") |
| def set_defaults_and_validate(self) -> TimeRange: |
| today = date.today() |
| if self.end is None: |
| self.end = today |
| if self.start is None: |
| self.start = date(today.year - 1, today.month, today.day) |
| if (self.end - self.start).days > MAX_LOOKBACK_DAYS: |
| raise ValueError("Time range cannot exceed 3 years") |
| return self |
|
|
|
|
| class JobRequest(BaseModel): |
| aoi: AOI |
| time_range: TimeRange = Field(default_factory=TimeRange) |
| product_ids: list[str] |
| email: str |
|
|
| @model_validator(mode="before") |
| @classmethod |
| def _accept_legacy_field_names(cls, data): |
| """Accept old 'indicator_ids' field name from stored database records.""" |
| if isinstance(data, dict): |
| if "indicator_ids" in data and "product_ids" not in data: |
| data["product_ids"] = data.pop("indicator_ids") |
| return data |
| season_start: int = Field(default=1, ge=1, le=12) |
| season_end: int = Field(default=12, ge=1, le=12) |
|
|
| def season_months(self) -> list[int]: |
| """Return ordered list of month numbers in the analysis season. |
| |
| Supports year-boundary wrapping: season_start=10, season_end=3 |
| yields [10, 11, 12, 1, 2, 3]. |
| """ |
| if self.season_start <= self.season_end: |
| return list(range(self.season_start, self.season_end + 1)) |
| else: |
| return list(range(self.season_start, 13)) + list(range(1, self.season_end + 1)) |
|
|
| @field_validator("product_ids") |
| @classmethod |
| def require_at_least_one_product(cls, v: list[str]) -> list[str]: |
| if len(v) == 0: |
| raise ValueError("At least one EO product must be selected") |
| return v |
|
|
|
|
| class ProductResult(BaseModel): |
| product_id: str |
| headline: str |
|
|
| @model_validator(mode="before") |
| @classmethod |
| def _accept_legacy_and_sanitize(cls, data): |
| """Accept old field names and sanitize NaN/inf floats.""" |
| if isinstance(data, dict): |
| if "indicator_id" in data and "product_id" not in data: |
| data["product_id"] = data.pop("indicator_id") |
| |
| return sanitize_for_json(data) |
| return data |
| status: StatusLevel |
| trend: TrendDirection |
| confidence: ConfidenceLevel |
| map_layer_path: str |
| chart_data: dict[str, Any] |
| summary: str |
| methodology: str |
| limitations: list[str] |
| data_source: str = "satellite" |
| anomaly_months: int = 0 |
| z_score_current: float = 0.0 |
| hotspot_pct: float = 0.0 |
| confidence_factors: dict[str, float] = Field(default_factory=dict) |
|
|
|
|
| class Job(BaseModel): |
| id: str |
| request: JobRequest |
| status: JobStatus = JobStatus.QUEUED |
| created_at: datetime = Field(default_factory=datetime.utcnow) |
| updated_at: datetime = Field(default_factory=datetime.utcnow) |
| progress: dict[str, str] = Field(default_factory=dict) |
| results: list[ProductResult] = Field(default_factory=list) |
| error: str | None = None |
|
|
|
|
| class ProductMeta(BaseModel): |
| id: str |
| name: str |
| category: str |
| question: str |
| estimated_minutes: int |
|
|
|
|
| class AoiAdviceRequest(BaseModel): |
| bbox: list[float] = Field(min_length=4, max_length=4) |
|
|
|
|
| class CompoundSignal(BaseModel): |
| name: str |
| triggered: bool |
| confidence: str |
| description: str |
| indicators: list[str] |
| overlap_pct: float = 0.0 |
| affected_ha: float = 0.0 |
|
|