Spaces:
Running
Running
GitHub Actions commited on
Commit 路
0d73bd2
1
Parent(s): c662f80
Auto-deploy from GitHub
Browse files- app.py +47 -2
- models/model_state.keras +2 -2
- models/scalers/scaler_dyn_global_state.pkl +1 -1
- models/scalers/scaler_target_global_state.pkl +1 -1
- models/state_peak.json +1 -0
- municipal_predictor.py +266 -0
- state_predictor.py +193 -0
app.py
CHANGED
|
@@ -9,7 +9,8 @@ import numpy as np
|
|
| 9 |
import json
|
| 10 |
|
| 11 |
from detect import DengueDetector
|
| 12 |
-
from
|
|
|
|
| 13 |
|
| 14 |
def default_json_serializer(obj):
|
| 15 |
if isinstance(obj, np.integer):
|
|
@@ -22,16 +23,23 @@ def default_json_serializer(obj):
|
|
| 22 |
|
| 23 |
detector: DengueDetector = None
|
| 24 |
predictor: DenguePredictor = None
|
|
|
|
| 25 |
|
| 26 |
app = FastAPI()
|
| 27 |
|
| 28 |
# --- evento de startup para carregar os modelos ---
|
| 29 |
@app.on_event("startup")
|
| 30 |
async def startup_event():
|
| 31 |
-
global detector, predictor
|
| 32 |
print("Executando evento de startup: Carregando os m贸dulos de IA...")
|
| 33 |
detector = DengueDetector()
|
| 34 |
predictor = DenguePredictor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
print("M贸dulos de IA carregados com sucesso. API pronta.")
|
| 36 |
|
| 37 |
# --- CORS ---
|
|
@@ -71,6 +79,7 @@ async def predict_dengue_route(payload: dict = Body(...)):
|
|
| 71 |
raise ValueError("O campo 'ibge_code' 茅 obrigat贸rio.")
|
| 72 |
|
| 73 |
ibge_code = int(ibge_code_str)
|
|
|
|
| 74 |
result = predictor.predict(ibge_code)
|
| 75 |
|
| 76 |
json_content = json.dumps(result, default=default_json_serializer)
|
|
@@ -83,4 +92,40 @@ async def predict_dengue_route(payload: dict = Body(...)):
|
|
| 83 |
return JSONResponse(status_code=500, content={
|
| 84 |
"error": str(e),
|
| 85 |
"traceback": tb_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
})
|
|
|
|
| 9 |
import json
|
| 10 |
|
| 11 |
from detect import DengueDetector
|
| 12 |
+
from municipal_predictor import DenguePredictor
|
| 13 |
+
from state_predictor import StatePredictor
|
| 14 |
|
| 15 |
def default_json_serializer(obj):
|
| 16 |
if isinstance(obj, np.integer):
|
|
|
|
| 23 |
|
| 24 |
detector: DengueDetector = None
|
| 25 |
predictor: DenguePredictor = None
|
| 26 |
+
state_predictor: StatePredictor = None
|
| 27 |
|
| 28 |
app = FastAPI()
|
| 29 |
|
| 30 |
# --- evento de startup para carregar os modelos ---
|
| 31 |
@app.on_event("startup")
|
| 32 |
async def startup_event():
|
| 33 |
+
global detector, predictor, state_predictor
|
| 34 |
print("Executando evento de startup: Carregando os m贸dulos de IA...")
|
| 35 |
detector = DengueDetector()
|
| 36 |
predictor = DenguePredictor()
|
| 37 |
+
try:
|
| 38 |
+
state_predictor = StatePredictor()
|
| 39 |
+
except Exception as e:
|
| 40 |
+
# N茫o bloqueia a API se o modelo estadual faltar; a rota retornar谩 503
|
| 41 |
+
print("[WARN] StatePredictor n茫o inicializado:", str(e))
|
| 42 |
+
state_predictor = None
|
| 43 |
print("M贸dulos de IA carregados com sucesso. API pronta.")
|
| 44 |
|
| 45 |
# --- CORS ---
|
|
|
|
| 79 |
raise ValueError("O campo 'ibge_code' 茅 obrigat贸rio.")
|
| 80 |
|
| 81 |
ibge_code = int(ibge_code_str)
|
| 82 |
+
# Sempre retorna hist贸rico completo; frontend controla a janela vis铆vel
|
| 83 |
result = predictor.predict(ibge_code)
|
| 84 |
|
| 85 |
json_content = json.dumps(result, default=default_json_serializer)
|
|
|
|
| 92 |
return JSONResponse(status_code=500, content={
|
| 93 |
"error": str(e),
|
| 94 |
"traceback": tb_str
|
| 95 |
+
})
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@app.post("/predict/state/")
|
| 99 |
+
async def predict_dengue_state_route(payload: dict = Body(...)):
|
| 100 |
+
global state_predictor
|
| 101 |
+
if state_predictor is None:
|
| 102 |
+
# Tenta inicializar pregui莽osamente no primeiro uso
|
| 103 |
+
try:
|
| 104 |
+
state_predictor = StatePredictor()
|
| 105 |
+
except Exception as e:
|
| 106 |
+
return JSONResponse(status_code=503, content={"error": f"Preditor estadual ainda n茫o foi inicializado: {str(e)}"})
|
| 107 |
+
try:
|
| 108 |
+
state_sigla = payload.get("state") or payload.get("state_sigla") or payload.get("uf")
|
| 109 |
+
year = payload.get("year")
|
| 110 |
+
week = payload.get("week")
|
| 111 |
+
if not state_sigla:
|
| 112 |
+
raise ValueError("O campo 'state' (sigla) 茅 obrigat贸rio.")
|
| 113 |
+
|
| 114 |
+
# year/week s茫o opcionais; se omitidos, prev锚 ap贸s o 煤ltimo ponto conhecido
|
| 115 |
+
# Sempre retorna hist贸rico completo; frontend controla a janela vis铆vel
|
| 116 |
+
result = state_predictor.predict(
|
| 117 |
+
str(state_sigla).upper(),
|
| 118 |
+
year=int(year) if year is not None else None,
|
| 119 |
+
week=int(week) if week is not None else None,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
json_content = json.dumps(result, default=default_json_serializer)
|
| 123 |
+
return Response(content=json_content, media_type="application/json")
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
tb_str = traceback.format_exc()
|
| 127 |
+
print(tb_str)
|
| 128 |
+
return JSONResponse(status_code=500, content={
|
| 129 |
+
"error": str(e),
|
| 130 |
+
"traceback": tb_str
|
| 131 |
})
|
models/model_state.keras
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4c86d8e99e3779125ec864816e0fbf96f72a8e324e40a5e170182168a617b30
|
| 3 |
+
size 2536309
|
models/scalers/scaler_dyn_global_state.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1303
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc7972df5abd0302686c2d6ff16962ff31a13c5ca5346cbe57633de1ec34f1c1
|
| 3 |
size 1303
|
models/scalers/scaler_target_global_state.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 719
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a4e97671eeabf05f39cb9a6b53130816103d263c6bfffd9fc7fbee5f9c77178
|
| 3 |
size 719
|
models/state_peak.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"AC": 2036.0, "AL": 2338.0, "AM": 388.0, "AP": 594.0, "BA": 7520.0, "CE": 2754.0, "DF": 6456.0, "ES": 2598.0, "GO": 13984.0, "MA": 1139.0, "MG": 68685.0, "MS": 3781.0, "MT": 1923.0, "PA": 810.0, "PB": 1613.0, "PE": 2249.0, "PI": 2913.0, "PR": 39913.0, "RJ": 12162.0, "RN": 2868.0, "RO": 631.0, "RR": 63.0, "RS": 16798.0, "SC": 26832.0, "SE": 408.0, "SP": 129817.0, "TO": 1289.0}
|
municipal_predictor.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import joblib
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from datetime import timedelta
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
import base64
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
from tensorflow.keras.utils import register_keras_serializable
|
| 12 |
+
import matplotlib
|
| 13 |
+
matplotlib.use('Agg')
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
|
| 17 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
| 18 |
+
|
| 19 |
+
@register_keras_serializable(package="Custom", name="asymmetric_mse")
|
| 20 |
+
def asymmetric_mse(y_true, y_pred):
|
| 21 |
+
penalty_factor = 10.0
|
| 22 |
+
error = y_true - y_pred
|
| 23 |
+
denom = tf.maximum(tf.abs(y_true), 1.0)
|
| 24 |
+
rel = tf.abs(error) / denom
|
| 25 |
+
penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)
|
| 26 |
+
loss = tf.square(error) * penalty
|
| 27 |
+
return tf.reduce_mean(loss)
|
| 28 |
+
|
| 29 |
+
class DenguePredictor:
|
| 30 |
+
def __init__(self, project_root=None):
|
| 31 |
+
self.project_root = Path(project_root) if project_root else Path(__file__).resolve().parent
|
| 32 |
+
self.sequence_length = 12
|
| 33 |
+
self.horizon = 6
|
| 34 |
+
self.year_min_train = 2014
|
| 35 |
+
self.year_max_train = 2025
|
| 36 |
+
self.dynamic_features = [
|
| 37 |
+
"numero_casos", "casos_velocidade", "casos_aceleracao", "casos_mm_4_semanas",
|
| 38 |
+
"T2M", "T2M_MAX", "T2M_MIN", "PRECTOTCORR", "RH2M", "ALLSKY_SFC_SW_DWN",
|
| 39 |
+
"week_sin", "week_cos", "year_norm", "notificacao"
|
| 40 |
+
]
|
| 41 |
+
self.static_features = ["latitude", "longitude"]
|
| 42 |
+
self.feature_names_pt = {
|
| 43 |
+
"numero_casos": "N潞 de Casos de Dengue",
|
| 44 |
+
"T2M": "Temperatura M茅dia (掳C)",
|
| 45 |
+
"PRECTOTCORR": "Precipita莽茫o (mm)"
|
| 46 |
+
}
|
| 47 |
+
self._loaded = False
|
| 48 |
+
self.load_assets()
|
| 49 |
+
|
| 50 |
+
def load_assets(self):
|
| 51 |
+
models_dir = self.project_root / "models"
|
| 52 |
+
scalers_dir = models_dir / "scalers"
|
| 53 |
+
model_path = models_dir / "model.keras"
|
| 54 |
+
city_map_path = models_dir / "city_to_idx.json"
|
| 55 |
+
|
| 56 |
+
if not scalers_dir.exists():
|
| 57 |
+
raise FileNotFoundError(str(scalers_dir) + " not found")
|
| 58 |
+
|
| 59 |
+
self.scaler_dyn = joblib.load(scalers_dir / "scaler_dyn_global.pkl")
|
| 60 |
+
self.scaler_static = joblib.load(scalers_dir / "scaler_static_global.pkl")
|
| 61 |
+
self.scaler_target = joblib.load(scalers_dir / "scaler_target_global.pkl")
|
| 62 |
+
|
| 63 |
+
if city_map_path.exists():
|
| 64 |
+
with open(city_map_path, "r", encoding="utf-8") as fh:
|
| 65 |
+
self.city_to_idx = {int(k): int(v) for k, v in json.load(fh).items()}
|
| 66 |
+
else:
|
| 67 |
+
self.city_to_idx = {}
|
| 68 |
+
|
| 69 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 70 |
+
inference_path = hf_hub_download(
|
| 71 |
+
repo_id="previdengue/predict_inference_data",
|
| 72 |
+
filename="inference_data.parquet",
|
| 73 |
+
repo_type="dataset",
|
| 74 |
+
token=hf_token
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
df = pd.read_parquet(inference_path)
|
| 78 |
+
df["codigo_ibge"] = df["codigo_ibge"].astype(int)
|
| 79 |
+
df["ano"] = df["ano"].astype(int)
|
| 80 |
+
df["semana"] = df["semana"].astype(int)
|
| 81 |
+
try:
|
| 82 |
+
df["date"] = pd.to_datetime(df["ano"].astype(str) + df["semana"].astype(str) + "0", format="%Y%W%w", errors="coerce")
|
| 83 |
+
except Exception:
|
| 84 |
+
df["date"] = pd.NaT
|
| 85 |
+
|
| 86 |
+
df = df.sort_values(by=["codigo_ibge", "date"]).reset_index(drop=True)
|
| 87 |
+
df["week_sin"] = np.sin(2 * np.pi * df["semana"] / 52)
|
| 88 |
+
df["week_cos"] = np.cos(2 * np.pi * df["semana"] / 52)
|
| 89 |
+
df["year_norm"] = (df["ano"] - self.year_min_train) / (self.year_max_train - self.year_min_train)
|
| 90 |
+
df["notificacao"] = df["ano"].isin([2021, 2022]).astype(float)
|
| 91 |
+
|
| 92 |
+
self.df_master = df
|
| 93 |
+
self.municipios = df[["codigo_ibge", "municipio"]].drop_duplicates().sort_values("codigo_ibge")
|
| 94 |
+
|
| 95 |
+
if not model_path.exists():
|
| 96 |
+
raise FileNotFoundError(str(model_path) + " not found")
|
| 97 |
+
|
| 98 |
+
self.model = tf.keras.models.load_model(model_path, custom_objects={"asymmetric_mse": asymmetric_mse}, compile=False)
|
| 99 |
+
self._loaded = True
|
| 100 |
+
|
| 101 |
+
def plot_to_base64(self, fig):
|
| 102 |
+
buf = BytesIO()
|
| 103 |
+
fig.savefig(buf, format="png", bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 104 |
+
buf.seek(0)
|
| 105 |
+
img_str = base64.b64encode(buf.read()).decode("utf-8")
|
| 106 |
+
plt.close(fig)
|
| 107 |
+
return img_str
|
| 108 |
+
|
| 109 |
+
def _prepare_sequence(self, df_mun):
|
| 110 |
+
df_seq = df_mun.tail(self.sequence_length).copy()
|
| 111 |
+
df_seq["casos_velocidade"] = df_seq["numero_casos"].diff().fillna(0)
|
| 112 |
+
df_seq["casos_aceleracao"] = df_seq["casos_velocidade"].diff().fillna(0)
|
| 113 |
+
df_seq["casos_mm_4_semanas"] = df_seq["numero_casos"].rolling(4, min_periods=1).mean()
|
| 114 |
+
df_seq["week_sin"] = np.sin(2 * np.pi * df_seq["semana"] / 52)
|
| 115 |
+
df_seq["week_cos"] = np.cos(2 * np.pi * df_seq["semana"] / 52)
|
| 116 |
+
df_seq["year_norm"] = (df_seq["ano"] - self.year_min_train) / (self.year_max_train - self.year_min_train)
|
| 117 |
+
if "notificacao" not in df_seq.columns:
|
| 118 |
+
df_seq["notificacao"] = df_seq["ano"].isin([2021, 2022]).astype(float)
|
| 119 |
+
else:
|
| 120 |
+
df_seq["notificacao"] = df_seq["notificacao"].astype(float)
|
| 121 |
+
return df_seq
|
| 122 |
+
|
| 123 |
+
def predict(self, ibge_code: int, show_plot=False, display_history_weeks=None):
|
| 124 |
+
if not self._loaded:
|
| 125 |
+
raise RuntimeError("assets not loaded")
|
| 126 |
+
|
| 127 |
+
df_mun = self.df_master[self.df_master["codigo_ibge"] == int(ibge_code)].copy().reset_index(drop=True)
|
| 128 |
+
if df_mun.empty or len(df_mun) < self.sequence_length:
|
| 129 |
+
raise ValueError(f"No data or insufficient history for ibge {ibge_code}")
|
| 130 |
+
|
| 131 |
+
municipio_row = self.municipios[self.municipios["codigo_ibge"] == int(ibge_code)]
|
| 132 |
+
municipality_name = municipio_row.iloc[0]["municipio"] if not municipio_row.empty else str(ibge_code)
|
| 133 |
+
|
| 134 |
+
df_mun_clean = df_mun.dropna(subset=["numero_casos"]).reset_index(drop=True)
|
| 135 |
+
if len(df_mun_clean) < self.sequence_length:
|
| 136 |
+
raise ValueError(f"Insufficient known-case history for {ibge_code}")
|
| 137 |
+
|
| 138 |
+
seq_df = self._prepare_sequence(df_mun_clean)
|
| 139 |
+
if len(seq_df) < self.sequence_length:
|
| 140 |
+
raise ValueError(f"Insufficient sequence length for {ibge_code}")
|
| 141 |
+
|
| 142 |
+
dynamic_raw = seq_df[self.dynamic_features].values
|
| 143 |
+
static_raw = seq_df[self.static_features].iloc[-1].values.reshape(1, -1)
|
| 144 |
+
|
| 145 |
+
missing_feats = [c for c in self.dynamic_features if c not in seq_df.columns]
|
| 146 |
+
if missing_feats:
|
| 147 |
+
raise ValueError(f"Missing dynamic features in dataframe: {missing_feats}")
|
| 148 |
+
if hasattr(self.scaler_dyn, "n_features_in_") and self.scaler_dyn.n_features_in_ != len(self.dynamic_features):
|
| 149 |
+
raise ValueError(
|
| 150 |
+
f"Dynamic scaler expects {getattr(self.scaler_dyn, 'n_features_in_', 'unknown')} features, "
|
| 151 |
+
f"but predictor assembled {len(self.dynamic_features)}. Ensure training and inference feature sets match."
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
dynamic_scaled = self.scaler_dyn.transform(dynamic_raw).reshape(1, self.sequence_length, -1)
|
| 155 |
+
static_scaled = self.scaler_static.transform(static_raw)
|
| 156 |
+
|
| 157 |
+
city_idx = int(self.city_to_idx.get(int(ibge_code), 0))
|
| 158 |
+
city_input = np.array([[city_idx]], dtype=np.int32)
|
| 159 |
+
|
| 160 |
+
y_pred = self.model.predict([dynamic_scaled, static_scaled, city_input], verbose=0)
|
| 161 |
+
y_pred_reg = y_pred[0] if isinstance(y_pred, (list, tuple)) else y_pred
|
| 162 |
+
|
| 163 |
+
y_pred_flat = y_pred_reg.reshape(-1, 1)
|
| 164 |
+
y_pred_inv_flat = self.scaler_target.inverse_transform(y_pred_flat)
|
| 165 |
+
y_pred_inv = y_pred_inv_flat.reshape(y_pred_reg.shape)
|
| 166 |
+
pred_values = np.maximum(y_pred_inv.flatten(), 0.0)
|
| 167 |
+
|
| 168 |
+
last_known_case = seq_df["numero_casos"].iloc[-1]
|
| 169 |
+
connected_prediction = np.insert(pred_values, 0, last_known_case)
|
| 170 |
+
|
| 171 |
+
last_real_date = seq_df["date"].iloc[-1] if "date" in seq_df.columns else None
|
| 172 |
+
predicted_data = []
|
| 173 |
+
for i, val in enumerate(connected_prediction[1:]):
|
| 174 |
+
pred_date = (last_real_date + timedelta(weeks=i + 1)).strftime("%Y-%m-%d") if pd.notna(last_real_date) else None
|
| 175 |
+
predicted_data.append({"date": pred_date, "predicted_cases": int(round(float(val)))})
|
| 176 |
+
|
| 177 |
+
# Hist贸rico: por padr茫o retorna tudo; se display_history_weeks > 0, limita a janela
|
| 178 |
+
if display_history_weeks is None or (isinstance(display_history_weeks, (int, float)) and display_history_weeks <= 0):
|
| 179 |
+
hist_tail = df_mun.copy()
|
| 180 |
+
else:
|
| 181 |
+
hist_tail = df_mun.tail(min(len(df_mun), int(display_history_weeks))).copy()
|
| 182 |
+
historic_data = []
|
| 183 |
+
for _, row in hist_tail.iterrows():
|
| 184 |
+
historic_data.append({
|
| 185 |
+
"date": row["date"].strftime("%Y-%m-%d") if pd.notna(row.get("date")) else None,
|
| 186 |
+
"cases": int(row["numero_casos"]) if pd.notna(row.get("numero_casos")) else None
|
| 187 |
+
})
|
| 188 |
+
# Insights: lag correlation analysis and strategic summary
|
| 189 |
+
lag_plot_b64, strategic_summary, tipping_points = self.generate_lag_insights(df_mun)
|
| 190 |
+
|
| 191 |
+
insights = {
|
| 192 |
+
"lag_analysis_plot_base64": lag_plot_b64,
|
| 193 |
+
"strategic_summary": strategic_summary,
|
| 194 |
+
"tipping_points": tipping_points
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"municipality_name": municipality_name,
|
| 199 |
+
"ibge": int(ibge_code),
|
| 200 |
+
"last_known_index": int(df_mun.index[-1]),
|
| 201 |
+
"historic_data": historic_data,
|
| 202 |
+
"predicted_data": predicted_data,
|
| 203 |
+
"insights": insights,
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
def generate_lag_insights(self, df_mun: pd.DataFrame):
|
| 207 |
+
# Prepare analysis columns
|
| 208 |
+
df_analysis = df_mun.rename(columns={
|
| 209 |
+
"T2M": "Temperature_C",
|
| 210 |
+
"PRECTOTCORR": "Precipitation_mm"
|
| 211 |
+
})
|
| 212 |
+
max_lag = 12
|
| 213 |
+
cases_col = "numero_casos"
|
| 214 |
+
lag_features = ["Temperature_C", "Precipitation_mm"]
|
| 215 |
+
lag_correlations = {}
|
| 216 |
+
|
| 217 |
+
for col in lag_features:
|
| 218 |
+
if col in df_analysis.columns:
|
| 219 |
+
corrs = []
|
| 220 |
+
for lag in range(1, max_lag + 1):
|
| 221 |
+
try:
|
| 222 |
+
corr = df_analysis[cases_col].corr(df_analysis[col].shift(lag))
|
| 223 |
+
except Exception:
|
| 224 |
+
corr = np.nan
|
| 225 |
+
corrs.append(corr)
|
| 226 |
+
lag_correlations[col] = corrs
|
| 227 |
+
else:
|
| 228 |
+
lag_correlations[col] = [np.nan] * max_lag
|
| 229 |
+
|
| 230 |
+
# Plot
|
| 231 |
+
fig, ax = plt.subplots(figsize=(10, 6), facecolor="#18181b")
|
| 232 |
+
ax.set_facecolor("#18181b")
|
| 233 |
+
for feature_name, corrs in lag_correlations.items():
|
| 234 |
+
ax.plot(range(1, max_lag + 1), corrs, marker="o", linestyle="-", label=feature_name)
|
| 235 |
+
ax.set_title("Lag Analysis", color="white")
|
| 236 |
+
ax.set_xlabel("Lag (weeks)", color="white")
|
| 237 |
+
ax.set_ylabel("Correlation with cases", color="white")
|
| 238 |
+
ax.tick_params(colors="white")
|
| 239 |
+
ax.legend(facecolor="#27272a", edgecolor="gray", labelcolor="white")
|
| 240 |
+
ax.grid(True, which="both", linestyle="--", linewidth=0.5, color="#444")
|
| 241 |
+
lag_plot_b64 = self.plot_to_base64(fig)
|
| 242 |
+
|
| 243 |
+
# Summaries
|
| 244 |
+
lag_peaks = {}
|
| 245 |
+
for feature, corrs in lag_correlations.items():
|
| 246 |
+
if corrs and not all(pd.isna(corrs)):
|
| 247 |
+
peak = int(np.nanargmax(np.abs(np.array(corrs))) + 1)
|
| 248 |
+
else:
|
| 249 |
+
peak = "N/A"
|
| 250 |
+
lag_peaks[feature] = peak
|
| 251 |
+
|
| 252 |
+
temp_lag = lag_peaks.get("Temperature_C", "N/A")
|
| 253 |
+
rain_lag = lag_peaks.get("Precipitation_mm", "N/A")
|
| 254 |
+
summary = (
|
| 255 |
+
f"O modelo identifica Temperatura e Precipita莽茫o como fatores clim谩ticos chave. "
|
| 256 |
+
f"Temperatura mostra impacto m谩ximo ap贸s {temp_lag} semanas e precipita莽茫o ap贸s {rain_lag} semanas. "
|
| 257 |
+
"A莽玫es preventivas devem ser intensificadas nessas janelas ap贸s eventos clim谩ticos extremos."
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
tipping_points = [
|
| 261 |
+
{"factor": "Temperatura", "value": f"Maior impacto em {temp_lag} semanas"},
|
| 262 |
+
{"factor": "Precipita莽茫o", "value": f"Maior impacto em {rain_lag} semanas"},
|
| 263 |
+
{"factor": "Umidade", "value": "Aumenta a sobreviv锚ncia de mosquitos adultos"}
|
| 264 |
+
]
|
| 265 |
+
|
| 266 |
+
return lag_plot_b64, summary, tipping_points
|
state_predictor.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import joblib
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from datetime import timedelta
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
from tensorflow.keras.utils import register_keras_serializable
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
|
| 12 |
+
@register_keras_serializable(package="Custom", name="asymmetric_mse")
|
| 13 |
+
def asymmetric_mse(y_true, y_pred):
|
| 14 |
+
penalty_factor = 10.0
|
| 15 |
+
error = y_true - y_pred
|
| 16 |
+
denom = tf.maximum(tf.abs(y_true), 1.0)
|
| 17 |
+
rel = tf.abs(error) / denom
|
| 18 |
+
penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)
|
| 19 |
+
loss = tf.square(error) * penalty
|
| 20 |
+
return tf.reduce_mean(loss)
|
| 21 |
+
|
| 22 |
+
class StatePredictor:
|
| 23 |
+
def __init__(self, project_root=None):
|
| 24 |
+
self.project_root = Path(project_root) if project_root else Path(__file__).resolve().parent
|
| 25 |
+
self.sequence_length = 12
|
| 26 |
+
self.horizon = 6
|
| 27 |
+
self.dynamic_features = [
|
| 28 |
+
"casos_norm_log",
|
| 29 |
+
"casos_velocidade", "casos_aceleracao", "casos_mm_4_semanas",
|
| 30 |
+
"T2M_mean","T2M_std","PRECTOTCORR_mean","PRECTOTCORR_std",
|
| 31 |
+
"RH2M_mean","RH2M_std","ALLSKY_SFC_SW_DWN_mean","ALLSKY_SFC_SW_DWN_std",
|
| 32 |
+
"week_sin","week_cos","year_norm","notificacao"
|
| 33 |
+
]
|
| 34 |
+
self.static_features = ["populacao_total"]
|
| 35 |
+
self._loaded = False
|
| 36 |
+
self.load_assets()
|
| 37 |
+
|
| 38 |
+
def load_assets(self):
|
| 39 |
+
models_dir = self.project_root / "models"
|
| 40 |
+
scalers_dir = models_dir / "scalers"
|
| 41 |
+
model_path = models_dir / "model_state.keras"
|
| 42 |
+
state_map_path = models_dir / "state_to_idx.json"
|
| 43 |
+
state_peak_path = models_dir / "state_peak.json"
|
| 44 |
+
|
| 45 |
+
# scalers
|
| 46 |
+
dyn_state = scalers_dir / "scaler_dyn_global_state.pkl"
|
| 47 |
+
static_state = scalers_dir / "scaler_static_global_state.pkl"
|
| 48 |
+
target_state = scalers_dir / "scaler_target_global_state.pkl"
|
| 49 |
+
if not dyn_state.exists() or not static_state.exists() or not target_state.exists():
|
| 50 |
+
raise FileNotFoundError("State scalers not found under models/scalers. Expected *_state.pkl files.")
|
| 51 |
+
self.scaler_dyn = joblib.load(dyn_state)
|
| 52 |
+
self.scaler_static = joblib.load(static_state)
|
| 53 |
+
self.scaler_target = joblib.load(target_state)
|
| 54 |
+
|
| 55 |
+
# mappings
|
| 56 |
+
if state_map_path.exists():
|
| 57 |
+
with open(state_map_path, "r", encoding="utf-8") as fh:
|
| 58 |
+
self.state_to_idx = json.load(fh)
|
| 59 |
+
else:
|
| 60 |
+
self.state_to_idx = {}
|
| 61 |
+
if state_peak_path.exists():
|
| 62 |
+
with open(state_peak_path, "r", encoding="utf-8") as fh:
|
| 63 |
+
self.state_peak_map = json.load(fh)
|
| 64 |
+
else:
|
| 65 |
+
self.state_peak_map = {}
|
| 66 |
+
|
| 67 |
+
# inference dataset (HF only)
|
| 68 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 69 |
+
hf_repo = "previdengue/predict_inference_data_estadual"
|
| 70 |
+
hf_filename = "inference_data_estadual.parquet"
|
| 71 |
+
try:
|
| 72 |
+
hf_path = hf_hub_download(
|
| 73 |
+
repo_id=hf_repo,
|
| 74 |
+
filename=hf_filename,
|
| 75 |
+
repo_type="dataset",
|
| 76 |
+
token=hf_token,
|
| 77 |
+
)
|
| 78 |
+
df_loaded = pd.read_parquet(hf_path)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
raise FileNotFoundError(
|
| 81 |
+
"Could not download 'inference_data_estadual.parquet' from HF repo 'previdengue/predict_inference_data_estadual'. "
|
| 82 |
+
"Ensure the dataset exists and set HF_TOKEN if the repo requires authentication."
|
| 83 |
+
) from e
|
| 84 |
+
|
| 85 |
+
# normalize
|
| 86 |
+
df = df_loaded.copy()
|
| 87 |
+
required = ["estado_sigla", "year", "week", "casos_soma"]
|
| 88 |
+
if any(col not in df.columns for col in required):
|
| 89 |
+
raise ValueError("State dataset missing required columns: ['estado_sigla','year','week','casos_soma']")
|
| 90 |
+
df["estado_sigla"] = df["estado_sigla"].astype(str)
|
| 91 |
+
df = df.sort_values(["estado_sigla", "year", "week"]).reset_index(drop=True)
|
| 92 |
+
if "date" not in df.columns:
|
| 93 |
+
try:
|
| 94 |
+
df["date"] = pd.to_datetime(df["year"].astype(str) + df["week"].astype(str) + "0", format="%Y%W%w", errors="coerce")
|
| 95 |
+
except Exception:
|
| 96 |
+
pass
|
| 97 |
+
if "week_sin" not in df.columns:
|
| 98 |
+
df["week_sin"] = np.sin(2*np.pi*df["week"]/52)
|
| 99 |
+
if "week_cos" not in df.columns:
|
| 100 |
+
df["week_cos"] = np.cos(2*np.pi*df["week"]/52)
|
| 101 |
+
if "year_norm" not in df.columns:
|
| 102 |
+
year_min, year_max = df["year"].min(), df["year"].max()
|
| 103 |
+
df["year_norm"] = (df["year"] - year_min) / max(1.0, (year_max - year_min))
|
| 104 |
+
df["notificacao"] = df["year"].isin([2021, 2022]).astype(float)
|
| 105 |
+
|
| 106 |
+
self.df_state = df
|
| 107 |
+
if not model_path.exists():
|
| 108 |
+
raise FileNotFoundError(str(model_path) + " not found")
|
| 109 |
+
self.model = tf.keras.models.load_model(model_path, custom_objects={"asymmetric_mse": asymmetric_mse}, compile=False)
|
| 110 |
+
self._loaded = True
|
| 111 |
+
|
| 112 |
+
def _prepare_state_sequence(self, df_st: pd.DataFrame, state_sigla: str):
|
| 113 |
+
df_st = df_st.copy()
|
| 114 |
+
df_st['casos_velocidade'] = df_st['casos_soma'].diff().fillna(0)
|
| 115 |
+
df_st['casos_aceleracao'] = df_st['casos_velocidade'].diff().fillna(0)
|
| 116 |
+
df_st['casos_mm_4_semanas'] = df_st['casos_soma'].rolling(4, min_periods=1).mean()
|
| 117 |
+
if "notificacao" not in df_st.columns:
|
| 118 |
+
df_st["notificacao"] = df_st["year"].isin([2021, 2022]).astype(float)
|
| 119 |
+
peak = float(self.state_peak_map.get(state_sigla, 1.0))
|
| 120 |
+
if peak <= 0:
|
| 121 |
+
peak = 1.0
|
| 122 |
+
df_st["casos_norm"] = df_st["casos_soma"] / peak
|
| 123 |
+
df_st["casos_norm_log"] = np.log1p(df_st["casos_norm"])
|
| 124 |
+
return df_st
|
| 125 |
+
|
| 126 |
+
def predict(self, state_sigla: str, year: int = None, week: int = None, display_history_weeks: int | None = None):
|
| 127 |
+
if not self._loaded:
|
| 128 |
+
raise RuntimeError("state assets not loaded")
|
| 129 |
+
st = str(state_sigla).upper()
|
| 130 |
+
df_st = self.df_state[self.df_state["estado_sigla"] == st].copy().sort_values(["year","week"]).reset_index(drop=True)
|
| 131 |
+
if df_st.empty or len(df_st) < self.sequence_length:
|
| 132 |
+
raise ValueError(f"No data or insufficient history for state {st}")
|
| 133 |
+
df_st = self._prepare_state_sequence(df_st, st)
|
| 134 |
+
if year is not None and week is not None:
|
| 135 |
+
idx_list = df_st.index[(df_st['year'] == int(year)) & (df_st['week'] == int(week))].tolist()
|
| 136 |
+
if not idx_list:
|
| 137 |
+
raise ValueError("Prediction point (year/week) not found in state series")
|
| 138 |
+
pred_point_idx = idx_list[0]
|
| 139 |
+
else:
|
| 140 |
+
pred_point_idx = len(df_st)
|
| 141 |
+
last_known_idx = pred_point_idx - 1
|
| 142 |
+
if last_known_idx < self.sequence_length - 1:
|
| 143 |
+
raise ValueError("Insufficient sequence window before prediction point")
|
| 144 |
+
start_idx = last_known_idx - self.sequence_length + 1
|
| 145 |
+
input_seq = df_st.iloc[start_idx:last_known_idx+1].copy()
|
| 146 |
+
for col in self.static_features:
|
| 147 |
+
if col not in input_seq.columns:
|
| 148 |
+
input_seq[col] = 0.0
|
| 149 |
+
static_raw = input_seq[self.static_features].iloc[0].values.reshape(1, -1)
|
| 150 |
+
missing_dyn = [c for c in self.dynamic_features if c not in input_seq.columns]
|
| 151 |
+
if missing_dyn:
|
| 152 |
+
raise ValueError(f"Missing dynamic state features: {missing_dyn}")
|
| 153 |
+
dyn_raw = input_seq[self.dynamic_features].values
|
| 154 |
+
if hasattr(self.scaler_dyn, "n_features_in_") and self.scaler_dyn.n_features_in_ != len(self.dynamic_features):
|
| 155 |
+
raise ValueError(
|
| 156 |
+
f"State dynamic scaler expects {self.scaler_dyn.n_features_in_} features, got {len(self.dynamic_features)}."
|
| 157 |
+
)
|
| 158 |
+
dyn_scaled = self.scaler_dyn.transform(dyn_raw).reshape(1, self.sequence_length, len(self.dynamic_features))
|
| 159 |
+
static_scaled = self.scaler_static.transform(static_raw)
|
| 160 |
+
state_idx = int(self.state_to_idx.get(st, 0))
|
| 161 |
+
state_input = np.array([[state_idx]], dtype=np.int32)
|
| 162 |
+
y_pred = self.model.predict([dyn_scaled, static_scaled, state_input], verbose=0)
|
| 163 |
+
y_pred_reg = y_pred[0] if isinstance(y_pred, (list, tuple)) else y_pred
|
| 164 |
+
y_pred_log_norm = self.scaler_target.inverse_transform(y_pred_reg.reshape(-1,1)).reshape(y_pred_reg.shape)
|
| 165 |
+
y_pred_norm = np.expm1(y_pred_log_norm)
|
| 166 |
+
peak = float(self.state_peak_map.get(st, 1.0))
|
| 167 |
+
if peak <= 0:
|
| 168 |
+
peak = 1.0
|
| 169 |
+
prediction_counts = np.maximum(y_pred_norm.flatten() * peak, 0.0)
|
| 170 |
+
last_known_date = df_st.iloc[last_known_idx]['date'] if 'date' in df_st.columns and last_known_idx < len(df_st) else None
|
| 171 |
+
predicted_data = []
|
| 172 |
+
for i, val in enumerate(prediction_counts):
|
| 173 |
+
if pd.notna(last_known_date):
|
| 174 |
+
pred_date = (last_known_date + timedelta(weeks=i+1)).strftime("%Y-%m-%d")
|
| 175 |
+
else:
|
| 176 |
+
pred_date = None
|
| 177 |
+
predicted_data.append({"date": pred_date, "predicted_cases": int(round(float(val)))})
|
| 178 |
+
if display_history_weeks is None or display_history_weeks <= 0:
|
| 179 |
+
hist_tail = df_st.iloc[:last_known_idx+1].copy()
|
| 180 |
+
else:
|
| 181 |
+
hist_tail = df_st.iloc[max(0, last_known_idx - display_history_weeks): last_known_idx+1].copy()
|
| 182 |
+
historic_data = []
|
| 183 |
+
for _, row in hist_tail.iterrows():
|
| 184 |
+
historic_data.append({
|
| 185 |
+
"date": row["date"].strftime("%Y-%m-%d") if pd.notna(row.get("date")) else None,
|
| 186 |
+
"cases": int(row["casos_soma"]) if pd.notna(row.get("casos_soma")) else None
|
| 187 |
+
})
|
| 188 |
+
return {
|
| 189 |
+
"state": st,
|
| 190 |
+
"last_known_index": int(last_known_idx),
|
| 191 |
+
"historic_data": historic_data,
|
| 192 |
+
"predicted_data": predicted_data,
|
| 193 |
+
}
|