Aperture / app /models.py
KSvend
fix: sanitize NaN/inf floats to prevent JSON serialization crashes
2850a7e
from __future__ import annotations
import enum
from datetime import date, datetime
from typing import Any
import math
from pydantic import BaseModel, Field, field_validator, model_validator
from shapely.geometry import box as shapely_box
from pyproj import Geod
def _sanitize_float(v: Any) -> Any:
"""Replace NaN/inf float values with 0.0 for JSON compatibility."""
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
return 0.0
return v
def sanitize_for_json(obj: Any) -> Any:
"""Recursively sanitize a data structure, replacing NaN/inf with 0.0."""
if isinstance(obj, float):
return _sanitize_float(obj)
if isinstance(obj, dict):
return {k: sanitize_for_json(v) for k, v in obj.items()}
if isinstance(obj, list):
return [sanitize_for_json(v) for v in obj]
return obj
# --- East Africa bounding box (approximate) ---
EA_BOUNDS = (22.0, -5.0, 52.0, 23.0) # (min_lon, min_lat, max_lon, max_lat)
MAX_LOOKBACK_DAYS = 3 * 365 + 1 # ~3 years
class StatusLevel(str, enum.Enum):
GREEN = "green"
AMBER = "amber"
RED = "red"
class TrendDirection(str, enum.Enum):
IMPROVING = "improving"
STABLE = "stable"
DETERIORATING = "deteriorating"
class ConfidenceLevel(str, enum.Enum):
HIGH = "high"
MODERATE = "moderate"
LOW = "low"
class JobStatus(str, enum.Enum):
QUEUED = "queued"
PROCESSING = "processing"
COMPLETE = "complete"
FAILED = "failed"
class AOI(BaseModel):
name: str
bbox: list[float] = Field(min_length=4, max_length=4)
@property
def area_km2(self) -> float:
geod = Geod(ellps="WGS84")
poly = shapely_box(*self.bbox)
area_m2, _ = geod.geometry_area_perimeter(poly)
return abs(area_m2) / 1e6
@model_validator(mode="after")
def validate_geography(self) -> AOI:
min_lon, min_lat, max_lon, max_lat = self.bbox
ea_min_lon, ea_min_lat, ea_max_lon, ea_max_lat = EA_BOUNDS
# Check area first so "too large" error takes priority
from app.config import MAX_AOI_KM2
if self.area_km2 > MAX_AOI_KM2:
raise ValueError(
f"AOI area ({self.area_km2:.0f} km²) exceeds {MAX_AOI_KM2:,} km² limit"
)
if (
max_lon < ea_min_lon
or min_lon > ea_max_lon
or max_lat < ea_min_lat
or min_lat > ea_max_lat
):
raise ValueError(
"AOI must intersect the East Africa region "
f"({ea_min_lon}{ea_max_lon}°E, {ea_min_lat}{ea_max_lat}°N)"
)
return self
class TimeRange(BaseModel):
start: date = Field(default=None)
end: date = Field(default=None)
@model_validator(mode="after")
def set_defaults_and_validate(self) -> TimeRange:
today = date.today()
if self.end is None:
self.end = today
if self.start is None:
self.start = date(today.year - 1, today.month, today.day)
if (self.end - self.start).days > MAX_LOOKBACK_DAYS:
raise ValueError("Time range cannot exceed 3 years")
return self
class JobRequest(BaseModel):
aoi: AOI
time_range: TimeRange = Field(default_factory=TimeRange)
product_ids: list[str]
email: str
@model_validator(mode="before")
@classmethod
def _accept_legacy_field_names(cls, data):
"""Accept old 'indicator_ids' field name from stored database records."""
if isinstance(data, dict):
if "indicator_ids" in data and "product_ids" not in data:
data["product_ids"] = data.pop("indicator_ids")
return data
season_start: int = Field(default=1, ge=1, le=12)
season_end: int = Field(default=12, ge=1, le=12)
def season_months(self) -> list[int]:
"""Return ordered list of month numbers in the analysis season.
Supports year-boundary wrapping: season_start=10, season_end=3
yields [10, 11, 12, 1, 2, 3].
"""
if self.season_start <= self.season_end:
return list(range(self.season_start, self.season_end + 1))
else:
return list(range(self.season_start, 13)) + list(range(1, self.season_end + 1))
@field_validator("product_ids")
@classmethod
def require_at_least_one_product(cls, v: list[str]) -> list[str]:
if len(v) == 0:
raise ValueError("At least one EO product must be selected")
return v
class ProductResult(BaseModel):
product_id: str
headline: str
@model_validator(mode="before")
@classmethod
def _accept_legacy_and_sanitize(cls, data):
"""Accept old field names and sanitize NaN/inf floats."""
if isinstance(data, dict):
if "indicator_id" in data and "product_id" not in data:
data["product_id"] = data.pop("indicator_id")
# Sanitize all float values to prevent JSON serialization errors
return sanitize_for_json(data)
return data
status: StatusLevel
trend: TrendDirection
confidence: ConfidenceLevel
map_layer_path: str
chart_data: dict[str, Any]
summary: str
methodology: str
limitations: list[str]
data_source: str = "satellite"
anomaly_months: int = 0
z_score_current: float = 0.0
hotspot_pct: float = 0.0
confidence_factors: dict[str, float] = Field(default_factory=dict)
class Job(BaseModel):
id: str
request: JobRequest
status: JobStatus = JobStatus.QUEUED
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
progress: dict[str, str] = Field(default_factory=dict)
results: list[ProductResult] = Field(default_factory=list)
error: str | None = None
class ProductMeta(BaseModel):
id: str
name: str
category: str
question: str
estimated_minutes: int
class AoiAdviceRequest(BaseModel):
bbox: list[float] = Field(min_length=4, max_length=4)
class CompoundSignal(BaseModel):
name: str
triggered: bool
confidence: str # "strong", "moderate", "weak"
description: str
indicators: list[str]
overlap_pct: float = 0.0
affected_ha: float = 0.0