from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager import uvicorn import pandas as pd import numpy as np import torch import json import io import joblib import os import sys import logging from model import DroughtNetLSTM from utils import normalize, date_encode, interpolate_nans from datetime import datetime from typing import List, Optional # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) # Lifespan event handler @asynccontextmanager async def lifespan(app: FastAPI): global model, scaler_dict, scaler_dict_static, device try: logger.info("Starting application initialization") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Load scalers with safety measures for version compatibility try: logger.info("Loading scalers") scaler_dict = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict.joblib")) scaler_dict_static = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict_static.joblib")) logger.info("Scalers loaded successfully") except Exception as e: logger.error(f"Error loading scalers: {str(e)}") # Provide fallback empty dictionaries if loading fails scaler_dict = {} scaler_dict_static = {} logger.warning("Using empty scalers as fallback") # Define model params logger.info("Initializing model") time_dim = 20 lstm_dim = 256 num_layers = 2 dropout = 0.15 static_dim = 29 staticfc_dim = 16 hidden_dim = 256 output_size = 6 model = DroughtNetLSTM( time_dim=time_dim, lstm_dim=lstm_dim, num_layers=num_layers, dropout=dropout, static_dim=static_dim, staticfc_dim=staticfc_dim, hidden_dim=hidden_dim, output_size=output_size ) try: model_path = os.path.join(os.path.dirname(__file__), "best_macro_f1_model.pt") logger.info(f"Loading model from {model_path}") model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() logger.info("Model loaded and initialized successfully") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise # Re-raise to prevent app from starting with broken model logger.info("Application initialization completed successfully") yield # Allow app to run logger.info("Application shutdown initiated") except Exception as e: logger.error(f"Critical error during initialization: {str(e)}") # Still yield to allow proper error handling yield logger.info("Application shutdown after initialization error") app = FastAPI( title="Drought Prediction API", description="API for predicting drought severity based on weather data", version="1.0.0", lifespan=lifespan ) # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): return {"message": "Welcome to Drought Prediction API. Use /predict endpoint to make predictions."} @app.get("/health") async def health(): """Simple health check endpoint""" return {"status": "ok", "model_loaded": model is not None} @app.post("/predict") async def predict( csv_file: UploadFile = File(...), x_static: str = Form(...), ): try: logger.info("Received prediction request") # Parse static input x_static_list = json.loads(x_static) x_static_array = np.array([x_static_list], dtype=np.float32) logger.info(f"Static data shape: {x_static_array.shape}") # Load and process CSV content = await csv_file.read() df = pd.read_csv(io.StringIO(content.decode('utf-8')), skiprows=26) logger.info(f"Loaded CSV with shape: {df.shape}") df = prepare_time_data(df) logger.info("Time data prepared successfully") # Feature extraction float_cols = [ 'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE', 'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE', 'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE', ] features = float_cols + ['sin_day', 'cos_day'] x_time_array = df[features].to_numpy(dtype=np.float32) x_time_array = np.expand_dims(x_time_array, axis=0) logger.info(f"Time features shape: {x_time_array.shape}") # Normalize try: x_static_norm, x_time_norm = normalize( x_static_array, x_time_array, scaler_dict=scaler_dict, scaler_dict_static=scaler_dict_static ) logger.info("Data normalized successfully") except Exception as norm_error: logger.error(f"Normalization error: {str(norm_error)}") # Fall back to using unnormalized data if normalization fails logger.warning("Using unnormalized data as fallback") x_static_norm = x_static_array x_time_norm = x_time_array # To tensors x_time_tensor = torch.tensor(x_time_norm).float().to(device) x_static_tensor = torch.tensor(x_static_norm).float().to(device) # Predict logger.info("Running prediction") with torch.no_grad(): output = model(x_time_tensor, x_static_tensor) output = torch.clamp(output, min=0.0, max=5.0) predictions = output.cpu().numpy().tolist()[0] logger.info(f"Prediction completed: {predictions}") drought_classes = { 0: "No Drought (D0)", 1: "Abnormally Dry (D1)", 2: "Moderate Drought (D2)", 3: "Severe Drought (D3)", 4: "Extreme Drought (D4)", 5: "Exceptional Drought (D5)" } result = { "raw_predictions": predictions, "max_class": { "class": int(np.argmax(predictions)), "label": drought_classes[int(np.argmax(predictions))], "confidence": float(np.max(predictions)) }, "class_probabilities": { drought_classes[i]: float(predictions[i]) for i in range(len(predictions)) } } logger.info("Returning prediction result") return JSONResponse(content=result) except Exception as e: logger.error(f"Prediction error: {str(e)}") raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") def prepare_time_data(df): try: if 'YEAR' not in df.columns or 'DOY' not in df.columns: if 'date' in df.columns: df['date'] = pd.to_datetime(df['date']) df['YEAR'] = df['date'].dt.year df['DOY'] = df['date'].dt.dayofyear else: raise ValueError("Input CSV must contain either 'date' column or both 'YEAR' and 'DOY' columns") if 'date' not in df.columns: df['date'] = pd.to_datetime(df['YEAR'].astype(str) + df['DOY'].astype(str), format="%Y%j") df[['sin_day', 'cos_day']] = df['date'].apply(lambda d: pd.Series(date_encode(d))) float_cols = [ 'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE', 'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE', 'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE', ] for col in float_cols: if col in df.columns and df[col].isna().any(): df[col] = interpolate_nans(df[col].values) return df except Exception as e: logger.error(f"Error preparing time data: {str(e)}") raise if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces sử dụng cổng 7860 uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)