File size: 22,043 Bytes
3c57246
 
 
1a8155e
b4a7e94
1a8155e
43ca9e3
b4a7e94
43ca9e3
 
 
 
 
3c57246
 
 
 
00baf30
 
 
3c57246
 
 
 
 
 
 
1a8155e
b4a7e94
3c57246
b4a7e94
 
 
 
 
 
43ca9e3
1a8155e
b4a7e94
43ca9e3
b4a7e94
 
 
43ca9e3
b4a7e94
 
43ca9e3
b4a7e94
 
 
 
 
 
 
 
 
3c57246
 
b4a7e94
 
 
 
 
3c57246
 
b4a7e94
 
 
 
3c57246
43ca9e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c57246
43ca9e3
3c57246
43ca9e3
3c57246
 
 
43ca9e3
 
3c57246
 
 
 
 
 
 
43ca9e3
3c57246
b4a7e94
3c57246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2113e94
 
 
3c57246
 
 
 
b4a7e94
3c57246
 
b4a7e94
 
3c57246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4a7e94
 
3c57246
 
 
 
b4a7e94
3c57246
 
 
 
 
 
 
 
 
 
 
 
 
b4a7e94
3c57246
b4a7e94
43ca9e3
 
3c57246
 
43ca9e3
3c57246
 
 
 
 
 
 
 
 
43ca9e3
 
3c57246
 
 
 
 
43ca9e3
 
3c57246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43ca9e3
3c57246
43ca9e3
3c57246
 
 
 
43ca9e3
3c57246
43ca9e3
3c57246
b4a7e94
 
3c57246
 
 
 
 
 
 
 
b4a7e94
43ca9e3
3c57246
 
 
 
 
b4a7e94
3c57246
 
 
b4a7e94
3c57246
 
 
 
 
b4a7e94
3c57246
 
 
b4a7e94
3c57246
 
 
b4a7e94
3c57246
 
 
b4a7e94
3c57246
 
 
1a8155e
3c57246
43ca9e3
3c57246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
# app.py (or main.py)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Input
from tensorflow.keras.utils import custom_object_scope
import pickle
import os
import requests
import pandas as pd
from datetime import datetime, timedelta, timezone
import pytz
import json
import traceback # Import traceback to print detailed error info

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Assuming TKAN is installed and available
from tkan import TKAN
try:
    from tkat import TKAT
except ImportError:
    print("TKAT library not found. If your model uses TKAT, ensure the library is installed.")
    TKAT = None

# --- Your MinMaxScaler Class (Copied from Notebook) ---
# This class is essential for loading your scalers
class MinMaxScaler:
    def __init__(self, feature_axis=None, minmax_range=(0, 1)):
        self.feature_axis = feature_axis
        self.min_ = None
        self.max_ = None
        self.scale_ = None
        self.minmax_range = minmax_range

    def fit(self, X):
        if X.ndim == 3 and self.feature_axis is not None:
            axis = tuple(i for i in range(X.ndim) if i != self.feature_axis)
            self.min_ = np.min(X, axis=axis)
            self.max_ = np.max(X, axis=axis)
        elif X.ndim == 2:
            self.min_ = np.min(X, axis=0)
            self.max_ = np.max(X, axis=0)
        elif X.ndim == 1:
            self.min_ = np.min(X)
            self.max_ = np.max(X)
        else:
            raise ValueError("Data must be 1D, 2D, or 3D.")

        self.scale_ = self.max_ - self.min_
        return self

    def transform(self, X):
        if self.min_ is None or self.max_ is None or self.scale_ is None:
             raise ValueError("Scaler has not been fitted.")
        X_scaled = (X - self.min_) / self.scale_
        X_scaled = X_scaled * (self.minmax_range[1] - self.minmax_range[0]) + self.minmax_range[0]
        return X_scaled

    def inverse_transform(self, X_scaled):
        if self.min_ is None or self.max_ is None or self.scale_ is None:
             raise ValueError("Scaler has not been fitted.")
        X = (X_scaled - self.minmax_range[0]) / (self.minmax_range[1] - self.minmax_range[0])
        X = X * self.scale_ + self.min_
        return X

# --- AQI breakpoints and calculation functions (Copied from Notebook) ---
aqi_breakpoints = {
    'pm25': [(0, 50, 0, 50), (51, 100, 51, 100), (101, 200, 101, 200), (201, 300, 201, 300)],
    'pm10': [(0, 50, 0, 50), (51, 100, 51, 100), (101, 250, 101, 200), (251, 350, 201, 300)],
    'co': [(0, 1.0, 0, 50), (1.1, 2.0, 51, 100), (2.1, 10.0, 101, 200), (10.1, 17.0, 201, 300)]
}

def calculate_sub_aqi(concentration, breakpoints):
    for i_low, i_high, c_low, c_high in breakpoints:
        if c_low <= concentration <= c_high:
            if c_high == c_low:
                 return i_low
            return ((i_high - i_low) / (c_high - c_low)) * (concentration - c_low) + i_low
    if concentration < breakpoints[0][2]:
        return breakpoints[0][0]
    elif concentration > breakpoints[-1][3]:
        return breakpoints[-1][1]
    else:
        return np.nan

def calculate_overall_aqi(row, aqi_breakpoints):
    sub_aqis = []
    # Mapping API names to internal names if necessary
    pollutant_mapping = {
        'pm25': 'pm25',
        'pm10': 'pm10',
        'co': 'co',
        'pm2_5': 'pm25', # Common API name for PM2.5
        'carbon_monoxide': 'co', # Common API name for CO
    }
    for api_pollutant, internal_pollutant in pollutant_mapping.items():
        if api_pollutant in row:
            concentration = row[api_pollutant]
            if not pd.isna(concentration): # Use pd.isna for pandas DataFrames/Series
                sub_aqi = calculate_sub_aqi(concentration, aqi_breakpoints.get(internal_pollutant, []))
                sub_aqis.append(sub_aqi)
            else:
                sub_aqis.append(np.nan)
        else:
             sub_aqis.append(np.nan)

    # Use np.nanmax to find the maximum ignoring NaNs. Returns -inf if all are NaN.
    # Check if sub_aqis list is not empty and contains at least one non-NaN value
    if sub_aqis and not all(pd.isna(sub_aqis)):
        return np.nanmax(sub_aqis)
    else:
        return np.nan # Return NaN if no valid pollutant data is available

# --- Data Retrieval Function ---
def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: float):
    print(f"Attempting to retrieve data for the last {sequence_length} hours from Open-Meteo for Lat: {latitude}, Lon: {longitude}")

    end_time = datetime.now(pytz.utc)
    # Fetch slightly more data to allow for resampling and ensure sequence_length is met
    fetch_hours = sequence_length + 5
    start_time = end_time - timedelta(hours=fetch_hours)

    # Format timestamps for API request (ISO 8601)
    start_time_str = start_time.isoformat().split('.')[0] + 'Z'
    end_time_str = end_time.isoformat().split('.')[0] + 'Z'

    print(f"Requesting data from {start_time_str} to {end_time_str}")

    # Open-Meteo Air Quality API
    air_quality_url = "https://air-quality-api.open-meteo.com/v1/air-quality"
    air_quality_params = {
        "latitude": latitude,
        "longitude": longitude,
        "hourly": ["pm2_5", "pm10", "carbon_monoxide"],
        "timezone": "UTC",
        "start_date": start_time.strftime('%Y-%m-%d'), # Use YYYY-MM-DD format
        "end_date": end_time.strftime('%Y-%m-%d'),
        "past_hours": fetch_hours
    }

    # Open-Meteo Historical Weather API for Temperature
    weather_url = "https://archive-api.open-meteo.com/v1/archive"
    weather_params = {
        "latitude": latitude,
        "longitude": longitude,
        "hourly": ["temperature_2m"],
        "timezone": "UTC",
        "start_date": start_time.strftime('%Y-%m-%d'),
        "end_date": end_time.strftime('%Y-%m-%d')
    }

    try:
        # Fetch Air Quality Data
        print(f"Fetching air quality data from: {air_quality_url}")
        air_quality_response = requests.get(air_quality_url, params=air_quality_params)
        air_quality_response.raise_for_status()
        air_quality_data = air_quality_response.json()
        print("Air quality data retrieved.")

        # Fetch Temperature Data
        print(f"Fetching temperature data from: {weather_url}")
        weather_response = requests.get(weather_url, params=weather_params)
        weather_response.raise_for_status()
        weather_data = weather_response.json()
        print("Temperature data retrieved.")

        print("Data fetched successfully.")

        # Process Air Quality Data
        if 'hourly' not in air_quality_data or 'time' not in air_quality_data['hourly']:
             print("Error: 'hourly' or 'time' key not found in air quality response.")
             return None, "Error: Invalid air quality data format from API."
        df_aq = pd.DataFrame(air_quality_data['hourly'])
        df_aq['time'] = pd.to_datetime(df_aq['time'])
        df_aq.set_index('time', inplace=True)

        # Process Temperature Data
        if 'hourly' not in weather_data or 'time' not in weather_data['hourly']:
             print("Error: 'hourly' or 'time' key not found in weather response.")
             return None, "Error: Invalid weather data format from API."
        df_temp = pd.DataFrame(weather_data['hourly'])
        df_temp['time'] = pd.to_datetime(df_temp['time'])
        df_temp.set_index('time', inplace=True)

        # Merge dataframes
        df_merged = df_aq.merge(df_temp, left_index=True, right_index=True, how='outer')
        print("DataFrames merged.")


        # Resample to ensure consistent hourly frequency and fill missing data
        # Use 'h' for hourly resampling
        df_processed = df_merged.resample('h').ffill().bfill()
        print(f"DataFrame resampled to hourly. Shape: {df_processed.shape}")


        # Rename columns to match internal naming convention
        df_processed.rename(columns={'pm2_5': 'pm25', 'carbon_monoxide': 'co', 'temperature_2m': 'temp'}, inplace=True)
        print("Renamed columns.")


        # Calculate AQI for the processed data
        df_processed['calculated_aqi'] = df_processed.apply(lambda row: calculate_overall_aqi(row, aqi_breakpoints), axis=1)
        print("Calculated AQI.")


        # Select and reorder columns to match training data order
        required_columns = ['calculated_aqi', 'temp', 'pm25', 'pm10', 'co']
        # Ensure all required columns exist before selecting
        if not all(col in df_processed.columns for col in required_columns):
             missing_cols = [col for col in required_columns if col not in df_processed.columns]
             print(f"Error: Missing required columns after processing: {missing_cols}")
             return None, f"Error: Missing required data columns: {missing_cols}"

        df_processed = df_processed[required_columns].copy()
        print(f"Selected and reordered columns. Final processing shape: {df_processed.shape}")


        # Handle any remaining NaNs after ffill/bfill (e.g., if the very first values were NaN or API returned all NaNs)
        initial_rows = len(df_processed)
        df_processed.dropna(inplace=True)
        if len(df_processed) < initial_rows:
             print(f"Warning: Dropped {initial_rows - len(df_processed)} rows with remaining NaNs.")


        # Check if enough data points are available
        if len(df_processed) < sequence_length:
            print(f"Error: Only retrieved and processed {len(df_processed)} data points, but {sequence_length} are required.")
            return None, f"Error: Insufficient historical data ({len(df_processed)} points available, {sequence_length} required)."

        # Select the last `sequence_length` rows for the input sequence
        latest_data_sequence_df = df_processed.tail(sequence_length).copy() # Use .copy() to avoid SettingWithCopyWarning
        print(f"Selected last {sequence_length} data points.")

        # Convert to numpy array and reshape (1, sequence_length, num_features)
        latest_data_sequence = latest_data_sequence_df.values.reshape(1, sequence_length, len(required_columns))

        # Get the timestamps for output formatting later
        timestamps = latest_data_sequence_df.index.tolist()

        print(f"Prepared input sequence with shape: {latest_data_sequence.shape}")

        return latest_data_sequence, timestamps # Return data and timestamps

    except requests.exceptions.RequestException as e:
        print(f"API Request Error: {e}")
        return None, f"API Request Error: {e}"
    except Exception as e:
        print(f"An unexpected error occurred during data retrieval and processing: {e}")
        traceback.print_exc()
        return None, f"An unexpected error occurred during data processing: {e}"


# --- Define paths to your saved files ---
# Use relative paths assuming files are in the root directory of the Space
MODEL_PATH = 'best_model_TKAN_nahead_1.keras'
INPUT_SCALER_PATH = 'input_scaler.pkl'
TARGET_SCALER_PATH = 'target_scaler.pkl' # This should be the scaler for the ratio
# Y_SCALER_TRAIN_PATH = 'y_scaler_train.pkl' # Keep commented out unless you find a specific use for it in the inverse transform


# --- Load the scalers and model ---
input_scaler = None
target_scaler = None # Scaler for the AQI/rolling_median ratio
model = None

try:
    with open(INPUT_SCALER_PATH, 'rb') as f:
        input_scaler = pickle.load(f)
    print(f"Input scaler loaded successfully from {INPUT_SCALER_PATH}")

    with open(TARGET_SCALER_PATH, 'rb') as f:
        target_scaler = pickle.load(f)
    print(f"Target scaler (for ratio) loaded successfully from {TARGET_SCALER_PATH}")

except FileNotFoundError as e:
    print(f"Error loading scaler files: {e}")
    print("Please ensure input_scaler.pkl and target_scaler.pkl are in the correct directory.")
    # These need to be loaded for the app to work, so we might let the startup fail or raise an error here.
    # For a web app, letting it fail on startup and show in logs is better than running with None scalers.
    # However, for the purpose of giving you the code structure, we'll just print and model=None below.
except Exception as e:
    print(f"An unexpected error occurred during scaler loading: {e}")
    traceback.print_exc()


# Load the trained model with custom_object_scope
custom_objects = {"TKAN": TKAN}
if TKAT is not None:
     custom_objects["TKAT"] = TKAT

try:
    print(f"Loading model from {MODEL_PATH}...")
    # Use custom_object_scope to register custom layers during loading
    with custom_object_scope(custom_objects):
        # compile=False because we only need the model for inference
        model = load_model(MODEL_PATH, compile=False)
    print("Model loaded successfully.")
except FileNotFoundError:
    print(f"Error: Model file not found at {MODEL_PATH}.")
except ValueError as e:
     print(f"Error loading model (ValueError): {e}")
     print("This can happen if the file is not a valid Keras file or if custom objects are not registered.")
     traceback.print_exc()
except Exception as e:
    print(f"An unexpected error occurred during model loading: {e}")
    traceback.print_exc()


# Initialize FastAPI app
app = FastAPI()

# Define the structure of the prediction request body
class PredictionRequest(BaseModel):
    latitude: float
    longitude: float
    pm25: float = None # Make current inputs optional, rely primarily on historical fetch
    pm10: float = None
    co: float = None
    temp: float = None
    n_ahead: int = 1 # Default prediction steps


# Define the structure of the prediction response body
class PredictionResponse(BaseModel):
    status: str # "success" or "error"
    message: str # Description of the result or error
    predictions: list = None # List of {"timestamp": "...", "aqi": ...} or None on error


# Define the prediction endpoint
@app.post("/predict", response_model=PredictionResponse)
async def predict_aqi_endpoint(request: PredictionRequest):
    # Check if model and scalers were loaded successfully on startup
    if model is None or input_scaler is None or target_scaler is None:
        print("API called but model or scalers are not loaded.")
        # Return a 500 Internal Server Error if dependencies failed to load
        raise HTTPException(status_code=500, detail="Model or scalers not loaded. Check server logs for details.")

    # Get the expected sequence length and number of features from the model's input shape
    # Assuming input shape is (None, sequence_length, num_features)
    if model.input_shape is None or len(model.input_shape) < 2:
         print(f"Error: Model has unexpected input shape: {model.input_shape}")
         raise HTTPException(status_code=500, detail=f"Model has unexpected input shape: {model.input_shape}")

    SEQUENCE_LENGTH = model.input_shape[1]
    NUM_FEATURES = model.input_shape[2]
    required_num_features = len(['calculated_aqi', 'temp', 'pm25', 'pm10', 'co'])
    if NUM_FEATURES != required_num_features:
         print(f"Error: Model expects {NUM_FEATURES} features, but data processing provides {required_num_features}.")
         raise HTTPException(status_code=500, detail=f"Model expects {NUM_FEATURES} features, but data processing provides {required_num_features}.")


    # Get the historical data sequence and its timestamps from Open-Meteo
    # The function now returns the data and a message (or error)
    latest_data_sequence_unscaled, message = get_latest_data_sequence(SEQUENCE_LENGTH, request.latitude, request.longitude)

    # Check if data retrieval was successful
    if latest_data_sequence_unscaled is None:
        # Return an error response if data fetching failed
        print(f"Data retrieval failed: {message}")
        return PredictionResponse(status="error", message=f"Data retrieval failed: {message}")

    # The timestamps returned are for the sequence itself. We need timestamps for the *predictions*.
    # The predictions are for n_ahead steps *after* the last timestamp in the sequence.
    prediction_timestamps = []
    if message and isinstance(message, list) and len(message) > 0: # 'message' is actually 'timestamps' here
        last_timestamp_of_sequence = message[-1] # Get the last timestamp from the sequence
        for i in range(request.n_ahead):
            # Prediction i (0-indexed) is for hour i+1 after the last timestamp
            prediction_timestamps.append(last_timestamp_of_sequence + timedelta(hours=i + 1))
    else:
        print("Warning: Could not get valid timestamps from data retrieval. Prediction timestamps will be approximate.")
        # Fallback: Approximate timestamps based on current time
        now_utc = datetime.now(pytz.utc)
        for i in range(request.n_ahead):
             prediction_timestamps.append(now_utc + timedelta(hours=i+1))


    # Optional: Update the last timestep with current user inputs if provided
    # Check if current inputs were provided and are valid (not None or NaN)
    if request.pm25 is not None and not pd.isna(request.pm25) and \
       request.pm10 is not None and not pd.isna(request.pm10) and \
       request.co is not None and not pd.isna(request.co) and \
       request.temp is not None and not pd.isna(request.temp):

        current_aqi = calculate_overall_aqi({'pm25': request.pm25, 'pm10': request.pm10, 'co': request.co, 'temp': request.temp}, aqi_breakpoints)

        if not pd.isna(current_aqi):
            # Assuming column order: 'calculated_aqi', 'temp', 'pm25', 'pm10', 'co'
            # Update the last row (-1) of the input sequence
            latest_data_sequence_unscaled[0, -1, 0] = current_aqi
            latest_data_sequence_unscaled[0, -1, 1] = request.temp
            latest_data_sequence_unscaled[0, -1, 2] = request.pm25
            latest_data_sequence_unscaled[0, -1, 3] = request.pm10
            latest_data_sequence_unscaled[0, -1, 4] = request.co
            print("Updated last timestep of input sequence with current user inputs.")
        else:
             print("Warning: Could not calculate AQI for current inputs. Last timestep remains historical.")

    # Scale the input data
    try:
        X_scaled = input_scaler.transform(latest_data_sequence_unscaled)
        print("Input data scaled successfully.")
    except Exception as e:
        print(f"Error scaling input data: {e}")
        traceback.print_exc()
        raise HTTPException(status_code=500, detail="Error processing input data for prediction (scaling).")


    # Make prediction
    try:
        scaled_prediction = model.predict(X_scaled, verbose=0) # Shape (1, n_ahead)
        print(f"Model prediction made. Scaled prediction shape: {scaled_prediction.shape}")
    except Exception as e:
        print(f"Error during model prediction: {e}")
        traceback.print_exc()
        raise HTTPException(status_code=500, detail="Error during model prediction.")


    # Inverse transform the prediction
    try:
        # --- Inverse Transformation Logic (Based on Rolling Median Scaling) ---
        # This part needs the actual rolling median for the future prediction timesteps.
        # Using an approximation based on the input sequence.

        if latest_data_sequence_unscaled.shape[1] > 0:
            # Get the 'calculated_aqi' values from the unscaled input sequence
            calculated_aqi_sequence = latest_data_sequence_unscaled[0, :, 0] # Assuming AQI is the first feature

            # Approximate the rolling median based on the last few points of the input sequence
            # This is a simple approximation. A more robust method might be needed.
            approx_rolling_median_proxy = np.mean(calculated_aqi_sequence[-min(5, SEQUENCE_LENGTH):])
            if pd.isna(approx_rolling_median_proxy) or approx_rolling_median_proxy <= 0:
                 approx_rolling_median_proxy = 1.0 # Prevent division by zero or invalid scaling

            # Create a placeholder scaler array for the future timesteps
            corresponding_rolling_median_scaler = np.full((1, request.n_ahead, 1), approx_rolling_median_proxy, dtype=np.float32)
            print(f"Approximated rolling median proxy for inverse transform: {approx_rolling_median_proxy:.2f}")

            # 1. Inverse transform the scaled prediction (ratio) using the target_scaler
            y_unscaled_pred_ratio = target_scaler.inverse_transform(scaled_prediction.reshape(1, request.n_ahead, 1))
            print(f"Inverse transformed to ratio scale. Shape: {y_unscaled_pred_ratio.shape}")

            # 2. Multiply the unscaled ratio by the approximated rolling median scaler
            predicted_aqi_values = y_unscaled_pred_ratio * corresponding_rolling_median_scaler
            predicted_aqi_values = predicted_aqi_values.flatten() # Shape (n_ahead,)

        else:
            print("Error: Input sequence is empty, cannot perform inverse transform.")
            raise ValueError("Input sequence is empty.")

        print(f"Final predicted AQI values: {predicted_aqi_values}")

    except Exception as e:
        print(f"Error during inverse transformation: {e}")
        traceback.print_exc()
        raise HTTPException(status_code=500, detail="Error processing prediction results (inverse transform).")

    # Prepare the prediction output list
    predictions_list = []
    for i in range(request.n_ahead):
        # Use the calculated prediction_timestamps
        timestamp_str = prediction_timestamps[i].strftime('%Y-%m-%d %H:%M:%S')
        predictions_list.append({
            "timestamp": timestamp_str,
            "aqi": float(predicted_aqi_values[i]) # Ensure AQI is a standard float
        })

    # Return the successful response
    return PredictionResponse(status="success", message="Prediction successful.", predictions=predictions_list)

# Root endpoint for health check
@app.get("/")
async def read_root():
    return {"message": "AQI Prediction API is running."}