Spaces:
Sleeping
Sleeping
| import os | |
| # Set environment variables for TensorFlow/Keras backend | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| os.environ['JAX_PLATFORMS'] = 'cpu' | |
| os.environ['JAX_ENABLE_X64'] = 'True' | |
| BACKEND = 'jax' | |
| os.environ['KERAS_BACKEND'] = BACKEND | |
| # app.py (or main.py) | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.layers import Input | |
| from tensorflow.keras.utils import custom_object_scope | |
| import pickle | |
| from tkan import TKAN | |
| import requests | |
| import pandas as pd | |
| from datetime import datetime, timedelta, timezone | |
| import pytz | |
| import json | |
| import traceback | |
| import joblib | |
| import jax | |
| import firebase_admin | |
| from firebase_admin import credentials, db | |
| # Initialize Firebase using environment variables | |
| firebase_credential_json = os.environ.get("FIREBASE_CREDENTIAL_JSON") | |
| if firebase_credential_json: | |
| # Write the JSON string to a file | |
| with open("firebase-credentials.json", "w") as f: | |
| f.write(firebase_credential_json) | |
| firebase_credential_path = "firebase-credentials.json" | |
| else: | |
| raise RuntimeError("Firebase credential not found in env variable.") | |
| try: | |
| # Get Firebase config from environment variables | |
| FIREBASE_CREDENTIAL_PATH = firebase_credential_path | |
| FIREBASE_DATABASE_URL = os.getenv("FIREBASE_DATABASE_URL") | |
| if not os.path.exists(FIREBASE_CREDENTIAL_PATH): | |
| raise FileNotFoundError(f"Firebase credentials file not found at {FIREBASE_CREDENTIAL_PATH}") | |
| cred = credentials.Certificate(FIREBASE_CREDENTIAL_PATH) | |
| firebase_admin.initialize_app(cred, { | |
| 'databaseURL': FIREBASE_DATABASE_URL | |
| }) | |
| firebase_initialized = True | |
| print("Firebase initialized successfully") | |
| except Exception as e: | |
| print(f"Firebase initialization failed: {e}") | |
| firebase_initialized = False | |
| class MinMaxScaler: | |
| def __init__(self, feature_axis=None, minmax_range=(0, 1)): | |
| self.feature_axis = feature_axis | |
| self.min_ = None | |
| self.max_ = None | |
| self.scale_ = None | |
| self.minmax_range = minmax_range | |
| def load_attributes(self, attributes): | |
| self.min_ = np.array(attributes['min_']) if isinstance(attributes['min_'], list) else attributes['min_'] | |
| self.max_ = np.array(attributes['max_']) if isinstance(attributes['max_'], list) else attributes['max_'] | |
| self.scale_ = np.array(attributes['scale_']) if isinstance(attributes['scale_'], list) else attributes['scale_'] | |
| self.minmax_range = tuple(attributes['minmax_range']) if isinstance(attributes['minmax_range'], list) else attributes['minmax_range'] | |
| def fit(self, X): | |
| if X.ndim == 3 and self.feature_axis is not None: | |
| axis = tuple(i for i in range(X.ndim) if i != self.feature_axis) | |
| self.min_ = np.min(X, axis=axis) | |
| self.max_ = np.max(X, axis=axis) | |
| elif X.ndim == 2: | |
| self.min_ = np.min(X, axis=0) | |
| self.max_ = np.max(X, axis=0) | |
| elif X.ndim == 1: | |
| self.min_ = np.min(X) | |
| self.max_ = np.max(X) | |
| else: | |
| raise ValueError("Data must be 1D, 2D, or 3D.") | |
| self.scale_ = self.max_ - self.min_ | |
| return self | |
| def transform(self, X): | |
| if self.min_ is None or self.max_ is None or self.scale_ is None: | |
| # Handle the case where scaler wasn't fitted (though it should be if attributes loaded) | |
| # Or raise an error | |
| raise ValueError("Scaler attributes not loaded or scaler not fitted.") | |
| X_scaled = (X - self.min_) / self.scale_ | |
| X_scaled = X_scaled * (self.minmax_range[1] - self.minmax_range[0]) + self.minmax_range[0] | |
| return X_scaled | |
| def inverse_transform(self, X_scaled): | |
| if self.min_ is None or self.max_ is None or self.scale_ is None: | |
| # Handle the case where scaler wasn't fitted | |
| raise ValueError("Scaler attributes not loaded or scaler not fitted.") | |
| X = (X_scaled - self.minmax_range[0]) / (self.minmax_range[1] - self.minmax_range[0]) | |
| X = X * self.scale_ + self.min_ | |
| return X | |
| # --- AQI breakpoints and calculation functions --- | |
| aqi_breakpoints = { | |
| 'pm25': [(0, 50, 0, 50), (51, 100, 51, 100), (101, 200, 101, 200), (201, 300, 201, 300)], | |
| 'pm10': [(0, 50, 0, 50), (51, 100, 51, 100), (101, 250, 101, 200), (251, 350, 201, 300)], | |
| 'co': [(0, 1.0, 0, 50), (1.1, 2.0, 51, 100), (2.1, 10.0, 101, 200), (10.1, 17.0, 201, 300)] | |
| } | |
| def calculate_sub_aqi(concentration, breakpoints): | |
| for i_low, i_high, c_low, c_high in breakpoints: | |
| if c_low <= concentration <= c_high: | |
| if c_high == c_low: | |
| return i_low | |
| return ((i_high - i_low) / (c_high - c_low)) * (concentration - c_low) + i_low | |
| if concentration < breakpoints[0][2]: | |
| return breakpoints[0][0] | |
| elif concentration > breakpoints[-1][3]: | |
| return breakpoints[-1][1] | |
| else: | |
| return np.nan | |
| def calculate_overall_aqi(row, aqi_breakpoints): | |
| sub_aqis = [] | |
| pollutant_mapping = { | |
| 'pm25': 'pm25', | |
| 'pm10': 'pm10', | |
| 'co': 'co', | |
| 'pm2_5': 'pm25', | |
| 'carbon_monoxide': 'co', | |
| } | |
| for api_pollutant, internal_pollutant in pollutant_mapping.items(): | |
| if api_pollutant in row: | |
| concentration = row[api_pollutant] | |
| if not pd.isna(concentration): | |
| sub_aqi = calculate_sub_aqi(concentration, aqi_breakpoints.get(internal_pollutant, [])) | |
| sub_aqis.append(sub_aqi) | |
| else: | |
| sub_aqis.append(np.nan) | |
| else: | |
| sub_aqis.append(np.nan) | |
| if sub_aqis and not all(pd.isna(sub_aqis)): | |
| return np.nanmax(sub_aqis) | |
| else: | |
| return np.nan | |
| # --- Function to retrieve data from Firebase --- | |
| def get_firebase_data(sequence_length: int, latitude: float, longitude: float): | |
| """ | |
| Retrieve data from Firebase RTDB. | |
| Returns: (data_sequence, timestamps) or (None, error_message) | |
| """ | |
| if not firebase_initialized: | |
| return None, "Firebase not initialized" | |
| try: | |
| print(f"Attempting to retrieve data from Firebase RTDB...") | |
| # Reference to your sensor data in Firebase | |
| ref = db.reference('/AQIData') # Adjust based on your Firebase structure | |
| # Get current time and calculate the time window | |
| current_utc_time = datetime.now(pytz.utc) | |
| window_start = current_utc_time - timedelta(hours=sequence_length + 6) | |
| window_end = current_utc_time | |
| # Query Firebase data | |
| firebase_data = ref.order_by_child('datetime').start_at(window_start.strftime('%Y-%m-%dT%H:%M')).end_at(window_end.strftime('%Y-%m-%dT%H:%M')).get() | |
| if not firebase_data: | |
| print("No data found in Firebase for the specified time range") | |
| return None, "No data in Firebase" | |
| # Convert Firebase data to DataFrame | |
| data_list = [] | |
| for key, sensor_data in firebase_data.items(): | |
| try: | |
| # Get the datetime string from inside the sensor_data dictionary | |
| datetime_str = sensor_data.get('datetime') | |
| # Check if datetime_str exists and is a string | |
| if isinstance(datetime_str, str): | |
| # Parse the datetime string using the correct format 'YYYY-MM-DDTHH:MM' | |
| # We assume the datetime in Firebase is already in UTC or is timezone-naive | |
| # and we treat it as UTC based on the original code's localization attempt. | |
| timestamp = datetime.strptime(datetime_str, '%Y-%m-%dT%H:%M') | |
| timestamp = pytz.utc.localize(timestamp) | |
| data_point = { | |
| 'time': timestamp, | |
| 'pm25': sensor_data.get('pm2_5', np.nan), | |
| 'pm10': sensor_data.get('pm10', np.nan), | |
| 'co': sensor_data.get('carbon_monoxide', np.nan), | |
| 'temp': sensor_data.get('temp', np.nan) | |
| } | |
| data_list.append(data_point) | |
| else: | |
| # Handle cases where 'datetime' is missing or not a string for a specific item | |
| print(f"Warning: Data item with key {key} is missing or has invalid 'datetime' field: {datetime_str}") | |
| except ValueError as ve: | |
| # Catch errors specifically related to parsing the datetime string | |
| print(f"Error parsing datetime string '{datetime_str}' for key {key}: {ve}. Expected format '%Y-%m-%dT%H:%M'") | |
| continue # Skip this data point if parsing fails | |
| except Exception as e: | |
| # Catch any other unexpected errors during processing of a single item | |
| print(f"An unexpected error occurred processing item with key {key}: {e}") | |
| continue | |
| if not data_list: | |
| print("No valid data points parsed from Firebase after attempting to process.") | |
| return None, "No valid data in Firebase after parsing" | |
| df = pd.DataFrame(data_list) | |
| df.set_index('time', inplace=True) | |
| df.sort_index(inplace=True) | |
| print(f"Retrieved {len(df)} data points from Firebase") | |
| # Resample to hourly frequency | |
| df_hourly = df.resample('h').mean() | |
| # Check for gaps | |
| time_diffs = df_hourly.index.to_series().diff() | |
| max_gap = time_diffs.max() | |
| if pd.notna(max_gap) and max_gap > timedelta(hours=1, minutes=30): | |
| print(f"Data has gaps larger than 1.5 hours. Max gap: {max_gap}") | |
| return None, f"Firebase data not consecutive (max gap: {max_gap})" | |
| # Check enough consecutive hours | |
| consecutive_hours = len(df_hourly) | |
| if consecutive_hours < sequence_length: | |
| print(f"Only {consecutive_hours} consecutive hours available, need {sequence_length}") | |
| return None, f"Insufficient consecutive hours in Firebase ({consecutive_hours}/{sequence_length})" | |
| # Calculate AQI | |
| df_hourly['calculated_aqi'] = df_hourly.apply(lambda row: calculate_overall_aqi(row, aqi_breakpoints), axis=1) | |
| # Select required columns | |
| required_columns = ['calculated_aqi', 'temp', 'pm25', 'pm10', 'co'] | |
| df_final = df_hourly[required_columns].copy() | |
| # Drop rows with NaN values | |
| df_final.dropna(inplace=True) | |
| if len(df_final) < sequence_length: | |
| print(f"After dropping NaN values, only {len(df_final)} valid points remain") | |
| return None, f"Insufficient valid data after cleaning ({len(df_final)}/{sequence_length})" | |
| # Get last sequence_length hours | |
| latest_data_df = df_final.tail(sequence_length) | |
| latest_data_sequence = latest_data_df.values.reshape(1, sequence_length, len(required_columns)) | |
| timestamps = latest_data_df.index.tolist() | |
| print(f"Successfully prepared Firebase data sequence with shape: {latest_data_sequence.shape}") | |
| return latest_data_sequence, timestamps | |
| except Exception as e: | |
| print(f"Error retrieving data from Firebase: {e}") | |
| traceback.print_exc() | |
| return None, f"Firebase error: {str(e)}" | |
| # --- Data retrieval function --- | |
| def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: float): | |
| """ | |
| Try to get data from Firebase first, fallback to OpenMeteo if insufficient. | |
| """ | |
| # First, try Firebase | |
| firebase_data, firebase_message = get_firebase_data(sequence_length, latitude, longitude) | |
| if firebase_data is not None: | |
| print("Successfully retrieved data from Firebase RTDB") | |
| return firebase_data, firebase_message | |
| # If Firebase fails, fallback to OpenMeteo | |
| print(f"Firebase data retrieval failed: {firebase_message}") | |
| print("Falling back to OpenMeteo API...") | |
| # Your existing OpenMeteo code starts here | |
| print(f"Attempting to retrieve data for the last {sequence_length} hours from Open-Meteo for Lat: {latitude}, Lon: {longitude}") | |
| current_utc_time = datetime.now(pytz.utc) | |
| print(f"Current UTC time on server for API calls: {current_utc_time.strftime('%Y-%m-%d %H:%M:%S UTC')}") | |
| # Define a window to fetch from APIs | |
| api_fetch_past_hours = sequence_length + 24 | |
| processing_window_hours = sequence_length + 24 | |
| print(f"Requesting data for the past {api_fetch_past_hours} hours for air quality and temperature from APIs.") | |
| air_quality_url = "https://air-quality-api.open-meteo.com/v1/air-quality" | |
| air_quality_params = { | |
| "latitude": latitude, | |
| "longitude": longitude, | |
| "hourly": ["pm2_5", "pm10", "carbon_monoxide"], | |
| "timezone": "UTC", | |
| "past_hours": api_fetch_past_hours | |
| } | |
| weather_url = "https://api.open-meteo.com/v1/forecast" | |
| weather_params = { | |
| "latitude": latitude, | |
| "longitude": longitude, | |
| "hourly": ["temperature_2m"], | |
| "timezone": "UTC", | |
| "past_hours": api_fetch_past_hours | |
| } | |
| try: | |
| print(f"Fetching air quality data from: {air_quality_url}") | |
| air_quality_response = requests.get(air_quality_url, params=air_quality_params) | |
| air_quality_response.raise_for_status() | |
| air_quality_data = air_quality_response.json() | |
| print("Air quality data retrieved.") | |
| print(f"Fetching temperature data from: {weather_url}") | |
| weather_response = requests.get(weather_url, params=weather_params) | |
| weather_response.raise_for_status() | |
| weather_data = weather_response.json() | |
| print("Temperature data retrieved.") | |
| print("Data fetched successfully from APIs.") | |
| if 'hourly' not in air_quality_data or 'time' not in air_quality_data['hourly']: | |
| print("Error: 'hourly' or 'time' key not found in air quality response.") | |
| return None, "Error: Invalid air quality data format from API." | |
| df_aq = pd.DataFrame(air_quality_data['hourly']) | |
| if df_aq.empty: | |
| print("Warning: Air quality data DataFrame is empty after fetching.") | |
| # Continue if not empty, but columns might be missing | |
| if not df_aq.empty and not all(col in df_aq.columns for col in ['time', 'pm2_5', 'pm10', 'carbon_monoxide']): | |
| print("Warning: Air quality data is missing some expected columns ('time', 'pm2_5', 'pm10', 'carbon_monoxide') after fetching.") | |
| if 'time' not in df_aq.columns and not df_aq.empty: | |
| return None, "Error: 'time' column missing in air quality data." | |
| if not df_aq.empty: | |
| df_aq['time'] = pd.to_datetime(df_aq['time']) | |
| df_aq.set_index('time', inplace=True) | |
| print(f"Processed df_aq. Shape: {df_aq.shape}. Columns: {df_aq.columns.tolist() if not df_aq.empty else 'N/A'}") | |
| if 'hourly' not in weather_data or 'time' not in weather_data['hourly']: | |
| print("Error: 'hourly' or 'time' key not found in weather response.") | |
| return None, "Error: Invalid weather data format from API." | |
| df_temp = pd.DataFrame(weather_data['hourly']) | |
| if df_temp.empty: | |
| print("Warning: Temperature data DataFrame is empty after fetching.") | |
| if not df_temp.empty and not all(col in df_temp.columns for col in ['time', 'temperature_2m']): | |
| print("Warning: Temperature data is missing some expected columns ('time', 'temperature_2m') after fetching.") | |
| if 'time' not in df_temp.columns and not df_temp.empty: | |
| return None, "Error: 'time' column missing in temperature data." | |
| if not df_temp.empty: | |
| df_temp['time'] = pd.to_datetime(df_temp['time']) | |
| df_temp.set_index('time', inplace=True) | |
| print(f"Processed df_temp. Shape: {df_temp.shape}. Columns: {df_temp.columns.tolist() if not df_temp.empty else 'N/A'}") | |
| if df_aq.empty or df_temp.empty: | |
| print("Error: One or both dataframes (AQ, Temp) are empty before merge. Cannot proceed.") | |
| return None, "Error: Insufficient data from APIs (AQ or Temp empty)." | |
| df_merged = df_aq.merge(df_temp, left_index=True, right_index=True, how='inner') | |
| print(f"DataFrames merged (inner). Initial merged shape: {df_merged.shape}") | |
| if df_merged.empty: | |
| print("Error: Inner merge of AQ and Temperature data resulted in an empty DataFrame. No overlapping timestamps with data.") | |
| return None, "Error: No overlapping AQ and Temperature data available for the period." | |
| # Resample to ensure consistent hourly frequency and fill missing data | |
| df_processed = df_merged.resample('h').mean() # Use mean for resampling to handle potential duplicates at same hour | |
| df_processed = df_processed.ffill().bfill() # Then fill | |
| print(f"DataFrame resampled to hourly, filled NaNs. Shape: {df_processed.shape}") | |
| # print(f"df_processed head after resample/ffill/bfill:\n{df_processed.head().to_string()}") | |
| # print(f"df_processed NaNs after resample/ffill/bfill:\n{df_processed.isna().sum().to_string()}") | |
| df_processed.rename(columns={'pm2_5': 'pm25', 'carbon_monoxide': 'co', 'temperature_2m': 'temp'}, inplace=True) | |
| print(f"Renamed columns. Current columns: {df_processed.columns.tolist()}") | |
| expected_cols_for_aqi = ['pm25', 'pm10', 'co'] | |
| for col in expected_cols_for_aqi: | |
| if col not in df_processed.columns: | |
| print(f"Warning: Column '{col}' for AQI calculation is missing after rename. Adding as NaN.") | |
| df_processed[col] = np.nan | |
| df_processed['calculated_aqi'] = df_processed.apply(lambda row: calculate_overall_aqi(row, aqi_breakpoints), axis=1) | |
| print("Calculated AQI.") | |
| # print(f"df_processed head after AQI calculation:\n{df_processed.head().to_string()}") | |
| # print(f"df_processed NaNs after AQI calculation:\n{df_processed.isna().sum().to_string()}") | |
| required_columns = ['calculated_aqi', 'temp', 'pm25', 'pm10', 'co'] | |
| for col in required_columns: | |
| if col not in df_processed.columns: | |
| print(f"Warning: Column '{col}' is missing before final selection. Adding it as NaN.") | |
| df_processed[col] = np.nan | |
| df_processed = df_processed[required_columns].copy() | |
| # print(f"Selected and reordered columns. Shape before windowing: {df_processed.shape}. Columns: {df_processed.columns.tolist()}") | |
| # Filter to the defined processing window relative to current time | |
| # Ensure we only consider data up to the current hour and back by processing_window_hours | |
| window_start_time_dt = current_utc_time.replace(minute=0, second=0, microsecond=0) - timedelta(hours=processing_window_hours - 1) | |
| window_end_time_dt = current_utc_time.replace(minute=0, second=0, microsecond=0) | |
| # Convert Python datetime to Pandas Timestamp for robust comparison | |
| # `window_start_time_dt` and `window_end_time_dt` are already UTC-aware from `datetime.now(pytz.utc)` | |
| window_start_time_ts = pd.Timestamp(window_start_time_dt) | |
| window_end_time_ts = pd.Timestamp(window_end_time_dt) | |
| # Ensure df_processed.index is timezone-aware (it should be if APIs return UTC and pd.to_datetime is used correctly) | |
| if df_processed.index.tz is None: | |
| print("Warning: df_processed.index is timezone-naive. Localizing to UTC.") | |
| df_processed.index = df_processed.index.tz_localize('UTC') | |
| df_recent_processed = df_processed[(df_processed.index >= window_start_time_ts) & (df_processed.index <= window_end_time_ts)].copy() | |
| print(f"Filtered to recent processing window ({processing_window_hours}hrs). Shape: {df_recent_processed.shape}") | |
| # print(f"df_recent_processed head:\n{df_recent_processed.head().to_string()}") | |
| # print(f"df_recent_processed NaNs before dropna:\n{df_recent_processed.isna().sum().to_string()}") | |
| initial_rows_recent = len(df_recent_processed) | |
| df_recent_processed.dropna(inplace=True) | |
| if len(df_recent_processed) < initial_rows_recent: | |
| print(f"Warning: Dropped {initial_rows_recent - len(df_recent_processed)} rows with NaNs from the recent processing window.") | |
| print(f"Shape after dropna on recent window: {df_recent_processed.shape}") | |
| if len(df_recent_processed) < sequence_length: | |
| print(f"Error: Only {len(df_recent_processed)} valid data points remain in the recent window after processing, but {sequence_length} are required.") | |
| return None, f"Error: Insufficient historical data in the recent window ({len(df_recent_processed)} points available, {sequence_length} required)." | |
| latest_data_sequence_df = df_recent_processed.tail(sequence_length).copy() | |
| print(f"Selected last {sequence_length} data points for model input. Shape: {latest_data_sequence_df.shape}") | |
| # print(f"Final sequence data:\n{latest_data_sequence_df.to_string()}") | |
| latest_data_sequence = latest_data_sequence_df.values.reshape(1, sequence_length, len(required_columns)) | |
| timestamps = latest_data_sequence_df.index.tolist() | |
| # print(f"Prepared input sequence with shape: {latest_data_sequence.shape}") | |
| return latest_data_sequence, timestamps | |
| except requests.exceptions.RequestException as e: | |
| print(f"API Request Error: {e}") | |
| traceback.print_exc() | |
| return None, f"API Request Error: {e}" | |
| except Exception as e: | |
| print(f"An unexpected error occurred during data retrieval and processing: {e}") | |
| traceback.print_exc() | |
| return None, f"An unexpected error occurred during data processing: {e}" | |
| # --- Define paths to your saved files --- | |
| MODEL_PATH = 'best_model_TKAN_nahead_1.keras' | |
| INPUT_SCALER_ATTR_PATH = 'input_scaler_attributes.json' | |
| TARGET_SCALER_ATTR_PATH = 'target_scaler_attributes.json' | |
| Y_SCALER_TRAIN_PATH = 'y_scaler_train.npy' | |
| # --- Load the scalers and model --- | |
| input_scaler = None | |
| target_scaler = None | |
| model = None | |
| try: | |
| print(f"Attempting to load input scaler attributes from {INPUT_SCALER_ATTR_PATH}...") | |
| with open(INPUT_SCALER_ATTR_PATH, 'r') as f: | |
| input_attrs = json.load(f) | |
| input_scaler = MinMaxScaler() | |
| input_scaler.load_attributes(input_attrs) | |
| print("Input scaler loaded manually.") | |
| print(f"Attempting to load target scaler attributes from {TARGET_SCALER_ATTR_PATH}...") | |
| with open(TARGET_SCALER_ATTR_PATH, 'r') as f: | |
| target_attrs = json.load(f) | |
| target_scaler = MinMaxScaler() | |
| target_scaler.load_attributes(target_attrs) | |
| print("Target scaler loaded manually.") | |
| print(f"Attempting to load y_scaler_train numpy array from {Y_SCALER_TRAIN_PATH}...") | |
| y_scaler_train = np.load(Y_SCALER_TRAIN_PATH) | |
| print("y_scaler_train numpy array loaded.") | |
| except FileNotFoundError as e: | |
| print(f"Error loading scaler attribute files (FileNotFoundError): {e}") | |
| except Exception as e: | |
| print(f"An error occurred during manual scaler loading: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| custom_objects = {"TKAN": TKAN} | |
| try: | |
| print(f"Loading model from {MODEL_PATH}...") | |
| with custom_object_scope(custom_objects): | |
| model = load_model(MODEL_PATH, compile=False) | |
| print("Model loaded successfully.") | |
| except FileNotFoundError: | |
| print(f"Error: Model file not found at {MODEL_PATH}.") | |
| except ValueError as e: | |
| print(f"Error loading model (ValueError): {e}") | |
| print("This can happen if the file is not a valid Keras file or if custom objects are not registered.") | |
| traceback.print_exc() | |
| except Exception as e: | |
| print(f"An unexpected error occurred during model loading: {e}") | |
| traceback.print_exc() | |
| app = FastAPI() | |
| class PredictionRequest(BaseModel): | |
| latitude: float | |
| longitude: float | |
| pm25: float = None | |
| pm10: float = None | |
| co: float = None | |
| temp: float = None | |
| n_ahead: int = 1 | |
| class PredictionResponse(BaseModel): | |
| status: str | |
| message: str | |
| predictions: list = None | |
| async def predict_aqi_endpoint(request: PredictionRequest): | |
| if model is None or input_scaler is None or target_scaler is None: | |
| print("API called but model or scalers are not loaded.") | |
| raise HTTPException(status_code=500, detail="Model or scalers not loaded. Check server logs for details.") | |
| if model.input_shape is None or len(model.input_shape) < 2: | |
| print(f"Error: Model has unexpected input shape: {model.input_shape}") | |
| raise HTTPException(status_code=500, detail=f"Model has unexpected input shape: {model.input_shape}") | |
| SEQUENCE_LENGTH = model.input_shape[1] | |
| NUM_FEATURES = model.input_shape[2] | |
| required_num_features_model = len(['calculated_aqi', 'temp', 'pm25', 'pm10', 'co']) | |
| if NUM_FEATURES != required_num_features_model: | |
| print(f"Error: Model expects {NUM_FEATURES} features, but data processing provides {required_num_features_model}.") | |
| raise HTTPException(status_code=500, detail=f"Model expects {NUM_FEATURES} features, data processing provides {required_num_features_model}.") | |
| latest_data_sequence_unscaled, message = get_latest_data_sequence(SEQUENCE_LENGTH, request.latitude, request.longitude) | |
| if latest_data_sequence_unscaled is None: | |
| print(f"Data retrieval failed: {message}") | |
| return PredictionResponse(status="error", message=f"Data retrieval failed: {message}") | |
| prediction_timestamps = [] | |
| if message and isinstance(message, list) and len(message) > 0: | |
| last_timestamp_of_sequence = message[-1] | |
| for i in range(request.n_ahead): | |
| prediction_timestamps.append(last_timestamp_of_sequence + timedelta(hours=i + 1)) | |
| else: | |
| print("Warning: Could not get valid timestamps from data retrieval. Prediction timestamps will be approximate.") | |
| now_utc = datetime.now(pytz.utc) | |
| for i in range(request.n_ahead): | |
| prediction_timestamps.append(now_utc + timedelta(hours=i+1)) | |
| if request.pm25 is not None and not pd.isna(request.pm25) and \ | |
| request.pm10 is not None and not pd.isna(request.pm10) and \ | |
| request.co is not None and not pd.isna(request.co) and \ | |
| request.temp is not None and not pd.isna(request.temp): | |
| current_aqi = calculate_overall_aqi({'pm25': request.pm25, 'pm10': request.pm10, 'co': request.co, 'temp': request.temp}, aqi_breakpoints) | |
| if not pd.isna(current_aqi) and latest_data_sequence_unscaled.shape[1] == SEQUENCE_LENGTH : # Ensure sequence is correctly shaped | |
| latest_data_sequence_unscaled[0, -1, 0] = current_aqi | |
| latest_data_sequence_unscaled[0, -1, 1] = request.temp | |
| latest_data_sequence_unscaled[0, -1, 2] = request.pm25 | |
| latest_data_sequence_unscaled[0, -1, 3] = request.pm10 | |
| latest_data_sequence_unscaled[0, -1, 4] = request.co | |
| print("Updated last timestep of input sequence with current user inputs.") | |
| elif pd.isna(current_aqi): | |
| print("Warning: Could not calculate AQI for current inputs. Last timestep remains historical.") | |
| else: | |
| print("Warning: Sequence not correctly shaped to update with current user inputs, or current_aqi is NaN.") | |
| try: | |
| X_scaled = input_scaler.transform(latest_data_sequence_unscaled) | |
| print("Input data scaled successfully.") | |
| except Exception as e: | |
| print(f"Error scaling input data: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail="Error processing input data for prediction (scaling).") | |
| try: | |
| scaled_prediction = model.predict(X_scaled, verbose=0) | |
| print(f"Model prediction made. Scaled prediction shape: {scaled_prediction.shape}") | |
| except Exception as e: | |
| print(f"Error during model prediction: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail="Error during model prediction.") | |
| try: | |
| if latest_data_sequence_unscaled.shape[1] > 0: | |
| calculated_aqi_sequence = latest_data_sequence_unscaled[0, :, 0] | |
| approx_rolling_median_proxy = np.mean(calculated_aqi_sequence[-min(5, SEQUENCE_LENGTH):]) | |
| if pd.isna(approx_rolling_median_proxy) or approx_rolling_median_proxy <= 0: | |
| approx_rolling_median_proxy = 1.0 | |
| corresponding_rolling_median_scaler = np.full((1, request.n_ahead, 1), approx_rolling_median_proxy, dtype=np.float32) | |
| print(f"Approximated rolling median proxy for inverse transform: {approx_rolling_median_proxy:.2f}") | |
| y_unscaled_pred_ratio = target_scaler.inverse_transform(scaled_prediction.reshape(1, request.n_ahead, 1)) | |
| print(f"Inverse transformed to ratio scale. Shape: {y_unscaled_pred_ratio.shape}") | |
| predicted_aqi_values = y_unscaled_pred_ratio * corresponding_rolling_median_scaler | |
| predicted_aqi_values = predicted_aqi_values.flatten() | |
| else: | |
| print("Error: Input sequence is empty, cannot perform inverse transform.") | |
| raise ValueError("Input sequence is empty.") | |
| print(f"Final predicted AQI values: {predicted_aqi_values}") | |
| except Exception as e: | |
| print(f"Error during inverse transformation: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail="Error processing prediction results (inverse transform).") | |
| predictions_list = [] | |
| for i in range(request.n_ahead): | |
| timestamp_str = prediction_timestamps[i].strftime('%Y-%m-%d %H:%M:%S') | |
| predictions_list.append({ | |
| "timestamp": timestamp_str, | |
| "aqi": float(predicted_aqi_values[i]) | |
| }) | |
| return PredictionResponse(status="success", message="Prediction successful.", predictions=predictions_list) | |
| async def read_root(): | |
| return {"message": "AQI Prediction API is running."} |