VIJAYARAGUL commited on
Commit
0980c99
·
1 Parent(s): f6b67b1
Files changed (2) hide show
  1. Dockerfile +23 -8
  2. app.py +104 -206
Dockerfile CHANGED
@@ -1,30 +1,45 @@
 
 
 
1
  FROM python:3.10-slim
2
 
 
3
  # Set working directory
 
4
  WORKDIR /app
5
 
6
- # Set environment variables
 
 
7
  ENV PYTHONDONTWRITEBYTECODE=1
8
  ENV PYTHONUNBUFFERED=1
9
  ENV HF_HOME=/tmp/hf_cache
10
 
 
11
  # Install system dependencies
 
12
  RUN apt-get update && apt-get install -y --no-install-recommends \
13
  gcc \
14
  libc-dev \
15
  && rm -rf /var/lib/apt/lists/*
16
 
17
- # Copy requirements first for better caching
 
 
18
  COPY requirements.txt .
19
-
20
- # Install Python dependencies
21
  RUN pip install --no-cache-dir -r requirements.txt
22
 
23
- # Copy all files from the root directory
 
 
24
  COPY . .
25
 
26
- # Expose HF Spaces port
 
 
27
  EXPOSE 7860
28
 
29
- # Command to run the application
30
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "120"]
 
 
 
1
+ # -------------------------
2
+ # Base image
3
+ # -------------------------
4
  FROM python:3.10-slim
5
 
6
+ # -------------------------
7
  # Set working directory
8
+ # -------------------------
9
  WORKDIR /app
10
 
11
+ # -------------------------
12
+ # Environment variables
13
+ # -------------------------
14
  ENV PYTHONDONTWRITEBYTECODE=1
15
  ENV PYTHONUNBUFFERED=1
16
  ENV HF_HOME=/tmp/hf_cache
17
 
18
+ # -------------------------
19
  # Install system dependencies
20
+ # -------------------------
21
  RUN apt-get update && apt-get install -y --no-install-recommends \
22
  gcc \
23
  libc-dev \
24
  && rm -rf /var/lib/apt/lists/*
25
 
26
+ # -------------------------
27
+ # Copy requirements and install
28
+ # -------------------------
29
  COPY requirements.txt .
 
 
30
  RUN pip install --no-cache-dir -r requirements.txt
31
 
32
+ # -------------------------
33
+ # Copy application code
34
+ # -------------------------
35
  COPY . .
36
 
37
+ # -------------------------
38
+ # Expose Hugging Face port
39
+ # -------------------------
40
  EXPOSE 7860
41
 
42
+ # -------------------------
43
+ # Command to run FastAPI
44
+ # -------------------------
45
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "120"]
app.py CHANGED
@@ -1,17 +1,18 @@
1
  from fastapi import FastAPI, HTTPException, Request, Response
 
 
2
  from pydantic import BaseModel
3
  import pandas as pd
4
  import joblib
5
  import requests
6
- from datetime import timedelta
7
- from math import sin, cos, radians, pi
8
- import logging
9
  import gc
10
  import os
 
 
11
  from contextlib import asynccontextmanager
12
 
13
  # -------------------------
14
- # Logger setup
15
  # -------------------------
16
  logging.basicConfig(
17
  level=logging.INFO,
@@ -19,7 +20,7 @@ logging.basicConfig(
19
  )
20
 
21
  # -------------------------
22
- # Global variables for lazy loading (memory optimization)
23
  # -------------------------
24
  _occurrence_model = None
25
  _occurrence_scaler = None
@@ -27,7 +28,7 @@ _severity_model = None
27
  _severity_scaler = None
28
 
29
  # -------------------------
30
- # NASA POWER setup
31
  # -------------------------
32
  API_BASE = "https://power.larc.nasa.gov/api/temporal/daily/point"
33
  PARAMS = "PRECTOT,T2M,T2M_MAX,T2M_MIN,ALLSKY_SFC_SW_DWN,RH2M,WS2M"
@@ -39,47 +40,26 @@ FEATURE_ORDER = [
39
  ]
40
 
41
  # -------------------------
42
- # Memory management
43
  # -------------------------
44
  def cleanup_memory():
45
- """Force garbage collection to free up memory"""
46
  gc.collect()
47
 
48
  def safe_model_load(filename: str):
49
- """Load model from the same directory as this script with comprehensive error handling"""
50
  try:
51
- # Get the directory of the current script (main.py)
52
  script_dir = os.path.dirname(os.path.abspath(__file__))
53
- model_path = os.path.join(script_dir, filename)
54
- logging.info(f"🔄 Attempting to load {filename} from {model_path}...")
55
-
56
- # Check file exists and has content
57
- if not os.path.exists(model_path):
58
- raise FileNotFoundError(f"File not found: {model_path}")
59
-
60
- file_size = os.path.getsize(model_path)
61
- if file_size == 0:
62
- raise ValueError(f"File is empty: {model_path}")
63
-
64
- logging.info(f"📊 File size: {file_size / (1024 * 1024):.1f} MB")
65
-
66
- # Load the model
67
- model = joblib.load(model_path)
68
- logging.info(f"✅ Successfully loaded {filename}")
69
- return model
70
-
71
  except Exception as e:
72
- logging.error(f"Failed to load {filename}: {str(e)}")
73
- logging.error(f" Error type: {type(e).__name__}")
74
- raise HTTPException(status_code=500, detail=f"Model loading failed: {filename} - {str(e)}")
75
 
76
- # -------------------------
77
- # Lazy loading functions
78
- # -------------------------
79
  def get_occurrence_model_and_scaler():
80
  global _occurrence_model, _occurrence_scaler
81
  if _occurrence_model is None or _occurrence_scaler is None:
82
- logging.info("Loading occurrence model and scaler...")
83
  _occurrence_model = safe_model_load("drought_occurrence_model.joblib")
84
  _occurrence_scaler = safe_model_load("drought_occurrence_scaler.joblib")
85
  cleanup_memory()
@@ -88,49 +68,55 @@ def get_occurrence_model_and_scaler():
88
  def get_severity_model_and_scaler():
89
  global _severity_model, _severity_scaler
90
  if _severity_model is None or _severity_scaler is None:
91
- logging.info("Loading severity model and scaler...")
92
  _severity_model = safe_model_load("drought_severity_model.joblib")
93
  _severity_scaler = safe_model_load("drought_severity_scaler.joblib")
94
  cleanup_memory()
95
  return _severity_model, _severity_scaler
96
 
97
  # -------------------------
98
- # Lifespan event handler
99
  # -------------------------
100
  @asynccontextmanager
101
  async def lifespan(app: FastAPI):
102
- # Startup
103
- logging.info("🚀 Drought API starting - models will load on first request")
104
  cleanup_memory()
105
  yield
106
- # Shutdown
107
- logging.info("🛑 Drought API shutting down")
108
  global _occurrence_model, _occurrence_scaler, _severity_model, _severity_scaler
109
  _occurrence_model = _occurrence_scaler = _severity_model = _severity_scaler = None
110
  cleanup_memory()
111
 
112
  # -------------------------
113
- # Request schema
114
- # -------------------------
115
- class PredictionRequest(BaseModel):
116
- lat: float
117
- lon: float
118
- time: str # YYYY-MM-DD
119
-
120
- # -------------------------
121
- # FastAPI app with lifespan
122
  # -------------------------
123
  app = FastAPI(
124
  title="🌍 Drought Prediction API",
125
  version="2.4",
126
  description="Memory-optimized drought prediction API",
127
- lifespan=lifespan,
128
- docs_url="/docs", # Explicitly enable
129
- redoc_url="/redoc" # Alternative docs
 
 
 
 
 
 
 
 
130
  )
131
 
132
  # -------------------------
133
- # NASA fetcher (memory optimized)
 
 
 
 
 
 
 
 
134
  # -------------------------
135
  def fetch_features(lat, lon, time_str: str) -> dict:
136
  end = pd.to_datetime(time_str)
@@ -146,198 +132,96 @@ def fetch_features(lat, lon, time_str: str) -> dict:
146
  }
147
  try:
148
  response = requests.get(API_BASE, params=params, timeout=30)
149
- if response.status_code != 200:
150
- logging.error(f"NASA API error {response.status_code}")
151
- raise HTTPException(status_code=502, detail="NASA API error")
152
  data = response.json().get("properties", {}).get("parameter", {})
153
- if not data:
154
- raise HTTPException(status_code=502, detail="No data from NASA API")
155
  features = {}
156
- for p, values in data.items():
157
- vals = [v for v in values.values() if v is not None]
158
- if vals:
159
- if p == "PRECTOT":
160
- features["PRECTOTCORR"] = sum(vals)
161
- else:
162
- features[p] = sum(vals) / len(vals)
163
- # Clear response from memory
164
- del data, response, vals
165
- cleanup_memory()
166
- # Derived features
167
  features.update({
168
  "lat_sin": sin(radians(lat)),
169
  "lat_cos": cos(radians(lat)),
170
  "lon_sin": sin(radians(lon)),
171
  "lon_cos": cos(radians(lon)),
172
- "month_sin": sin(2 * pi * end.month / 12),
173
- "month_cos": cos(2 * pi * end.month / 12)
174
  })
175
  missing = [f for f in FEATURE_ORDER if f not in features]
176
  if missing:
177
  raise HTTPException(status_code=500, detail=f"Missing features: {missing}")
 
178
  return features
179
- except HTTPException:
180
- raise
181
  except Exception as e:
182
- logging.error(f"NASA API fetch error: {e}")
183
  raise HTTPException(status_code=502, detail="NASA API request failed")
184
 
185
  # -------------------------
186
- # Prediction endpoint (memory optimized with detailed debugging)
187
  # -------------------------
188
  @app.post("/predict")
189
  async def predict(req: PredictionRequest):
190
  try:
191
- logging.info(f"🔄 Starting prediction for lat={req.lat}, lon={req.lon}, time={req.time}")
192
- # Validate input
193
- try:
194
- pd.to_datetime(req.time)
195
- except Exception as e:
196
- logging.error(f"Invalid time format: {req.time}")
197
- raise HTTPException(status_code=400, detail=f"Invalid time format: {req.time}. Use YYYY-MM-DD")
198
- # Get features
199
- logging.info("📡 Fetching NASA data...")
200
  features = fetch_features(req.lat, req.lon, req.time)
201
- logging.info(f"✅ Features fetched: {len(features)} features")
202
- X = pd.DataFrame([[features[col] for col in FEATURE_ORDER]], columns=FEATURE_ORDER)
203
- logging.info(f"📊 DataFrame created: {X.shape}")
204
- # Occurrence prediction
205
- logging.info("🔮 Loading occurrence model...")
206
- try:
207
- occ_model, occ_scaler = get_occurrence_model_and_scaler()
208
- logging.info("✅ Occurrence model loaded")
209
- except Exception as e:
210
- logging.error(f"❌ Failed to load occurrence model: {e}")
211
- raise HTTPException(status_code=500, detail=f"Failed to load occurrence model: {str(e)}")
212
- try:
213
- X_occ = occ_scaler.transform(X)
214
- occurrence_pred = int(occ_model.predict(X_occ)[0])
215
- occurrence_proba = occ_model.predict_proba(X_occ)[0].tolist()
216
- logging.info(f"✅ Occurrence prediction: {occurrence_pred}")
217
- except Exception as e:
218
- logging.error(f"❌ Occurrence prediction failed: {e}")
219
- raise HTTPException(status_code=500, detail=f"Occurrence prediction failed: {str(e)}")
220
- del X_occ # Free memory
221
- cleanup_memory()
222
- # Severity prediction
223
- logging.info("🔮 Loading severity model...")
224
- try:
225
- sev_model, sev_scaler = get_severity_model_and_scaler()
226
- logging.info("✅ Severity model loaded")
227
- except Exception as e:
228
- logging.error(f"❌ Failed to load severity model: {e}")
229
- raise HTTPException(status_code=500, detail=f"Failed to load severity model: {str(e)}")
230
- try:
231
- X_sev = sev_scaler.transform(X)
232
- severity_pred = int(sev_model.predict(X_sev)[0])
233
- severity_proba = sev_model.predict_proba(X_sev)[0].tolist()
234
- logging.info(f"✅ Severity prediction: {severity_pred}")
235
- except Exception as e:
236
- logging.error(f"❌ Severity prediction failed: {e}")
237
- raise HTTPException(status_code=500, detail=f"Severity prediction failed: {str(e)}")
238
- del X_sev # Free memory
239
- cleanup_memory()
240
  result = {
241
  "input": {"lat": req.lat, "lon": req.lon, "time": req.time},
242
- "occurrence": {
243
- "prediction": occurrence_pred,
244
- "probabilities": occurrence_proba
245
- },
246
- "severity": {
247
- "prediction": severity_pred,
248
- "probabilities": severity_proba
249
- },
250
- "features_used": {k: round(v, 4) for k, v in zip(FEATURE_ORDER, X.iloc[0].tolist())}
251
  }
252
- # Final cleanup
253
- del X, features
254
  cleanup_memory()
255
- logging.info(f"✅ Prediction complete: Occurrence={occurrence_pred}, Severity={severity_pred}")
256
  return result
257
- except HTTPException as http_err:
258
- logging.error(f"HTTP Error: {http_err.detail}")
259
- cleanup_memory()
260
- raise http_err
261
  except Exception as e:
262
- logging.error(f" Unexpected prediction error: {str(e)}")
263
- logging.error(f"❌ Error type: {type(e).__name__}")
264
- import traceback
265
- logging.error(f"❌ Traceback: {traceback.format_exc()}")
266
- cleanup_memory()
267
- raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
268
 
269
  # -------------------------
270
- # Debug endpoint to test individual components
 
 
 
 
 
 
 
 
 
271
  # -------------------------
272
  @app.get("/debug")
273
  async def debug_info():
274
- """Debug endpoint to check system status"""
275
- try:
276
- debug_data = {
277
- "python_version": f"{os.sys.version_info.major}.{os.sys.version_info.minor}.{os.sys.version_info.micro}",
278
- "feature_order": FEATURE_ORDER,
279
- "api_base": API_BASE,
280
- "models_loaded": {
281
- "occurrence_model": _occurrence_model is not None,
282
- "occurrence_scaler": _occurrence_scaler is not None,
283
- "severity_model": _severity_model is not None,
284
- "severity_scaler": _severity_scaler is not None
285
- }
286
- }
287
- # Test NASA API with a simple request
288
- try:
289
- test_response = requests.get("https://power.larc.nasa.gov", timeout=10)
290
- debug_data["nasa_api_accessible"] = test_response.status_code == 200
291
- except:
292
- debug_data["nasa_api_accessible"] = False
293
- # Test local model files
294
- try:
295
- script_dir = os.path.dirname(os.path.abspath(__file__))
296
- model_files = [
297
- "drought_occurrence_model.joblib",
298
- "drought_occurrence_scaler.joblib",
299
- "drought_severity_model.joblib",
300
- "drought_severity_scaler.joblib"
301
- ]
302
- debug_data["model_files_exist"] = {
303
- f: os.path.exists(os.path.join(script_dir, f)) for f in model_files
304
- }
305
- except Exception as e:
306
- debug_data["model_files_error"] = str(e)
307
- return debug_data
308
- except Exception as e:
309
- return {"debug_error": str(e)}
310
 
311
  # -------------------------
312
- # Test prediction with sample data
313
  # -------------------------
314
  @app.get("/test")
315
  async def test_prediction():
316
- """Test endpoint with hardcoded values"""
317
  try:
318
- test_request = PredictionRequest(
319
- lat=40.7128, # New York
320
- lon=-74.0060,
321
- time="2024-08-15"
322
- )
323
- result = await predict(test_request)
324
  return {"test_status": "success", "result": result}
325
  except Exception as e:
326
  return {"test_status": "failed", "error": str(e)}
327
 
328
- # -------------------------
329
- # Health check (lightweight)
330
- # -------------------------
331
- @app.api_route("/health", methods=["GET", "HEAD"])
332
- async def health_check(request: Request):
333
- if request.method == "HEAD":
334
- return Response(status_code=200)
335
- return {
336
- "status": "healthy",
337
- "api_version": "2.4",
338
- "python_version": f"{os.sys.version_info.major}.{os.sys.version_info.minor}"
339
- }
340
-
341
  # -------------------------
342
  # Root endpoint
343
  # -------------------------
@@ -349,6 +233,20 @@ async def root():
349
  "endpoints": {
350
  "predict": "/predict",
351
  "health": "/health",
352
- "docs": "/docs"
 
 
 
353
  }
354
- }
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException, Request, Response
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
4
  from pydantic import BaseModel
5
  import pandas as pd
6
  import joblib
7
  import requests
 
 
 
8
  import gc
9
  import os
10
+ import logging
11
+ from math import sin, cos, radians, pi
12
  from contextlib import asynccontextmanager
13
 
14
  # -------------------------
15
+ # Logger
16
  # -------------------------
17
  logging.basicConfig(
18
  level=logging.INFO,
 
20
  )
21
 
22
  # -------------------------
23
+ # Global models
24
  # -------------------------
25
  _occurrence_model = None
26
  _occurrence_scaler = None
 
28
  _severity_scaler = None
29
 
30
  # -------------------------
31
+ # Feature setup
32
  # -------------------------
33
  API_BASE = "https://power.larc.nasa.gov/api/temporal/daily/point"
34
  PARAMS = "PRECTOT,T2M,T2M_MAX,T2M_MIN,ALLSKY_SFC_SW_DWN,RH2M,WS2M"
 
40
  ]
41
 
42
  # -------------------------
43
+ # Utility functions
44
  # -------------------------
45
  def cleanup_memory():
 
46
  gc.collect()
47
 
48
  def safe_model_load(filename: str):
 
49
  try:
 
50
  script_dir = os.path.dirname(os.path.abspath(__file__))
51
+ path = os.path.join(script_dir, filename)
52
+ if not os.path.exists(path):
53
+ raise FileNotFoundError(f"{filename} not found")
54
+ return joblib.load(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  except Exception as e:
56
+ logging.error(f"Failed to load {filename}: {e}")
57
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {filename}")
 
58
 
 
 
 
59
  def get_occurrence_model_and_scaler():
60
  global _occurrence_model, _occurrence_scaler
61
  if _occurrence_model is None or _occurrence_scaler is None:
62
+ logging.info("Loading occurrence model/scaler...")
63
  _occurrence_model = safe_model_load("drought_occurrence_model.joblib")
64
  _occurrence_scaler = safe_model_load("drought_occurrence_scaler.joblib")
65
  cleanup_memory()
 
68
  def get_severity_model_and_scaler():
69
  global _severity_model, _severity_scaler
70
  if _severity_model is None or _severity_scaler is None:
71
+ logging.info("Loading severity model/scaler...")
72
  _severity_model = safe_model_load("drought_severity_model.joblib")
73
  _severity_scaler = safe_model_load("drought_severity_scaler.joblib")
74
  cleanup_memory()
75
  return _severity_model, _severity_scaler
76
 
77
  # -------------------------
78
+ # Lifespan
79
  # -------------------------
80
  @asynccontextmanager
81
  async def lifespan(app: FastAPI):
82
+ logging.info("🚀 Drought API starting (models load on first request)")
 
83
  cleanup_memory()
84
  yield
85
+ logging.info("🛑 Shutting down API")
 
86
  global _occurrence_model, _occurrence_scaler, _severity_model, _severity_scaler
87
  _occurrence_model = _occurrence_scaler = _severity_model = _severity_scaler = None
88
  cleanup_memory()
89
 
90
  # -------------------------
91
+ # FastAPI instance
 
 
 
 
 
 
 
 
92
  # -------------------------
93
  app = FastAPI(
94
  title="🌍 Drought Prediction API",
95
  version="2.4",
96
  description="Memory-optimized drought prediction API",
97
+ lifespan=lifespan
98
+ )
99
+
100
+ # -------------------------
101
+ # CORS middleware for website
102
+ # -------------------------
103
+ app.add_middleware(
104
+ CORSMiddleware,
105
+ allow_origins=["*"], # replace with website URL in production
106
+ allow_methods=["*"],
107
+ allow_headers=["*"]
108
  )
109
 
110
  # -------------------------
111
+ # Request model
112
+ # -------------------------
113
+ class PredictionRequest(BaseModel):
114
+ lat: float
115
+ lon: float
116
+ time: str # YYYY-MM-DD
117
+
118
+ # -------------------------
119
+ # NASA feature fetcher
120
  # -------------------------
121
  def fetch_features(lat, lon, time_str: str) -> dict:
122
  end = pd.to_datetime(time_str)
 
132
  }
133
  try:
134
  response = requests.get(API_BASE, params=params, timeout=30)
135
+ response.raise_for_status()
 
 
136
  data = response.json().get("properties", {}).get("parameter", {})
 
 
137
  features = {}
138
+ for p, vals in data.items():
139
+ values = [v for v in vals.values() if v is not None]
140
+ if values:
141
+ features["PRECTOTCORR" if p=="PRECTOT" else p] = sum(values)/len(values) if p!="PRECTOT" else sum(values)
 
 
 
 
 
 
 
142
  features.update({
143
  "lat_sin": sin(radians(lat)),
144
  "lat_cos": cos(radians(lat)),
145
  "lon_sin": sin(radians(lon)),
146
  "lon_cos": cos(radians(lon)),
147
+ "month_sin": sin(2*pi*end.month/12),
148
+ "month_cos": cos(2*pi*end.month/12)
149
  })
150
  missing = [f for f in FEATURE_ORDER if f not in features]
151
  if missing:
152
  raise HTTPException(status_code=500, detail=f"Missing features: {missing}")
153
+ cleanup_memory()
154
  return features
 
 
155
  except Exception as e:
156
+ logging.error(f"NASA fetch error: {e}")
157
  raise HTTPException(status_code=502, detail="NASA API request failed")
158
 
159
  # -------------------------
160
+ # Prediction endpoint
161
  # -------------------------
162
  @app.post("/predict")
163
  async def predict(req: PredictionRequest):
164
  try:
 
 
 
 
 
 
 
 
 
165
  features = fetch_features(req.lat, req.lon, req.time)
166
+ X = pd.DataFrame([[features[f] for f in FEATURE_ORDER]], columns=FEATURE_ORDER)
167
+ occ_model, occ_scaler = get_occurrence_model_and_scaler()
168
+ sev_model, sev_scaler = get_severity_model_and_scaler()
169
+ X_occ = occ_scaler.transform(X)
170
+ X_sev = sev_scaler.transform(X)
171
+ occurrence_pred = int(occ_model.predict(X_occ)[0])
172
+ occurrence_proba = occ_model.predict_proba(X_occ)[0].tolist()
173
+ severity_pred = int(sev_model.predict(X_sev)[0])
174
+ severity_proba = sev_model.predict_proba(X_sev)[0].tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  result = {
176
  "input": {"lat": req.lat, "lon": req.lon, "time": req.time},
177
+ "occurrence": {"prediction": occurrence_pred, "probabilities": occurrence_proba},
178
+ "severity": {"prediction": severity_pred, "probabilities": severity_proba},
179
+ "features_used": {k: round(v,4) for k,v in zip(FEATURE_ORDER, X.iloc[0].tolist())}
 
 
 
 
 
 
180
  }
 
 
181
  cleanup_memory()
 
182
  return result
183
+ except HTTPException as e:
184
+ raise e
 
 
185
  except Exception as e:
186
+ logging.error(f"Prediction error: {e}")
187
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
188
 
189
  # -------------------------
190
+ # Health check
191
+ # -------------------------
192
+ @app.api_route("/health", methods=["GET", "HEAD"])
193
+ async def health_check(request: Request):
194
+ if request.method == "HEAD":
195
+ return Response(status_code=200)
196
+ return {"status": "healthy", "api_version": "2.4"}
197
+
198
+ # -------------------------
199
+ # Debug endpoint
200
  # -------------------------
201
  @app.get("/debug")
202
  async def debug_info():
203
+ return {
204
+ "models_loaded": {
205
+ "occurrence_model": _occurrence_model is not None,
206
+ "occurrence_scaler": _occurrence_scaler is not None,
207
+ "severity_model": _severity_model is not None,
208
+ "severity_scaler": _severity_scaler is not None
209
+ },
210
+ "feature_order": FEATURE_ORDER
211
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  # -------------------------
214
+ # Test endpoint
215
  # -------------------------
216
  @app.get("/test")
217
  async def test_prediction():
 
218
  try:
219
+ test_req = PredictionRequest(lat=40.7128, lon=-74.0060, time="2024-08-15")
220
+ result = await predict(test_req)
 
 
 
 
221
  return {"test_status": "success", "result": result}
222
  except Exception as e:
223
  return {"test_status": "failed", "error": str(e)}
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  # -------------------------
226
  # Root endpoint
227
  # -------------------------
 
233
  "endpoints": {
234
  "predict": "/predict",
235
  "health": "/health",
236
+ "debug": "/debug",
237
+ "test": "/test",
238
+ "docs": "/docs",
239
+ "redoc": "/redoc"
240
  }
241
+ }
242
+
243
+ # -------------------------
244
+ # Swagger UI and Redoc
245
+ # -------------------------
246
+ @app.get("/docs", include_in_schema=False)
247
+ async def custom_swagger_ui():
248
+ return get_swagger_ui_html(openapi_url="/openapi.json", title="API Docs")
249
+
250
+ @app.get("/redoc", include_in_schema=False)
251
+ async def custom_redoc():
252
+ return get_redoc_html(openapi_url="/openapi.json", title="ReDoc")