GitHub Actions commited on
Commit
0d73bd2
1 Parent(s): c662f80

Auto-deploy from GitHub

Browse files
app.py CHANGED
@@ -9,7 +9,8 @@ import numpy as np
9
  import json
10
 
11
  from detect import DengueDetector
12
- from predict import DenguePredictor
 
13
 
14
  def default_json_serializer(obj):
15
  if isinstance(obj, np.integer):
@@ -22,16 +23,23 @@ def default_json_serializer(obj):
22
 
23
  detector: DengueDetector = None
24
  predictor: DenguePredictor = None
 
25
 
26
  app = FastAPI()
27
 
28
  # --- evento de startup para carregar os modelos ---
29
  @app.on_event("startup")
30
  async def startup_event():
31
- global detector, predictor
32
  print("Executando evento de startup: Carregando os m贸dulos de IA...")
33
  detector = DengueDetector()
34
  predictor = DenguePredictor()
 
 
 
 
 
 
35
  print("M贸dulos de IA carregados com sucesso. API pronta.")
36
 
37
  # --- CORS ---
@@ -71,6 +79,7 @@ async def predict_dengue_route(payload: dict = Body(...)):
71
  raise ValueError("O campo 'ibge_code' 茅 obrigat贸rio.")
72
 
73
  ibge_code = int(ibge_code_str)
 
74
  result = predictor.predict(ibge_code)
75
 
76
  json_content = json.dumps(result, default=default_json_serializer)
@@ -83,4 +92,40 @@ async def predict_dengue_route(payload: dict = Body(...)):
83
  return JSONResponse(status_code=500, content={
84
  "error": str(e),
85
  "traceback": tb_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  })
 
9
  import json
10
 
11
  from detect import DengueDetector
12
+ from municipal_predictor import DenguePredictor
13
+ from state_predictor import StatePredictor
14
 
15
  def default_json_serializer(obj):
16
  if isinstance(obj, np.integer):
 
23
 
24
  detector: DengueDetector = None
25
  predictor: DenguePredictor = None
26
+ state_predictor: StatePredictor = None
27
 
28
  app = FastAPI()
29
 
30
  # --- evento de startup para carregar os modelos ---
31
  @app.on_event("startup")
32
  async def startup_event():
33
+ global detector, predictor, state_predictor
34
  print("Executando evento de startup: Carregando os m贸dulos de IA...")
35
  detector = DengueDetector()
36
  predictor = DenguePredictor()
37
+ try:
38
+ state_predictor = StatePredictor()
39
+ except Exception as e:
40
+ # N茫o bloqueia a API se o modelo estadual faltar; a rota retornar谩 503
41
+ print("[WARN] StatePredictor n茫o inicializado:", str(e))
42
+ state_predictor = None
43
  print("M贸dulos de IA carregados com sucesso. API pronta.")
44
 
45
  # --- CORS ---
 
79
  raise ValueError("O campo 'ibge_code' 茅 obrigat贸rio.")
80
 
81
  ibge_code = int(ibge_code_str)
82
+ # Sempre retorna hist贸rico completo; frontend controla a janela vis铆vel
83
  result = predictor.predict(ibge_code)
84
 
85
  json_content = json.dumps(result, default=default_json_serializer)
 
92
  return JSONResponse(status_code=500, content={
93
  "error": str(e),
94
  "traceback": tb_str
95
+ })
96
+
97
+
98
+ @app.post("/predict/state/")
99
+ async def predict_dengue_state_route(payload: dict = Body(...)):
100
+ global state_predictor
101
+ if state_predictor is None:
102
+ # Tenta inicializar pregui莽osamente no primeiro uso
103
+ try:
104
+ state_predictor = StatePredictor()
105
+ except Exception as e:
106
+ return JSONResponse(status_code=503, content={"error": f"Preditor estadual ainda n茫o foi inicializado: {str(e)}"})
107
+ try:
108
+ state_sigla = payload.get("state") or payload.get("state_sigla") or payload.get("uf")
109
+ year = payload.get("year")
110
+ week = payload.get("week")
111
+ if not state_sigla:
112
+ raise ValueError("O campo 'state' (sigla) 茅 obrigat贸rio.")
113
+
114
+ # year/week s茫o opcionais; se omitidos, prev锚 ap贸s o 煤ltimo ponto conhecido
115
+ # Sempre retorna hist贸rico completo; frontend controla a janela vis铆vel
116
+ result = state_predictor.predict(
117
+ str(state_sigla).upper(),
118
+ year=int(year) if year is not None else None,
119
+ week=int(week) if week is not None else None,
120
+ )
121
+
122
+ json_content = json.dumps(result, default=default_json_serializer)
123
+ return Response(content=json_content, media_type="application/json")
124
+
125
+ except Exception as e:
126
+ tb_str = traceback.format_exc()
127
+ print(tb_str)
128
+ return JSONResponse(status_code=500, content={
129
+ "error": str(e),
130
+ "traceback": tb_str
131
  })
models/model_state.keras CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d23ef65f526f0a2c26a4ad4163d7400bc32c47d83abd9d46bce862b6114ba9af
3
- size 2534633
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4c86d8e99e3779125ec864816e0fbf96f72a8e324e40a5e170182168a617b30
3
+ size 2536309
models/scalers/scaler_dyn_global_state.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2bb2cfb7f78d33fbf9242461bdef7783f31fbbb35a8114b75c341da36b07fa33
3
  size 1303
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc7972df5abd0302686c2d6ff16962ff31a13c5ca5346cbe57633de1ec34f1c1
3
  size 1303
models/scalers/scaler_target_global_state.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:86da4e22650d62cc4806750659f8c83bbd924404800d818015716f751c7e2947
3
  size 719
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a4e97671eeabf05f39cb9a6b53130816103d263c6bfffd9fc7fbee5f9c77178
3
  size 719
models/state_peak.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"AC": 2036.0, "AL": 2338.0, "AM": 388.0, "AP": 594.0, "BA": 7520.0, "CE": 2754.0, "DF": 6456.0, "ES": 2598.0, "GO": 13984.0, "MA": 1139.0, "MG": 68685.0, "MS": 3781.0, "MT": 1923.0, "PA": 810.0, "PB": 1613.0, "PE": 2249.0, "PI": 2913.0, "PR": 39913.0, "RJ": 12162.0, "RN": 2868.0, "RO": 631.0, "RR": 63.0, "RS": 16798.0, "SC": 26832.0, "SE": 408.0, "SP": 129817.0, "TO": 1289.0}
municipal_predictor.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import joblib
4
+ import numpy as np
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ from datetime import timedelta
8
+ from io import BytesIO
9
+ import base64
10
+ import tensorflow as tf
11
+ from tensorflow.keras.utils import register_keras_serializable
12
+ import matplotlib
13
+ matplotlib.use('Agg')
14
+ import matplotlib.pyplot as plt
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ plt.style.use('seaborn-v0_8-darkgrid')
18
+
19
+ @register_keras_serializable(package="Custom", name="asymmetric_mse")
20
+ def asymmetric_mse(y_true, y_pred):
21
+ penalty_factor = 10.0
22
+ error = y_true - y_pred
23
+ denom = tf.maximum(tf.abs(y_true), 1.0)
24
+ rel = tf.abs(error) / denom
25
+ penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)
26
+ loss = tf.square(error) * penalty
27
+ return tf.reduce_mean(loss)
28
+
29
+ class DenguePredictor:
30
+ def __init__(self, project_root=None):
31
+ self.project_root = Path(project_root) if project_root else Path(__file__).resolve().parent
32
+ self.sequence_length = 12
33
+ self.horizon = 6
34
+ self.year_min_train = 2014
35
+ self.year_max_train = 2025
36
+ self.dynamic_features = [
37
+ "numero_casos", "casos_velocidade", "casos_aceleracao", "casos_mm_4_semanas",
38
+ "T2M", "T2M_MAX", "T2M_MIN", "PRECTOTCORR", "RH2M", "ALLSKY_SFC_SW_DWN",
39
+ "week_sin", "week_cos", "year_norm", "notificacao"
40
+ ]
41
+ self.static_features = ["latitude", "longitude"]
42
+ self.feature_names_pt = {
43
+ "numero_casos": "N潞 de Casos de Dengue",
44
+ "T2M": "Temperatura M茅dia (掳C)",
45
+ "PRECTOTCORR": "Precipita莽茫o (mm)"
46
+ }
47
+ self._loaded = False
48
+ self.load_assets()
49
+
50
+ def load_assets(self):
51
+ models_dir = self.project_root / "models"
52
+ scalers_dir = models_dir / "scalers"
53
+ model_path = models_dir / "model.keras"
54
+ city_map_path = models_dir / "city_to_idx.json"
55
+
56
+ if not scalers_dir.exists():
57
+ raise FileNotFoundError(str(scalers_dir) + " not found")
58
+
59
+ self.scaler_dyn = joblib.load(scalers_dir / "scaler_dyn_global.pkl")
60
+ self.scaler_static = joblib.load(scalers_dir / "scaler_static_global.pkl")
61
+ self.scaler_target = joblib.load(scalers_dir / "scaler_target_global.pkl")
62
+
63
+ if city_map_path.exists():
64
+ with open(city_map_path, "r", encoding="utf-8") as fh:
65
+ self.city_to_idx = {int(k): int(v) for k, v in json.load(fh).items()}
66
+ else:
67
+ self.city_to_idx = {}
68
+
69
+ hf_token = os.environ.get("HF_TOKEN")
70
+ inference_path = hf_hub_download(
71
+ repo_id="previdengue/predict_inference_data",
72
+ filename="inference_data.parquet",
73
+ repo_type="dataset",
74
+ token=hf_token
75
+ )
76
+
77
+ df = pd.read_parquet(inference_path)
78
+ df["codigo_ibge"] = df["codigo_ibge"].astype(int)
79
+ df["ano"] = df["ano"].astype(int)
80
+ df["semana"] = df["semana"].astype(int)
81
+ try:
82
+ df["date"] = pd.to_datetime(df["ano"].astype(str) + df["semana"].astype(str) + "0", format="%Y%W%w", errors="coerce")
83
+ except Exception:
84
+ df["date"] = pd.NaT
85
+
86
+ df = df.sort_values(by=["codigo_ibge", "date"]).reset_index(drop=True)
87
+ df["week_sin"] = np.sin(2 * np.pi * df["semana"] / 52)
88
+ df["week_cos"] = np.cos(2 * np.pi * df["semana"] / 52)
89
+ df["year_norm"] = (df["ano"] - self.year_min_train) / (self.year_max_train - self.year_min_train)
90
+ df["notificacao"] = df["ano"].isin([2021, 2022]).astype(float)
91
+
92
+ self.df_master = df
93
+ self.municipios = df[["codigo_ibge", "municipio"]].drop_duplicates().sort_values("codigo_ibge")
94
+
95
+ if not model_path.exists():
96
+ raise FileNotFoundError(str(model_path) + " not found")
97
+
98
+ self.model = tf.keras.models.load_model(model_path, custom_objects={"asymmetric_mse": asymmetric_mse}, compile=False)
99
+ self._loaded = True
100
+
101
+ def plot_to_base64(self, fig):
102
+ buf = BytesIO()
103
+ fig.savefig(buf, format="png", bbox_inches="tight", facecolor=fig.get_facecolor())
104
+ buf.seek(0)
105
+ img_str = base64.b64encode(buf.read()).decode("utf-8")
106
+ plt.close(fig)
107
+ return img_str
108
+
109
+ def _prepare_sequence(self, df_mun):
110
+ df_seq = df_mun.tail(self.sequence_length).copy()
111
+ df_seq["casos_velocidade"] = df_seq["numero_casos"].diff().fillna(0)
112
+ df_seq["casos_aceleracao"] = df_seq["casos_velocidade"].diff().fillna(0)
113
+ df_seq["casos_mm_4_semanas"] = df_seq["numero_casos"].rolling(4, min_periods=1).mean()
114
+ df_seq["week_sin"] = np.sin(2 * np.pi * df_seq["semana"] / 52)
115
+ df_seq["week_cos"] = np.cos(2 * np.pi * df_seq["semana"] / 52)
116
+ df_seq["year_norm"] = (df_seq["ano"] - self.year_min_train) / (self.year_max_train - self.year_min_train)
117
+ if "notificacao" not in df_seq.columns:
118
+ df_seq["notificacao"] = df_seq["ano"].isin([2021, 2022]).astype(float)
119
+ else:
120
+ df_seq["notificacao"] = df_seq["notificacao"].astype(float)
121
+ return df_seq
122
+
123
+ def predict(self, ibge_code: int, show_plot=False, display_history_weeks=None):
124
+ if not self._loaded:
125
+ raise RuntimeError("assets not loaded")
126
+
127
+ df_mun = self.df_master[self.df_master["codigo_ibge"] == int(ibge_code)].copy().reset_index(drop=True)
128
+ if df_mun.empty or len(df_mun) < self.sequence_length:
129
+ raise ValueError(f"No data or insufficient history for ibge {ibge_code}")
130
+
131
+ municipio_row = self.municipios[self.municipios["codigo_ibge"] == int(ibge_code)]
132
+ municipality_name = municipio_row.iloc[0]["municipio"] if not municipio_row.empty else str(ibge_code)
133
+
134
+ df_mun_clean = df_mun.dropna(subset=["numero_casos"]).reset_index(drop=True)
135
+ if len(df_mun_clean) < self.sequence_length:
136
+ raise ValueError(f"Insufficient known-case history for {ibge_code}")
137
+
138
+ seq_df = self._prepare_sequence(df_mun_clean)
139
+ if len(seq_df) < self.sequence_length:
140
+ raise ValueError(f"Insufficient sequence length for {ibge_code}")
141
+
142
+ dynamic_raw = seq_df[self.dynamic_features].values
143
+ static_raw = seq_df[self.static_features].iloc[-1].values.reshape(1, -1)
144
+
145
+ missing_feats = [c for c in self.dynamic_features if c not in seq_df.columns]
146
+ if missing_feats:
147
+ raise ValueError(f"Missing dynamic features in dataframe: {missing_feats}")
148
+ if hasattr(self.scaler_dyn, "n_features_in_") and self.scaler_dyn.n_features_in_ != len(self.dynamic_features):
149
+ raise ValueError(
150
+ f"Dynamic scaler expects {getattr(self.scaler_dyn, 'n_features_in_', 'unknown')} features, "
151
+ f"but predictor assembled {len(self.dynamic_features)}. Ensure training and inference feature sets match."
152
+ )
153
+
154
+ dynamic_scaled = self.scaler_dyn.transform(dynamic_raw).reshape(1, self.sequence_length, -1)
155
+ static_scaled = self.scaler_static.transform(static_raw)
156
+
157
+ city_idx = int(self.city_to_idx.get(int(ibge_code), 0))
158
+ city_input = np.array([[city_idx]], dtype=np.int32)
159
+
160
+ y_pred = self.model.predict([dynamic_scaled, static_scaled, city_input], verbose=0)
161
+ y_pred_reg = y_pred[0] if isinstance(y_pred, (list, tuple)) else y_pred
162
+
163
+ y_pred_flat = y_pred_reg.reshape(-1, 1)
164
+ y_pred_inv_flat = self.scaler_target.inverse_transform(y_pred_flat)
165
+ y_pred_inv = y_pred_inv_flat.reshape(y_pred_reg.shape)
166
+ pred_values = np.maximum(y_pred_inv.flatten(), 0.0)
167
+
168
+ last_known_case = seq_df["numero_casos"].iloc[-1]
169
+ connected_prediction = np.insert(pred_values, 0, last_known_case)
170
+
171
+ last_real_date = seq_df["date"].iloc[-1] if "date" in seq_df.columns else None
172
+ predicted_data = []
173
+ for i, val in enumerate(connected_prediction[1:]):
174
+ pred_date = (last_real_date + timedelta(weeks=i + 1)).strftime("%Y-%m-%d") if pd.notna(last_real_date) else None
175
+ predicted_data.append({"date": pred_date, "predicted_cases": int(round(float(val)))})
176
+
177
+ # Hist贸rico: por padr茫o retorna tudo; se display_history_weeks > 0, limita a janela
178
+ if display_history_weeks is None or (isinstance(display_history_weeks, (int, float)) and display_history_weeks <= 0):
179
+ hist_tail = df_mun.copy()
180
+ else:
181
+ hist_tail = df_mun.tail(min(len(df_mun), int(display_history_weeks))).copy()
182
+ historic_data = []
183
+ for _, row in hist_tail.iterrows():
184
+ historic_data.append({
185
+ "date": row["date"].strftime("%Y-%m-%d") if pd.notna(row.get("date")) else None,
186
+ "cases": int(row["numero_casos"]) if pd.notna(row.get("numero_casos")) else None
187
+ })
188
+ # Insights: lag correlation analysis and strategic summary
189
+ lag_plot_b64, strategic_summary, tipping_points = self.generate_lag_insights(df_mun)
190
+
191
+ insights = {
192
+ "lag_analysis_plot_base64": lag_plot_b64,
193
+ "strategic_summary": strategic_summary,
194
+ "tipping_points": tipping_points
195
+ }
196
+
197
+ return {
198
+ "municipality_name": municipality_name,
199
+ "ibge": int(ibge_code),
200
+ "last_known_index": int(df_mun.index[-1]),
201
+ "historic_data": historic_data,
202
+ "predicted_data": predicted_data,
203
+ "insights": insights,
204
+ }
205
+
206
+ def generate_lag_insights(self, df_mun: pd.DataFrame):
207
+ # Prepare analysis columns
208
+ df_analysis = df_mun.rename(columns={
209
+ "T2M": "Temperature_C",
210
+ "PRECTOTCORR": "Precipitation_mm"
211
+ })
212
+ max_lag = 12
213
+ cases_col = "numero_casos"
214
+ lag_features = ["Temperature_C", "Precipitation_mm"]
215
+ lag_correlations = {}
216
+
217
+ for col in lag_features:
218
+ if col in df_analysis.columns:
219
+ corrs = []
220
+ for lag in range(1, max_lag + 1):
221
+ try:
222
+ corr = df_analysis[cases_col].corr(df_analysis[col].shift(lag))
223
+ except Exception:
224
+ corr = np.nan
225
+ corrs.append(corr)
226
+ lag_correlations[col] = corrs
227
+ else:
228
+ lag_correlations[col] = [np.nan] * max_lag
229
+
230
+ # Plot
231
+ fig, ax = plt.subplots(figsize=(10, 6), facecolor="#18181b")
232
+ ax.set_facecolor("#18181b")
233
+ for feature_name, corrs in lag_correlations.items():
234
+ ax.plot(range(1, max_lag + 1), corrs, marker="o", linestyle="-", label=feature_name)
235
+ ax.set_title("Lag Analysis", color="white")
236
+ ax.set_xlabel("Lag (weeks)", color="white")
237
+ ax.set_ylabel("Correlation with cases", color="white")
238
+ ax.tick_params(colors="white")
239
+ ax.legend(facecolor="#27272a", edgecolor="gray", labelcolor="white")
240
+ ax.grid(True, which="both", linestyle="--", linewidth=0.5, color="#444")
241
+ lag_plot_b64 = self.plot_to_base64(fig)
242
+
243
+ # Summaries
244
+ lag_peaks = {}
245
+ for feature, corrs in lag_correlations.items():
246
+ if corrs and not all(pd.isna(corrs)):
247
+ peak = int(np.nanargmax(np.abs(np.array(corrs))) + 1)
248
+ else:
249
+ peak = "N/A"
250
+ lag_peaks[feature] = peak
251
+
252
+ temp_lag = lag_peaks.get("Temperature_C", "N/A")
253
+ rain_lag = lag_peaks.get("Precipitation_mm", "N/A")
254
+ summary = (
255
+ f"O modelo identifica Temperatura e Precipita莽茫o como fatores clim谩ticos chave. "
256
+ f"Temperatura mostra impacto m谩ximo ap贸s {temp_lag} semanas e precipita莽茫o ap贸s {rain_lag} semanas. "
257
+ "A莽玫es preventivas devem ser intensificadas nessas janelas ap贸s eventos clim谩ticos extremos."
258
+ )
259
+
260
+ tipping_points = [
261
+ {"factor": "Temperatura", "value": f"Maior impacto em {temp_lag} semanas"},
262
+ {"factor": "Precipita莽茫o", "value": f"Maior impacto em {rain_lag} semanas"},
263
+ {"factor": "Umidade", "value": "Aumenta a sobreviv锚ncia de mosquitos adultos"}
264
+ ]
265
+
266
+ return lag_plot_b64, summary, tipping_points
state_predictor.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import joblib
4
+ import numpy as np
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ from datetime import timedelta
8
+ import tensorflow as tf
9
+ from tensorflow.keras.utils import register_keras_serializable
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ @register_keras_serializable(package="Custom", name="asymmetric_mse")
13
+ def asymmetric_mse(y_true, y_pred):
14
+ penalty_factor = 10.0
15
+ error = y_true - y_pred
16
+ denom = tf.maximum(tf.abs(y_true), 1.0)
17
+ rel = tf.abs(error) / denom
18
+ penalty = tf.where(error > 0, 1.0 + penalty_factor * rel, 1.0)
19
+ loss = tf.square(error) * penalty
20
+ return tf.reduce_mean(loss)
21
+
22
+ class StatePredictor:
23
+ def __init__(self, project_root=None):
24
+ self.project_root = Path(project_root) if project_root else Path(__file__).resolve().parent
25
+ self.sequence_length = 12
26
+ self.horizon = 6
27
+ self.dynamic_features = [
28
+ "casos_norm_log",
29
+ "casos_velocidade", "casos_aceleracao", "casos_mm_4_semanas",
30
+ "T2M_mean","T2M_std","PRECTOTCORR_mean","PRECTOTCORR_std",
31
+ "RH2M_mean","RH2M_std","ALLSKY_SFC_SW_DWN_mean","ALLSKY_SFC_SW_DWN_std",
32
+ "week_sin","week_cos","year_norm","notificacao"
33
+ ]
34
+ self.static_features = ["populacao_total"]
35
+ self._loaded = False
36
+ self.load_assets()
37
+
38
+ def load_assets(self):
39
+ models_dir = self.project_root / "models"
40
+ scalers_dir = models_dir / "scalers"
41
+ model_path = models_dir / "model_state.keras"
42
+ state_map_path = models_dir / "state_to_idx.json"
43
+ state_peak_path = models_dir / "state_peak.json"
44
+
45
+ # scalers
46
+ dyn_state = scalers_dir / "scaler_dyn_global_state.pkl"
47
+ static_state = scalers_dir / "scaler_static_global_state.pkl"
48
+ target_state = scalers_dir / "scaler_target_global_state.pkl"
49
+ if not dyn_state.exists() or not static_state.exists() or not target_state.exists():
50
+ raise FileNotFoundError("State scalers not found under models/scalers. Expected *_state.pkl files.")
51
+ self.scaler_dyn = joblib.load(dyn_state)
52
+ self.scaler_static = joblib.load(static_state)
53
+ self.scaler_target = joblib.load(target_state)
54
+
55
+ # mappings
56
+ if state_map_path.exists():
57
+ with open(state_map_path, "r", encoding="utf-8") as fh:
58
+ self.state_to_idx = json.load(fh)
59
+ else:
60
+ self.state_to_idx = {}
61
+ if state_peak_path.exists():
62
+ with open(state_peak_path, "r", encoding="utf-8") as fh:
63
+ self.state_peak_map = json.load(fh)
64
+ else:
65
+ self.state_peak_map = {}
66
+
67
+ # inference dataset (HF only)
68
+ hf_token = os.environ.get("HF_TOKEN")
69
+ hf_repo = "previdengue/predict_inference_data_estadual"
70
+ hf_filename = "inference_data_estadual.parquet"
71
+ try:
72
+ hf_path = hf_hub_download(
73
+ repo_id=hf_repo,
74
+ filename=hf_filename,
75
+ repo_type="dataset",
76
+ token=hf_token,
77
+ )
78
+ df_loaded = pd.read_parquet(hf_path)
79
+ except Exception as e:
80
+ raise FileNotFoundError(
81
+ "Could not download 'inference_data_estadual.parquet' from HF repo 'previdengue/predict_inference_data_estadual'. "
82
+ "Ensure the dataset exists and set HF_TOKEN if the repo requires authentication."
83
+ ) from e
84
+
85
+ # normalize
86
+ df = df_loaded.copy()
87
+ required = ["estado_sigla", "year", "week", "casos_soma"]
88
+ if any(col not in df.columns for col in required):
89
+ raise ValueError("State dataset missing required columns: ['estado_sigla','year','week','casos_soma']")
90
+ df["estado_sigla"] = df["estado_sigla"].astype(str)
91
+ df = df.sort_values(["estado_sigla", "year", "week"]).reset_index(drop=True)
92
+ if "date" not in df.columns:
93
+ try:
94
+ df["date"] = pd.to_datetime(df["year"].astype(str) + df["week"].astype(str) + "0", format="%Y%W%w", errors="coerce")
95
+ except Exception:
96
+ pass
97
+ if "week_sin" not in df.columns:
98
+ df["week_sin"] = np.sin(2*np.pi*df["week"]/52)
99
+ if "week_cos" not in df.columns:
100
+ df["week_cos"] = np.cos(2*np.pi*df["week"]/52)
101
+ if "year_norm" not in df.columns:
102
+ year_min, year_max = df["year"].min(), df["year"].max()
103
+ df["year_norm"] = (df["year"] - year_min) / max(1.0, (year_max - year_min))
104
+ df["notificacao"] = df["year"].isin([2021, 2022]).astype(float)
105
+
106
+ self.df_state = df
107
+ if not model_path.exists():
108
+ raise FileNotFoundError(str(model_path) + " not found")
109
+ self.model = tf.keras.models.load_model(model_path, custom_objects={"asymmetric_mse": asymmetric_mse}, compile=False)
110
+ self._loaded = True
111
+
112
+ def _prepare_state_sequence(self, df_st: pd.DataFrame, state_sigla: str):
113
+ df_st = df_st.copy()
114
+ df_st['casos_velocidade'] = df_st['casos_soma'].diff().fillna(0)
115
+ df_st['casos_aceleracao'] = df_st['casos_velocidade'].diff().fillna(0)
116
+ df_st['casos_mm_4_semanas'] = df_st['casos_soma'].rolling(4, min_periods=1).mean()
117
+ if "notificacao" not in df_st.columns:
118
+ df_st["notificacao"] = df_st["year"].isin([2021, 2022]).astype(float)
119
+ peak = float(self.state_peak_map.get(state_sigla, 1.0))
120
+ if peak <= 0:
121
+ peak = 1.0
122
+ df_st["casos_norm"] = df_st["casos_soma"] / peak
123
+ df_st["casos_norm_log"] = np.log1p(df_st["casos_norm"])
124
+ return df_st
125
+
126
+ def predict(self, state_sigla: str, year: int = None, week: int = None, display_history_weeks: int | None = None):
127
+ if not self._loaded:
128
+ raise RuntimeError("state assets not loaded")
129
+ st = str(state_sigla).upper()
130
+ df_st = self.df_state[self.df_state["estado_sigla"] == st].copy().sort_values(["year","week"]).reset_index(drop=True)
131
+ if df_st.empty or len(df_st) < self.sequence_length:
132
+ raise ValueError(f"No data or insufficient history for state {st}")
133
+ df_st = self._prepare_state_sequence(df_st, st)
134
+ if year is not None and week is not None:
135
+ idx_list = df_st.index[(df_st['year'] == int(year)) & (df_st['week'] == int(week))].tolist()
136
+ if not idx_list:
137
+ raise ValueError("Prediction point (year/week) not found in state series")
138
+ pred_point_idx = idx_list[0]
139
+ else:
140
+ pred_point_idx = len(df_st)
141
+ last_known_idx = pred_point_idx - 1
142
+ if last_known_idx < self.sequence_length - 1:
143
+ raise ValueError("Insufficient sequence window before prediction point")
144
+ start_idx = last_known_idx - self.sequence_length + 1
145
+ input_seq = df_st.iloc[start_idx:last_known_idx+1].copy()
146
+ for col in self.static_features:
147
+ if col not in input_seq.columns:
148
+ input_seq[col] = 0.0
149
+ static_raw = input_seq[self.static_features].iloc[0].values.reshape(1, -1)
150
+ missing_dyn = [c for c in self.dynamic_features if c not in input_seq.columns]
151
+ if missing_dyn:
152
+ raise ValueError(f"Missing dynamic state features: {missing_dyn}")
153
+ dyn_raw = input_seq[self.dynamic_features].values
154
+ if hasattr(self.scaler_dyn, "n_features_in_") and self.scaler_dyn.n_features_in_ != len(self.dynamic_features):
155
+ raise ValueError(
156
+ f"State dynamic scaler expects {self.scaler_dyn.n_features_in_} features, got {len(self.dynamic_features)}."
157
+ )
158
+ dyn_scaled = self.scaler_dyn.transform(dyn_raw).reshape(1, self.sequence_length, len(self.dynamic_features))
159
+ static_scaled = self.scaler_static.transform(static_raw)
160
+ state_idx = int(self.state_to_idx.get(st, 0))
161
+ state_input = np.array([[state_idx]], dtype=np.int32)
162
+ y_pred = self.model.predict([dyn_scaled, static_scaled, state_input], verbose=0)
163
+ y_pred_reg = y_pred[0] if isinstance(y_pred, (list, tuple)) else y_pred
164
+ y_pred_log_norm = self.scaler_target.inverse_transform(y_pred_reg.reshape(-1,1)).reshape(y_pred_reg.shape)
165
+ y_pred_norm = np.expm1(y_pred_log_norm)
166
+ peak = float(self.state_peak_map.get(st, 1.0))
167
+ if peak <= 0:
168
+ peak = 1.0
169
+ prediction_counts = np.maximum(y_pred_norm.flatten() * peak, 0.0)
170
+ last_known_date = df_st.iloc[last_known_idx]['date'] if 'date' in df_st.columns and last_known_idx < len(df_st) else None
171
+ predicted_data = []
172
+ for i, val in enumerate(prediction_counts):
173
+ if pd.notna(last_known_date):
174
+ pred_date = (last_known_date + timedelta(weeks=i+1)).strftime("%Y-%m-%d")
175
+ else:
176
+ pred_date = None
177
+ predicted_data.append({"date": pred_date, "predicted_cases": int(round(float(val)))})
178
+ if display_history_weeks is None or display_history_weeks <= 0:
179
+ hist_tail = df_st.iloc[:last_known_idx+1].copy()
180
+ else:
181
+ hist_tail = df_st.iloc[max(0, last_known_idx - display_history_weeks): last_known_idx+1].copy()
182
+ historic_data = []
183
+ for _, row in hist_tail.iterrows():
184
+ historic_data.append({
185
+ "date": row["date"].strftime("%Y-%m-%d") if pd.notna(row.get("date")) else None,
186
+ "cases": int(row["casos_soma"]) if pd.notna(row.get("casos_soma")) else None
187
+ })
188
+ return {
189
+ "state": st,
190
+ "last_known_index": int(last_known_idx),
191
+ "historic_data": historic_data,
192
+ "predicted_data": predicted_data,
193
+ }