project / app.py
dsid271's picture
Update app.py
00baf30 verified
raw
history blame
22 kB
# app.py (or main.py)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Input
from tensorflow.keras.utils import custom_object_scope
import pickle
import os
import requests
import pandas as pd
from datetime import datetime, timedelta, timezone
import pytz
import json
import traceback # Import traceback to print detailed error info
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Assuming TKAN is installed and available
from tkan import TKAN
try:
from tkat import TKAT
except ImportError:
print("TKAT library not found. If your model uses TKAT, ensure the library is installed.")
TKAT = None
# --- Your MinMaxScaler Class (Copied from Notebook) ---
# This class is essential for loading your scalers
class MinMaxScaler:
def __init__(self, feature_axis=None, minmax_range=(0, 1)):
self.feature_axis = feature_axis
self.min_ = None
self.max_ = None
self.scale_ = None
self.minmax_range = minmax_range
def fit(self, X):
if X.ndim == 3 and self.feature_axis is not None:
axis = tuple(i for i in range(X.ndim) if i != self.feature_axis)
self.min_ = np.min(X, axis=axis)
self.max_ = np.max(X, axis=axis)
elif X.ndim == 2:
self.min_ = np.min(X, axis=0)
self.max_ = np.max(X, axis=0)
elif X.ndim == 1:
self.min_ = np.min(X)
self.max_ = np.max(X)
else:
raise ValueError("Data must be 1D, 2D, or 3D.")
self.scale_ = self.max_ - self.min_
return self
def transform(self, X):
if self.min_ is None or self.max_ is None or self.scale_ is None:
raise ValueError("Scaler has not been fitted.")
X_scaled = (X - self.min_) / self.scale_
X_scaled = X_scaled * (self.minmax_range[1] - self.minmax_range[0]) + self.minmax_range[0]
return X_scaled
def inverse_transform(self, X_scaled):
if self.min_ is None or self.max_ is None or self.scale_ is None:
raise ValueError("Scaler has not been fitted.")
X = (X_scaled - self.minmax_range[0]) / (self.minmax_range[1] - self.minmax_range[0])
X = X * self.scale_ + self.min_
return X
# --- AQI breakpoints and calculation functions (Copied from Notebook) ---
aqi_breakpoints = {
'pm25': [(0, 50, 0, 50), (51, 100, 51, 100), (101, 200, 101, 200), (201, 300, 201, 300)],
'pm10': [(0, 50, 0, 50), (51, 100, 51, 100), (101, 250, 101, 200), (251, 350, 201, 300)],
'co': [(0, 1.0, 0, 50), (1.1, 2.0, 51, 100), (2.1, 10.0, 101, 200), (10.1, 17.0, 201, 300)]
}
def calculate_sub_aqi(concentration, breakpoints):
for i_low, i_high, c_low, c_high in breakpoints:
if c_low <= concentration <= c_high:
if c_high == c_low:
return i_low
return ((i_high - i_low) / (c_high - c_low)) * (concentration - c_low) + i_low
if concentration < breakpoints[0][2]:
return breakpoints[0][0]
elif concentration > breakpoints[-1][3]:
return breakpoints[-1][1]
else:
return np.nan
def calculate_overall_aqi(row, aqi_breakpoints):
sub_aqis = []
# Mapping API names to internal names if necessary
pollutant_mapping = {
'pm25': 'pm25',
'pm10': 'pm10',
'co': 'co',
'pm2_5': 'pm25', # Common API name for PM2.5
'carbon_monoxide': 'co', # Common API name for CO
}
for api_pollutant, internal_pollutant in pollutant_mapping.items():
if api_pollutant in row:
concentration = row[api_pollutant]
if not pd.isna(concentration): # Use pd.isna for pandas DataFrames/Series
sub_aqi = calculate_sub_aqi(concentration, aqi_breakpoints.get(internal_pollutant, []))
sub_aqis.append(sub_aqi)
else:
sub_aqis.append(np.nan)
else:
sub_aqis.append(np.nan)
# Use np.nanmax to find the maximum ignoring NaNs. Returns -inf if all are NaN.
# Check if sub_aqis list is not empty and contains at least one non-NaN value
if sub_aqis and not all(pd.isna(sub_aqis)):
return np.nanmax(sub_aqis)
else:
return np.nan # Return NaN if no valid pollutant data is available
# --- Data Retrieval Function ---
def get_latest_data_sequence(sequence_length: int, latitude: float, longitude: float):
print(f"Attempting to retrieve data for the last {sequence_length} hours from Open-Meteo for Lat: {latitude}, Lon: {longitude}")
end_time = datetime.now(pytz.utc)
# Fetch slightly more data to allow for resampling and ensure sequence_length is met
fetch_hours = sequence_length + 5
start_time = end_time - timedelta(hours=fetch_hours)
# Format timestamps for API request (ISO 8601)
start_time_str = start_time.isoformat().split('.')[0] + 'Z'
end_time_str = end_time.isoformat().split('.')[0] + 'Z'
print(f"Requesting data from {start_time_str} to {end_time_str}")
# Open-Meteo Air Quality API
air_quality_url = "https://air-quality-api.open-meteo.com/v1/air-quality"
air_quality_params = {
"latitude": latitude,
"longitude": longitude,
"hourly": ["pm2_5", "pm10", "carbon_monoxide"],
"timezone": "UTC",
"start_date": start_time.strftime('%Y-%m-%d'), # Use YYYY-MM-DD format
"end_date": end_time.strftime('%Y-%m-%d'),
"past_hours": fetch_hours
}
# Open-Meteo Historical Weather API for Temperature
weather_url = "https://archive-api.open-meteo.com/v1/archive"
weather_params = {
"latitude": latitude,
"longitude": longitude,
"hourly": ["temperature_2m"],
"timezone": "UTC",
"start_date": start_time.strftime('%Y-%m-%d'),
"end_date": end_time.strftime('%Y-%m-%d')
}
try:
# Fetch Air Quality Data
print(f"Fetching air quality data from: {air_quality_url}")
air_quality_response = requests.get(air_quality_url, params=air_quality_params)
air_quality_response.raise_for_status()
air_quality_data = air_quality_response.json()
print("Air quality data retrieved.")
# Fetch Temperature Data
print(f"Fetching temperature data from: {weather_url}")
weather_response = requests.get(weather_url, params=weather_params)
weather_response.raise_for_status()
weather_data = weather_response.json()
print("Temperature data retrieved.")
print("Data fetched successfully.")
# Process Air Quality Data
if 'hourly' not in air_quality_data or 'time' not in air_quality_data['hourly']:
print("Error: 'hourly' or 'time' key not found in air quality response.")
return None, "Error: Invalid air quality data format from API."
df_aq = pd.DataFrame(air_quality_data['hourly'])
df_aq['time'] = pd.to_datetime(df_aq['time'])
df_aq.set_index('time', inplace=True)
# Process Temperature Data
if 'hourly' not in weather_data or 'time' not in weather_data['hourly']:
print("Error: 'hourly' or 'time' key not found in weather response.")
return None, "Error: Invalid weather data format from API."
df_temp = pd.DataFrame(weather_data['hourly'])
df_temp['time'] = pd.to_datetime(df_temp['time'])
df_temp.set_index('time', inplace=True)
# Merge dataframes
df_merged = df_aq.merge(df_temp, left_index=True, right_index=True, how='outer')
print("DataFrames merged.")
# Resample to ensure consistent hourly frequency and fill missing data
# Use 'h' for hourly resampling
df_processed = df_merged.resample('h').ffill().bfill()
print(f"DataFrame resampled to hourly. Shape: {df_processed.shape}")
# Rename columns to match internal naming convention
df_processed.rename(columns={'pm2_5': 'pm25', 'carbon_monoxide': 'co', 'temperature_2m': 'temp'}, inplace=True)
print("Renamed columns.")
# Calculate AQI for the processed data
df_processed['calculated_aqi'] = df_processed.apply(lambda row: calculate_overall_aqi(row, aqi_breakpoints), axis=1)
print("Calculated AQI.")
# Select and reorder columns to match training data order
required_columns = ['calculated_aqi', 'temp', 'pm25', 'pm10', 'co']
# Ensure all required columns exist before selecting
if not all(col in df_processed.columns for col in required_columns):
missing_cols = [col for col in required_columns if col not in df_processed.columns]
print(f"Error: Missing required columns after processing: {missing_cols}")
return None, f"Error: Missing required data columns: {missing_cols}"
df_processed = df_processed[required_columns].copy()
print(f"Selected and reordered columns. Final processing shape: {df_processed.shape}")
# Handle any remaining NaNs after ffill/bfill (e.g., if the very first values were NaN or API returned all NaNs)
initial_rows = len(df_processed)
df_processed.dropna(inplace=True)
if len(df_processed) < initial_rows:
print(f"Warning: Dropped {initial_rows - len(df_processed)} rows with remaining NaNs.")
# Check if enough data points are available
if len(df_processed) < sequence_length:
print(f"Error: Only retrieved and processed {len(df_processed)} data points, but {sequence_length} are required.")
return None, f"Error: Insufficient historical data ({len(df_processed)} points available, {sequence_length} required)."
# Select the last `sequence_length` rows for the input sequence
latest_data_sequence_df = df_processed.tail(sequence_length).copy() # Use .copy() to avoid SettingWithCopyWarning
print(f"Selected last {sequence_length} data points.")
# Convert to numpy array and reshape (1, sequence_length, num_features)
latest_data_sequence = latest_data_sequence_df.values.reshape(1, sequence_length, len(required_columns))
# Get the timestamps for output formatting later
timestamps = latest_data_sequence_df.index.tolist()
print(f"Prepared input sequence with shape: {latest_data_sequence.shape}")
return latest_data_sequence, timestamps # Return data and timestamps
except requests.exceptions.RequestException as e:
print(f"API Request Error: {e}")
return None, f"API Request Error: {e}"
except Exception as e:
print(f"An unexpected error occurred during data retrieval and processing: {e}")
traceback.print_exc()
return None, f"An unexpected error occurred during data processing: {e}"
# --- Define paths to your saved files ---
# Use relative paths assuming files are in the root directory of the Space
MODEL_PATH = 'best_model_TKAN_nahead_1.keras'
INPUT_SCALER_PATH = 'input_scaler.pkl'
TARGET_SCALER_PATH = 'target_scaler.pkl' # This should be the scaler for the ratio
# Y_SCALER_TRAIN_PATH = 'y_scaler_train.pkl' # Keep commented out unless you find a specific use for it in the inverse transform
# --- Load the scalers and model ---
input_scaler = None
target_scaler = None # Scaler for the AQI/rolling_median ratio
model = None
try:
with open(INPUT_SCALER_PATH, 'rb') as f:
input_scaler = pickle.load(f)
print(f"Input scaler loaded successfully from {INPUT_SCALER_PATH}")
with open(TARGET_SCALER_PATH, 'rb') as f:
target_scaler = pickle.load(f)
print(f"Target scaler (for ratio) loaded successfully from {TARGET_SCALER_PATH}")
except FileNotFoundError as e:
print(f"Error loading scaler files: {e}")
print("Please ensure input_scaler.pkl and target_scaler.pkl are in the correct directory.")
# These need to be loaded for the app to work, so we might let the startup fail or raise an error here.
# For a web app, letting it fail on startup and show in logs is better than running with None scalers.
# However, for the purpose of giving you the code structure, we'll just print and model=None below.
except Exception as e:
print(f"An unexpected error occurred during scaler loading: {e}")
traceback.print_exc()
# Load the trained model with custom_object_scope
custom_objects = {"TKAN": TKAN}
if TKAT is not None:
custom_objects["TKAT"] = TKAT
try:
print(f"Loading model from {MODEL_PATH}...")
# Use custom_object_scope to register custom layers during loading
with custom_object_scope(custom_objects):
# compile=False because we only need the model for inference
model = load_model(MODEL_PATH, compile=False)
print("Model loaded successfully.")
except FileNotFoundError:
print(f"Error: Model file not found at {MODEL_PATH}.")
except ValueError as e:
print(f"Error loading model (ValueError): {e}")
print("This can happen if the file is not a valid Keras file or if custom objects are not registered.")
traceback.print_exc()
except Exception as e:
print(f"An unexpected error occurred during model loading: {e}")
traceback.print_exc()
# Initialize FastAPI app
app = FastAPI()
# Define the structure of the prediction request body
class PredictionRequest(BaseModel):
latitude: float
longitude: float
pm25: float = None # Make current inputs optional, rely primarily on historical fetch
pm10: float = None
co: float = None
temp: float = None
n_ahead: int = 1 # Default prediction steps
# Define the structure of the prediction response body
class PredictionResponse(BaseModel):
status: str # "success" or "error"
message: str # Description of the result or error
predictions: list = None # List of {"timestamp": "...", "aqi": ...} or None on error
# Define the prediction endpoint
@app.post("/predict", response_model=PredictionResponse)
async def predict_aqi_endpoint(request: PredictionRequest):
# Check if model and scalers were loaded successfully on startup
if model is None or input_scaler is None or target_scaler is None:
print("API called but model or scalers are not loaded.")
# Return a 500 Internal Server Error if dependencies failed to load
raise HTTPException(status_code=500, detail="Model or scalers not loaded. Check server logs for details.")
# Get the expected sequence length and number of features from the model's input shape
# Assuming input shape is (None, sequence_length, num_features)
if model.input_shape is None or len(model.input_shape) < 2:
print(f"Error: Model has unexpected input shape: {model.input_shape}")
raise HTTPException(status_code=500, detail=f"Model has unexpected input shape: {model.input_shape}")
SEQUENCE_LENGTH = model.input_shape[1]
NUM_FEATURES = model.input_shape[2]
required_num_features = len(['calculated_aqi', 'temp', 'pm25', 'pm10', 'co'])
if NUM_FEATURES != required_num_features:
print(f"Error: Model expects {NUM_FEATURES} features, but data processing provides {required_num_features}.")
raise HTTPException(status_code=500, detail=f"Model expects {NUM_FEATURES} features, but data processing provides {required_num_features}.")
# Get the historical data sequence and its timestamps from Open-Meteo
# The function now returns the data and a message (or error)
latest_data_sequence_unscaled, message = get_latest_data_sequence(SEQUENCE_LENGTH, request.latitude, request.longitude)
# Check if data retrieval was successful
if latest_data_sequence_unscaled is None:
# Return an error response if data fetching failed
print(f"Data retrieval failed: {message}")
return PredictionResponse(status="error", message=f"Data retrieval failed: {message}")
# The timestamps returned are for the sequence itself. We need timestamps for the *predictions*.
# The predictions are for n_ahead steps *after* the last timestamp in the sequence.
prediction_timestamps = []
if message and isinstance(message, list) and len(message) > 0: # 'message' is actually 'timestamps' here
last_timestamp_of_sequence = message[-1] # Get the last timestamp from the sequence
for i in range(request.n_ahead):
# Prediction i (0-indexed) is for hour i+1 after the last timestamp
prediction_timestamps.append(last_timestamp_of_sequence + timedelta(hours=i + 1))
else:
print("Warning: Could not get valid timestamps from data retrieval. Prediction timestamps will be approximate.")
# Fallback: Approximate timestamps based on current time
now_utc = datetime.now(pytz.utc)
for i in range(request.n_ahead):
prediction_timestamps.append(now_utc + timedelta(hours=i+1))
# Optional: Update the last timestep with current user inputs if provided
# Check if current inputs were provided and are valid (not None or NaN)
if request.pm25 is not None and not pd.isna(request.pm25) and \
request.pm10 is not None and not pd.isna(request.pm10) and \
request.co is not None and not pd.isna(request.co) and \
request.temp is not None and not pd.isna(request.temp):
current_aqi = calculate_overall_aqi({'pm25': request.pm25, 'pm10': request.pm10, 'co': request.co, 'temp': request.temp}, aqi_breakpoints)
if not pd.isna(current_aqi):
# Assuming column order: 'calculated_aqi', 'temp', 'pm25', 'pm10', 'co'
# Update the last row (-1) of the input sequence
latest_data_sequence_unscaled[0, -1, 0] = current_aqi
latest_data_sequence_unscaled[0, -1, 1] = request.temp
latest_data_sequence_unscaled[0, -1, 2] = request.pm25
latest_data_sequence_unscaled[0, -1, 3] = request.pm10
latest_data_sequence_unscaled[0, -1, 4] = request.co
print("Updated last timestep of input sequence with current user inputs.")
else:
print("Warning: Could not calculate AQI for current inputs. Last timestep remains historical.")
# Scale the input data
try:
X_scaled = input_scaler.transform(latest_data_sequence_unscaled)
print("Input data scaled successfully.")
except Exception as e:
print(f"Error scaling input data: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail="Error processing input data for prediction (scaling).")
# Make prediction
try:
scaled_prediction = model.predict(X_scaled, verbose=0) # Shape (1, n_ahead)
print(f"Model prediction made. Scaled prediction shape: {scaled_prediction.shape}")
except Exception as e:
print(f"Error during model prediction: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail="Error during model prediction.")
# Inverse transform the prediction
try:
# --- Inverse Transformation Logic (Based on Rolling Median Scaling) ---
# This part needs the actual rolling median for the future prediction timesteps.
# Using an approximation based on the input sequence.
if latest_data_sequence_unscaled.shape[1] > 0:
# Get the 'calculated_aqi' values from the unscaled input sequence
calculated_aqi_sequence = latest_data_sequence_unscaled[0, :, 0] # Assuming AQI is the first feature
# Approximate the rolling median based on the last few points of the input sequence
# This is a simple approximation. A more robust method might be needed.
approx_rolling_median_proxy = np.mean(calculated_aqi_sequence[-min(5, SEQUENCE_LENGTH):])
if pd.isna(approx_rolling_median_proxy) or approx_rolling_median_proxy <= 0:
approx_rolling_median_proxy = 1.0 # Prevent division by zero or invalid scaling
# Create a placeholder scaler array for the future timesteps
corresponding_rolling_median_scaler = np.full((1, request.n_ahead, 1), approx_rolling_median_proxy, dtype=np.float32)
print(f"Approximated rolling median proxy for inverse transform: {approx_rolling_median_proxy:.2f}")
# 1. Inverse transform the scaled prediction (ratio) using the target_scaler
y_unscaled_pred_ratio = target_scaler.inverse_transform(scaled_prediction.reshape(1, request.n_ahead, 1))
print(f"Inverse transformed to ratio scale. Shape: {y_unscaled_pred_ratio.shape}")
# 2. Multiply the unscaled ratio by the approximated rolling median scaler
predicted_aqi_values = y_unscaled_pred_ratio * corresponding_rolling_median_scaler
predicted_aqi_values = predicted_aqi_values.flatten() # Shape (n_ahead,)
else:
print("Error: Input sequence is empty, cannot perform inverse transform.")
raise ValueError("Input sequence is empty.")
print(f"Final predicted AQI values: {predicted_aqi_values}")
except Exception as e:
print(f"Error during inverse transformation: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail="Error processing prediction results (inverse transform).")
# Prepare the prediction output list
predictions_list = []
for i in range(request.n_ahead):
# Use the calculated prediction_timestamps
timestamp_str = prediction_timestamps[i].strftime('%Y-%m-%d %H:%M:%S')
predictions_list.append({
"timestamp": timestamp_str,
"aqi": float(predicted_aqi_values[i]) # Ensure AQI is a standard float
})
# Return the successful response
return PredictionResponse(status="success", message="Prediction successful.", predictions=predictions_list)
# Root endpoint for health check
@app.get("/")
async def read_root():
return {"message": "AQI Prediction API is running."}