nikethanreddy commited on
Commit
d7ffece
·
verified ·
1 Parent(s): 103ba23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -79
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
3
  os.environ['JAX_PLATFORMS'] = 'cpu'
@@ -135,16 +136,16 @@ def calculate_overall_aqi(row, aqi_breakpoints):
135
  def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: float):
136
  print(f"Attempting to retrieve data for the last {sequence_length} hours from Open-Meteo for Lat: {latitude}, Lon: {longitude}")
137
 
138
- end_time = datetime.now(pytz.utc)
139
- # Fetch slightly more data to allow for resampling and ensure sequence_length is met
140
- fetch_hours = sequence_length + 5
141
- start_time = end_time - timedelta(hours=fetch_hours)
 
 
142
 
143
- # Format timestamps for API request (ISO 8601)
144
- start_time_str = start_time.isoformat().split('.')[0] + 'Z'
145
- end_time_str = end_time.isoformat().split('.')[0] + 'Z'
146
 
147
- print(f"Requesting data from {start_time_str} to {end_time_str}")
 
148
 
149
  # Open-Meteo Air Quality API
150
  air_quality_url = "https://air-quality-api.open-meteo.com/v1/air-quality"
@@ -153,32 +154,30 @@ def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: f
153
  "longitude": longitude,
154
  "hourly": ["pm2_5", "pm10", "carbon_monoxide"],
155
  "timezone": "UTC",
156
- "start_date": start_time.strftime('%Y-%m-%d'), # Use YYYY-MM-DD format
157
- "end_date": end_time.strftime('%Y-%m-%d'),
158
- "past_hours": fetch_hours
159
  }
160
 
161
- # Open-Meteo Historical Weather API for Temperature
162
  weather_url = "https://archive-api.open-meteo.com/v1/archive"
163
  weather_params = {
164
  "latitude": latitude,
165
  "longitude": longitude,
166
  "hourly": ["temperature_2m"],
167
  "timezone": "UTC",
168
- "start_date": start_time.strftime('%Y-%m-%d'),
169
- "end_date": end_time.strftime('%Y-%m-%d')
170
  }
171
 
172
  try:
173
  # Fetch Air Quality Data
174
- print(f"Fetching air quality data from: {air_quality_url}")
175
  air_quality_response = requests.get(air_quality_url, params=air_quality_params)
176
  air_quality_response.raise_for_status()
177
  air_quality_data = air_quality_response.json()
178
  print("Air quality data retrieved.")
179
 
180
  # Fetch Temperature Data
181
- print(f"Fetching temperature data from: {weather_url}")
182
  weather_response = requests.get(weather_url, params=weather_params)
183
  weather_response.raise_for_status()
184
  weather_data = weather_response.json()
@@ -208,7 +207,6 @@ def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: f
208
 
209
 
210
  # Resample to ensure consistent hourly frequency and fill missing data
211
- # Use 'h' for hourly resampling
212
  df_processed = df_merged.resample('h').ffill().bfill()
213
  print(f"DataFrame resampled to hourly. Shape: {df_processed.shape}")
214
 
@@ -235,7 +233,7 @@ def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: f
235
  print(f"Selected and reordered columns. Final processing shape: {df_processed.shape}")
236
 
237
 
238
- # Handle any remaining NaNs after ffill/bfill (e.g., if the very first values were NaN or API returned all NaNs)
239
  initial_rows = len(df_processed)
240
  df_processed.dropna(inplace=True)
241
  if len(df_processed) < initial_rows:
@@ -248,7 +246,7 @@ def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: f
248
  return None, f"Error: Insufficient historical data ({len(df_processed)} points available, {sequence_length} required)."
249
 
250
  # Select the last `sequence_length` rows for the input sequence
251
- latest_data_sequence_df = df_processed.tail(sequence_length).copy() # Use .copy() to avoid SettingWithCopyWarning
252
  print(f"Selected last {sequence_length} data points.")
253
 
254
  # Convert to numpy array and reshape (1, sequence_length, num_features)
@@ -259,7 +257,7 @@ def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: f
259
 
260
  print(f"Prepared input sequence with shape: {latest_data_sequence.shape}")
261
 
262
- return latest_data_sequence, timestamps # Return data and timestamps
263
 
264
  except requests.exceptions.RequestException as e:
265
  print(f"API Request Error: {e}")
@@ -271,7 +269,6 @@ def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: f
271
 
272
 
273
  # --- Define paths to your saved files ---
274
- # Use relative paths assuming files are in the root directory of the Space
275
  MODEL_PATH = 'best_model_TKAN_nahead_1.keras'
276
  INPUT_SCALER_ATTR_PATH = 'input_scaler_attributes.json'
277
  TARGET_SCALER_ATTR_PATH = 'target_scaler_attributes.json'
@@ -280,25 +277,24 @@ Y_SCALER_TRAIN_PATH = 'y_scaler_train.npy'
280
 
281
  # --- Load the scalers and model ---
282
  input_scaler = None
283
- target_scaler = None # Scaler for the AQI/rolling_median ratio
284
  model = None
285
 
286
  try:
287
  print(f"Attempting to load input scaler attributes from {INPUT_SCALER_ATTR_PATH}...")
288
  with open(INPUT_SCALER_ATTR_PATH, 'r') as f:
289
  input_attrs = json.load(f)
290
- input_scaler = MinMaxScaler() # Create a new instance
291
- input_scaler.load_attributes(input_attrs) # Load attributes
292
  print("Input scaler loaded manually.")
293
 
294
  print(f"Attempting to load target scaler attributes from {TARGET_SCALER_ATTR_PATH}...")
295
  with open(TARGET_SCALER_ATTR_PATH, 'r') as f:
296
  target_attrs = json.load(f)
297
- target_scaler = MinMaxScaler() # Create a new instance
298
- target_scaler.load_attributes(target_attrs) # Load attributes
299
  print("Target scaler loaded manually.")
300
 
301
- # Load y_scaler_train numpy array if saved as .npy
302
  print(f"Attempting to load y_scaler_train numpy array from {Y_SCALER_TRAIN_PATH}...")
303
  y_scaler_train = np.load(Y_SCALER_TRAIN_PATH)
304
  print("y_scaler_train numpy array loaded.")
@@ -311,16 +307,13 @@ except Exception as e:
311
  import traceback
312
  traceback.print_exc()
313
 
314
- # Load the trained model with custom_object_scope
315
  custom_objects = {"TKAN": TKAN}
316
  if TKAT is not None:
317
  custom_objects["TKAT"] = TKAT
318
 
319
  try:
320
  print(f"Loading model from {MODEL_PATH}...")
321
- # Use custom_object_scope to register custom layers during loading
322
  with custom_object_scope(custom_objects):
323
- # compile=False because we only need the model for inference
324
  model = load_model(MODEL_PATH, compile=False)
325
  print("Model loaded successfully.")
326
  except FileNotFoundError:
@@ -334,38 +327,30 @@ except Exception as e:
334
  traceback.print_exc()
335
 
336
 
337
- # Initialize FastAPI app
338
  app = FastAPI()
339
 
340
- # Define the structure of the prediction request body
341
  class PredictionRequest(BaseModel):
342
  latitude: float
343
  longitude: float
344
- pm25: float = None # Make current inputs optional, rely primarily on historical fetch
345
  pm10: float = None
346
  co: float = None
347
  temp: float = None
348
- n_ahead: int = 1 # Default prediction steps
349
 
350
 
351
- # Define the structure of the prediction response body
352
  class PredictionResponse(BaseModel):
353
- status: str # "success" or "error"
354
- message: str # Description of the result or error
355
- predictions: list = None # List of {"timestamp": "...", "aqi": ...} or None on error
356
 
357
 
358
- # Define the prediction endpoint
359
  @app.post("/predict", response_model=PredictionResponse)
360
  async def predict_aqi_endpoint(request: PredictionRequest):
361
- # Check if model and scalers were loaded successfully on startup
362
  if model is None or input_scaler is None or target_scaler is None:
363
  print("API called but model or scalers are not loaded.")
364
- # Return a 500 Internal Server Error if dependencies failed to load
365
  raise HTTPException(status_code=500, detail="Model or scalers not loaded. Check server logs for details.")
366
 
367
- # Get the expected sequence length and number of features from the model's input shape
368
- # Assuming input shape is (None, sequence_length, num_features)
369
  if model.input_shape is None or len(model.input_shape) < 2:
370
  print(f"Error: Model has unexpected input shape: {model.input_shape}")
371
  raise HTTPException(status_code=500, detail=f"Model has unexpected input shape: {model.input_shape}")
@@ -378,34 +363,24 @@ async def predict_aqi_endpoint(request: PredictionRequest):
378
  raise HTTPException(status_code=500, detail=f"Model expects {NUM_FEATURES} features, but data processing provides {required_num_features}.")
379
 
380
 
381
- # Get the historical data sequence and its timestamps from Open-Meteo
382
- # The function now returns the data and a message (or error)
383
  latest_data_sequence_unscaled, message = get_latest_data_sequence(SEQUENCE_LENGTH, request.latitude, request.longitude)
384
 
385
- # Check if data retrieval was successful
386
  if latest_data_sequence_unscaled is None:
387
- # Return an error response if data fetching failed
388
  print(f"Data retrieval failed: {message}")
389
  return PredictionResponse(status="error", message=f"Data retrieval failed: {message}")
390
 
391
- # The timestamps returned are for the sequence itself. We need timestamps for the *predictions*.
392
- # The predictions are for n_ahead steps *after* the last timestamp in the sequence.
393
  prediction_timestamps = []
394
- if message and isinstance(message, list) and len(message) > 0: # 'message' is actually 'timestamps' here
395
- last_timestamp_of_sequence = message[-1] # Get the last timestamp from the sequence
396
  for i in range(request.n_ahead):
397
- # Prediction i (0-indexed) is for hour i+1 after the last timestamp
398
  prediction_timestamps.append(last_timestamp_of_sequence + timedelta(hours=i + 1))
399
  else:
400
  print("Warning: Could not get valid timestamps from data retrieval. Prediction timestamps will be approximate.")
401
- # Fallback: Approximate timestamps based on current time
402
  now_utc = datetime.now(pytz.utc)
403
  for i in range(request.n_ahead):
404
  prediction_timestamps.append(now_utc + timedelta(hours=i+1))
405
 
406
 
407
- # Optional: Update the last timestep with current user inputs if provided
408
- # Check if current inputs were provided and are valid (not None or NaN)
409
  if request.pm25 is not None and not pd.isna(request.pm25) and \
410
  request.pm10 is not None and not pd.isna(request.pm10) and \
411
  request.co is not None and not pd.isna(request.co) and \
@@ -414,8 +389,6 @@ async def predict_aqi_endpoint(request: PredictionRequest):
414
  current_aqi = calculate_overall_aqi({'pm25': request.pm25, 'pm10': request.pm10, 'co': request.co, 'temp': request.temp}, aqi_breakpoints)
415
 
416
  if not pd.isna(current_aqi):
417
- # Assuming column order: 'calculated_aqi', 'temp', 'pm25', 'pm10', 'co'
418
- # Update the last row (-1) of the input sequence
419
  latest_data_sequence_unscaled[0, -1, 0] = current_aqi
420
  latest_data_sequence_unscaled[0, -1, 1] = request.temp
421
  latest_data_sequence_unscaled[0, -1, 2] = request.pm25
@@ -425,7 +398,6 @@ async def predict_aqi_endpoint(request: PredictionRequest):
425
  else:
426
  print("Warning: Could not calculate AQI for current inputs. Last timestep remains historical.")
427
 
428
- # Scale the input data
429
  try:
430
  X_scaled = input_scaler.transform(latest_data_sequence_unscaled)
431
  print("Input data scaled successfully.")
@@ -435,9 +407,8 @@ async def predict_aqi_endpoint(request: PredictionRequest):
435
  raise HTTPException(status_code=500, detail="Error processing input data for prediction (scaling).")
436
 
437
 
438
- # Make prediction
439
  try:
440
- scaled_prediction = model.predict(X_scaled, verbose=0) # Shape (1, n_ahead)
441
  print(f"Model prediction made. Scaled prediction shape: {scaled_prediction.shape}")
442
  except Exception as e:
443
  print(f"Error during model prediction: {e}")
@@ -445,33 +416,22 @@ async def predict_aqi_endpoint(request: PredictionRequest):
445
  raise HTTPException(status_code=500, detail="Error during model prediction.")
446
 
447
 
448
- # Inverse transform the prediction
449
  try:
450
- # --- Inverse Transformation Logic (Based on Rolling Median Scaling) ---
451
- # This part needs the actual rolling median for the future prediction timesteps.
452
- # Using an approximation based on the input sequence.
453
-
454
  if latest_data_sequence_unscaled.shape[1] > 0:
455
- # Get the 'calculated_aqi' values from the unscaled input sequence
456
- calculated_aqi_sequence = latest_data_sequence_unscaled[0, :, 0] # Assuming AQI is the first feature
457
 
458
- # Approximate the rolling median based on the last few points of the input sequence
459
- # This is a simple approximation. A more robust method might be needed.
460
  approx_rolling_median_proxy = np.mean(calculated_aqi_sequence[-min(5, SEQUENCE_LENGTH):])
461
  if pd.isna(approx_rolling_median_proxy) or approx_rolling_median_proxy <= 0:
462
- approx_rolling_median_proxy = 1.0 # Prevent division by zero or invalid scaling
463
 
464
- # Create a placeholder scaler array for the future timesteps
465
  corresponding_rolling_median_scaler = np.full((1, request.n_ahead, 1), approx_rolling_median_proxy, dtype=np.float32)
466
  print(f"Approximated rolling median proxy for inverse transform: {approx_rolling_median_proxy:.2f}")
467
 
468
- # 1. Inverse transform the scaled prediction (ratio) using the target_scaler
469
  y_unscaled_pred_ratio = target_scaler.inverse_transform(scaled_prediction.reshape(1, request.n_ahead, 1))
470
  print(f"Inverse transformed to ratio scale. Shape: {y_unscaled_pred_ratio.shape}")
471
 
472
- # 2. Multiply the unscaled ratio by the approximated rolling median scaler
473
  predicted_aqi_values = y_unscaled_pred_ratio * corresponding_rolling_median_scaler
474
- predicted_aqi_values = predicted_aqi_values.flatten() # Shape (n_ahead,)
475
 
476
  else:
477
  print("Error: Input sequence is empty, cannot perform inverse transform.")
@@ -484,20 +444,16 @@ async def predict_aqi_endpoint(request: PredictionRequest):
484
  traceback.print_exc()
485
  raise HTTPException(status_code=500, detail="Error processing prediction results (inverse transform).")
486
 
487
- # Prepare the prediction output list
488
  predictions_list = []
489
  for i in range(request.n_ahead):
490
- # Use the calculated prediction_timestamps
491
  timestamp_str = prediction_timestamps[i].strftime('%Y-%m-%d %H:%M:%S')
492
  predictions_list.append({
493
  "timestamp": timestamp_str,
494
- "aqi": float(predicted_aqi_values[i]) # Ensure AQI is a standard float
495
  })
496
 
497
- # Return the successful response
498
  return PredictionResponse(status="success", message="Prediction successful.", predictions=predictions_list)
499
 
500
- # Root endpoint for health check
501
  @app.get("/")
502
  async def read_root():
503
  return {"message": "AQI Prediction API is running."}
 
1
+
2
  import os
3
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
4
  os.environ['JAX_PLATFORMS'] = 'cpu'
 
136
  def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: float):
137
  print(f"Attempting to retrieve data for the last {sequence_length} hours from Open-Meteo for Lat: {latitude}, Lon: {longitude}")
138
 
139
+ # Calculate fetch_hours needed (sequence_length + buffer for resampling/NaNs)
140
+ fetch_hours = sequence_length + 5
141
+
142
+ # For temperature, we still need a date range for the archive API
143
+ end_time_for_temp = datetime.now(pytz.utc)
144
+ start_time_for_temp = end_time_for_temp - timedelta(hours=fetch_hours)
145
 
 
 
 
146
 
147
+ print(f"Requesting data for the past {fetch_hours} hours for air quality.")
148
+ print(f"Requesting temperature data from {start_time_for_temp.strftime('%Y-%m-%d')} to {end_time_for_temp.strftime('%Y-%m-%d')}")
149
 
150
  # Open-Meteo Air Quality API
151
  air_quality_url = "https://air-quality-api.open-meteo.com/v1/air-quality"
 
154
  "longitude": longitude,
155
  "hourly": ["pm2_5", "pm10", "carbon_monoxide"],
156
  "timezone": "UTC",
157
+ "past_hours": fetch_hours # Use past_hours instead of start/end_date
 
 
158
  }
159
 
160
+ # Open-Meteo Historical Weather API for Temperature (still uses start/end_date)
161
  weather_url = "https://archive-api.open-meteo.com/v1/archive"
162
  weather_params = {
163
  "latitude": latitude,
164
  "longitude": longitude,
165
  "hourly": ["temperature_2m"],
166
  "timezone": "UTC",
167
+ "start_date": start_time_for_temp.strftime('%Y-%m-%d'),
168
+ "end_date": end_time_for_temp.strftime('%Y-%m-%d')
169
  }
170
 
171
  try:
172
  # Fetch Air Quality Data
173
+ print(f"Fetching air quality data from: {air_quality_url} with params: {air_quality_params}")
174
  air_quality_response = requests.get(air_quality_url, params=air_quality_params)
175
  air_quality_response.raise_for_status()
176
  air_quality_data = air_quality_response.json()
177
  print("Air quality data retrieved.")
178
 
179
  # Fetch Temperature Data
180
+ print(f"Fetching temperature data from: {weather_url} with params: {weather_params}")
181
  weather_response = requests.get(weather_url, params=weather_params)
182
  weather_response.raise_for_status()
183
  weather_data = weather_response.json()
 
207
 
208
 
209
  # Resample to ensure consistent hourly frequency and fill missing data
 
210
  df_processed = df_merged.resample('h').ffill().bfill()
211
  print(f"DataFrame resampled to hourly. Shape: {df_processed.shape}")
212
 
 
233
  print(f"Selected and reordered columns. Final processing shape: {df_processed.shape}")
234
 
235
 
236
+ # Handle any remaining NaNs after ffill/bfill
237
  initial_rows = len(df_processed)
238
  df_processed.dropna(inplace=True)
239
  if len(df_processed) < initial_rows:
 
246
  return None, f"Error: Insufficient historical data ({len(df_processed)} points available, {sequence_length} required)."
247
 
248
  # Select the last `sequence_length` rows for the input sequence
249
+ latest_data_sequence_df = df_processed.tail(sequence_length).copy()
250
  print(f"Selected last {sequence_length} data points.")
251
 
252
  # Convert to numpy array and reshape (1, sequence_length, num_features)
 
257
 
258
  print(f"Prepared input sequence with shape: {latest_data_sequence.shape}")
259
 
260
+ return latest_data_sequence, timestamps
261
 
262
  except requests.exceptions.RequestException as e:
263
  print(f"API Request Error: {e}")
 
269
 
270
 
271
  # --- Define paths to your saved files ---
 
272
  MODEL_PATH = 'best_model_TKAN_nahead_1.keras'
273
  INPUT_SCALER_ATTR_PATH = 'input_scaler_attributes.json'
274
  TARGET_SCALER_ATTR_PATH = 'target_scaler_attributes.json'
 
277
 
278
  # --- Load the scalers and model ---
279
  input_scaler = None
280
+ target_scaler = None
281
  model = None
282
 
283
  try:
284
  print(f"Attempting to load input scaler attributes from {INPUT_SCALER_ATTR_PATH}...")
285
  with open(INPUT_SCALER_ATTR_PATH, 'r') as f:
286
  input_attrs = json.load(f)
287
+ input_scaler = MinMaxScaler()
288
+ input_scaler.load_attributes(input_attrs)
289
  print("Input scaler loaded manually.")
290
 
291
  print(f"Attempting to load target scaler attributes from {TARGET_SCALER_ATTR_PATH}...")
292
  with open(TARGET_SCALER_ATTR_PATH, 'r') as f:
293
  target_attrs = json.load(f)
294
+ target_scaler = MinMaxScaler()
295
+ target_scaler.load_attributes(target_attrs)
296
  print("Target scaler loaded manually.")
297
 
 
298
  print(f"Attempting to load y_scaler_train numpy array from {Y_SCALER_TRAIN_PATH}...")
299
  y_scaler_train = np.load(Y_SCALER_TRAIN_PATH)
300
  print("y_scaler_train numpy array loaded.")
 
307
  import traceback
308
  traceback.print_exc()
309
 
 
310
  custom_objects = {"TKAN": TKAN}
311
  if TKAT is not None:
312
  custom_objects["TKAT"] = TKAT
313
 
314
  try:
315
  print(f"Loading model from {MODEL_PATH}...")
 
316
  with custom_object_scope(custom_objects):
 
317
  model = load_model(MODEL_PATH, compile=False)
318
  print("Model loaded successfully.")
319
  except FileNotFoundError:
 
327
  traceback.print_exc()
328
 
329
 
 
330
  app = FastAPI()
331
 
 
332
  class PredictionRequest(BaseModel):
333
  latitude: float
334
  longitude: float
335
+ pm25: float = None
336
  pm10: float = None
337
  co: float = None
338
  temp: float = None
339
+ n_ahead: int = 1
340
 
341
 
 
342
  class PredictionResponse(BaseModel):
343
+ status: str
344
+ message: str
345
+ predictions: list = None
346
 
347
 
 
348
  @app.post("/predict", response_model=PredictionResponse)
349
  async def predict_aqi_endpoint(request: PredictionRequest):
 
350
  if model is None or input_scaler is None or target_scaler is None:
351
  print("API called but model or scalers are not loaded.")
 
352
  raise HTTPException(status_code=500, detail="Model or scalers not loaded. Check server logs for details.")
353
 
 
 
354
  if model.input_shape is None or len(model.input_shape) < 2:
355
  print(f"Error: Model has unexpected input shape: {model.input_shape}")
356
  raise HTTPException(status_code=500, detail=f"Model has unexpected input shape: {model.input_shape}")
 
363
  raise HTTPException(status_code=500, detail=f"Model expects {NUM_FEATURES} features, but data processing provides {required_num_features}.")
364
 
365
 
 
 
366
  latest_data_sequence_unscaled, message = get_latest_data_sequence(SEQUENCE_LENGTH, request.latitude, request.longitude)
367
 
 
368
  if latest_data_sequence_unscaled is None:
 
369
  print(f"Data retrieval failed: {message}")
370
  return PredictionResponse(status="error", message=f"Data retrieval failed: {message}")
371
 
 
 
372
  prediction_timestamps = []
373
+ if message and isinstance(message, list) and len(message) > 0:
374
+ last_timestamp_of_sequence = message[-1]
375
  for i in range(request.n_ahead):
 
376
  prediction_timestamps.append(last_timestamp_of_sequence + timedelta(hours=i + 1))
377
  else:
378
  print("Warning: Could not get valid timestamps from data retrieval. Prediction timestamps will be approximate.")
 
379
  now_utc = datetime.now(pytz.utc)
380
  for i in range(request.n_ahead):
381
  prediction_timestamps.append(now_utc + timedelta(hours=i+1))
382
 
383
 
 
 
384
  if request.pm25 is not None and not pd.isna(request.pm25) and \
385
  request.pm10 is not None and not pd.isna(request.pm10) and \
386
  request.co is not None and not pd.isna(request.co) and \
 
389
  current_aqi = calculate_overall_aqi({'pm25': request.pm25, 'pm10': request.pm10, 'co': request.co, 'temp': request.temp}, aqi_breakpoints)
390
 
391
  if not pd.isna(current_aqi):
 
 
392
  latest_data_sequence_unscaled[0, -1, 0] = current_aqi
393
  latest_data_sequence_unscaled[0, -1, 1] = request.temp
394
  latest_data_sequence_unscaled[0, -1, 2] = request.pm25
 
398
  else:
399
  print("Warning: Could not calculate AQI for current inputs. Last timestep remains historical.")
400
 
 
401
  try:
402
  X_scaled = input_scaler.transform(latest_data_sequence_unscaled)
403
  print("Input data scaled successfully.")
 
407
  raise HTTPException(status_code=500, detail="Error processing input data for prediction (scaling).")
408
 
409
 
 
410
  try:
411
+ scaled_prediction = model.predict(X_scaled, verbose=0)
412
  print(f"Model prediction made. Scaled prediction shape: {scaled_prediction.shape}")
413
  except Exception as e:
414
  print(f"Error during model prediction: {e}")
 
416
  raise HTTPException(status_code=500, detail="Error during model prediction.")
417
 
418
 
 
419
  try:
 
 
 
 
420
  if latest_data_sequence_unscaled.shape[1] > 0:
421
+ calculated_aqi_sequence = latest_data_sequence_unscaled[0, :, 0]
 
422
 
 
 
423
  approx_rolling_median_proxy = np.mean(calculated_aqi_sequence[-min(5, SEQUENCE_LENGTH):])
424
  if pd.isna(approx_rolling_median_proxy) or approx_rolling_median_proxy <= 0:
425
+ approx_rolling_median_proxy = 1.0
426
 
 
427
  corresponding_rolling_median_scaler = np.full((1, request.n_ahead, 1), approx_rolling_median_proxy, dtype=np.float32)
428
  print(f"Approximated rolling median proxy for inverse transform: {approx_rolling_median_proxy:.2f}")
429
 
 
430
  y_unscaled_pred_ratio = target_scaler.inverse_transform(scaled_prediction.reshape(1, request.n_ahead, 1))
431
  print(f"Inverse transformed to ratio scale. Shape: {y_unscaled_pred_ratio.shape}")
432
 
 
433
  predicted_aqi_values = y_unscaled_pred_ratio * corresponding_rolling_median_scaler
434
+ predicted_aqi_values = predicted_aqi_values.flatten()
435
 
436
  else:
437
  print("Error: Input sequence is empty, cannot perform inverse transform.")
 
444
  traceback.print_exc()
445
  raise HTTPException(status_code=500, detail="Error processing prediction results (inverse transform).")
446
 
 
447
  predictions_list = []
448
  for i in range(request.n_ahead):
 
449
  timestamp_str = prediction_timestamps[i].strftime('%Y-%m-%d %H:%M:%S')
450
  predictions_list.append({
451
  "timestamp": timestamp_str,
452
+ "aqi": float(predicted_aqi_values[i])
453
  })
454
 
 
455
  return PredictionResponse(status="success", message="Prediction successful.", predictions=predictions_list)
456
 
 
457
  @app.get("/")
458
  async def read_root():
459
  return {"message": "AQI Prediction API is running."}