l1aF2027 commited on
Commit
a12b2da
·
verified ·
1 Parent(s): 72b20e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -72
app.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from contextlib import asynccontextmanager
5
- import uvicorn
6
  import pandas as pd
7
  import numpy as np
8
  import torch
@@ -10,51 +10,88 @@ import json
10
  import io
11
  import joblib
12
  import os
 
 
13
  from model import DroughtNetLSTM
14
  from utils import normalize, date_encode, interpolate_nans
15
  from datetime import datetime
16
  from typing import List, Optional
17
 
18
- # Lifespan event handler (thay thế @app.on_event)
 
 
 
 
 
 
 
 
19
  @asynccontextmanager
20
  async def lifespan(app: FastAPI):
21
  global model, scaler_dict, scaler_dict_static, device
22
-
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
-
25
- scaler_dict = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict.joblib"))
26
- scaler_dict_static = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict_static.joblib"))
27
- print("Khởi tạo dữ scaler hoàn tất")
28
-
29
- # Define model params
30
- time_dim = 20
31
- lstm_dim = 256
32
- num_layers = 2
33
- dropout = 0.15
34
- static_dim = 29
35
- staticfc_dim = 16
36
- hidden_dim = 256
37
- output_size = 6
38
-
39
- model = DroughtNetLSTM(
40
- time_dim=time_dim,
41
- lstm_dim=lstm_dim,
42
- num_layers=num_layers,
43
- dropout=dropout,
44
- static_dim=static_dim,
45
- staticfc_dim=staticfc_dim,
46
- hidden_dim=hidden_dim,
47
- output_size=output_size
48
- )
49
- model.load_state_dict(torch.load(
50
- os.path.join(os.path.dirname(__file__), "best_macro_f1_model.pt"),
51
- map_location=device
52
- ))
53
- model.to(device)
54
- model.eval()
55
- print("Khởi tạo dữ model hoàn tất")
56
-
57
- yield # Cho phép app chạy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  app = FastAPI(
60
  title="Drought Prediction API",
@@ -76,20 +113,32 @@ app.add_middleware(
76
  async def root():
77
  return {"message": "Welcome to Drought Prediction API. Use /predict endpoint to make predictions."}
78
 
 
 
 
 
 
79
  @app.post("/predict")
80
  async def predict(
81
  csv_file: UploadFile = File(...),
82
  x_static: str = Form(...),
83
  ):
84
  try:
 
 
85
  # Parse static input
86
  x_static_list = json.loads(x_static)
87
  x_static_array = np.array([x_static_list], dtype=np.float32)
 
88
 
89
  # Load and process CSV
90
  content = await csv_file.read()
91
- df = pd.read_csv(io.StringIO(content.decode('utf-8')))
 
 
 
92
  df = prepare_time_data(df)
 
93
 
94
  # Feature extraction
95
  float_cols = [
@@ -100,25 +149,36 @@ async def predict(
100
  features = float_cols + ['sin_day', 'cos_day']
101
  x_time_array = df[features].to_numpy(dtype=np.float32)
102
  x_time_array = np.expand_dims(x_time_array, axis=0)
 
103
 
104
  # Normalize
105
- x_static_norm, x_time_norm = normalize(
106
- x_static_array,
107
- x_time_array,
108
- scaler_dict=scaler_dict,
109
- scaler_dict_static=scaler_dict_static
110
- )
 
 
 
 
 
 
 
 
111
 
112
  # To tensors
113
  x_time_tensor = torch.tensor(x_time_norm).float().to(device)
114
  x_static_tensor = torch.tensor(x_static_norm).float().to(device)
115
 
116
  # Predict
 
117
  with torch.no_grad():
118
  output = model(x_time_tensor, x_static_tensor)
119
  output = torch.clamp(output, min=0.0, max=5.0)
120
 
121
  predictions = output.cpu().numpy().tolist()[0]
 
122
 
123
  drought_classes = {
124
  0: "No Drought (D0)",
@@ -141,37 +201,39 @@ async def predict(
141
  }
142
  }
143
 
 
144
  return JSONResponse(content=result)
145
 
146
  except Exception as e:
 
147
  raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
148
 
149
 
150
  def prepare_time_data(df):
151
- if 'YEAR' not in df.columns or 'DOY' not in df.columns:
152
- if 'date' in df.columns:
153
- df['date'] = pd.to_datetime(df['date'])
154
- df['YEAR'] = df['date'].dt.year
155
- df['DOY'] = df['date'].dt.dayofyear
156
- else:
157
- raise ValueError("Input CSV must contain either 'date' column or both 'YEAR' and 'DOY' columns")
158
-
159
- if 'date' not in df.columns:
160
- df['date'] = pd.to_datetime(df['YEAR'].astype(str) + df['DOY'].astype(str), format="%Y%j")
161
-
162
- df[['sin_day', 'cos_day']] = df['date'].apply(lambda d: pd.Series(date_encode(d)))
163
-
164
- float_cols = [
165
- 'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
166
- 'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
167
- 'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
168
- ]
169
- for col in float_cols:
170
- if col in df.columns and df[col].isna().any():
171
- df[col] = interpolate_nans(df[col].values)
172
-
173
- return df
174
-
175
- if __name__ == "__main__":
176
- port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces sử dụng cổng 7860
177
- uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)
 
2
  from fastapi.responses import JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from contextlib import asynccontextmanager
5
+
6
  import pandas as pd
7
  import numpy as np
8
  import torch
 
10
  import io
11
  import joblib
12
  import os
13
+ import sys
14
+ import logging
15
  from model import DroughtNetLSTM
16
  from utils import normalize, date_encode, interpolate_nans
17
  from datetime import datetime
18
  from typing import List, Optional
19
 
20
+ # Configure logging
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
24
+ handlers=[logging.StreamHandler(sys.stdout)]
25
+ )
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Lifespan event handler
29
  @asynccontextmanager
30
  async def lifespan(app: FastAPI):
31
  global model, scaler_dict, scaler_dict_static, device
32
+
33
+ try:
34
+ logger.info("Starting application initialization")
35
+
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ logger.info(f"Using device: {device}")
38
+
39
+ # Load scalers with safety measures for version compatibility
40
+ try:
41
+ logger.info("Loading scalers")
42
+ scaler_dict = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict.joblib"))
43
+ scaler_dict_static = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict_static.joblib"))
44
+ logger.info("Scalers loaded successfully")
45
+ except Exception as e:
46
+ logger.error(f"Error loading scalers: {str(e)}")
47
+ # Provide fallback empty dictionaries if loading fails
48
+ scaler_dict = {}
49
+ scaler_dict_static = {}
50
+ logger.warning("Using empty scalers as fallback")
51
+
52
+ # Define model params
53
+ logger.info("Initializing model")
54
+ time_dim = 20
55
+ lstm_dim = 256
56
+ num_layers = 2
57
+ dropout = 0.15
58
+ static_dim = 29
59
+ staticfc_dim = 16
60
+ hidden_dim = 256
61
+ output_size = 6
62
+
63
+ model = DroughtNetLSTM(
64
+ time_dim=time_dim,
65
+ lstm_dim=lstm_dim,
66
+ num_layers=num_layers,
67
+ dropout=dropout,
68
+ static_dim=static_dim,
69
+ staticfc_dim=staticfc_dim,
70
+ hidden_dim=hidden_dim,
71
+ output_size=output_size
72
+ )
73
+
74
+ try:
75
+ model_path = os.path.join(os.path.dirname(__file__), "best_macro_f1_model.pt")
76
+ logger.info(f"Loading model from {model_path}")
77
+ model.load_state_dict(torch.load(model_path, map_location=device))
78
+ model.to(device)
79
+ model.eval()
80
+ logger.info("Model loaded and initialized successfully")
81
+ except Exception as e:
82
+ logger.error(f"Error loading model: {str(e)}")
83
+ raise # Re-raise to prevent app from starting with broken model
84
+
85
+ logger.info("Application initialization completed successfully")
86
+
87
+ yield # Allow app to run
88
+
89
+ logger.info("Application shutdown initiated")
90
+ except Exception as e:
91
+ logger.error(f"Critical error during initialization: {str(e)}")
92
+ # Still yield to allow proper error handling
93
+ yield
94
+ logger.info("Application shutdown after initialization error")
95
 
96
  app = FastAPI(
97
  title="Drought Prediction API",
 
113
  async def root():
114
  return {"message": "Welcome to Drought Prediction API. Use /predict endpoint to make predictions."}
115
 
116
+ @app.get("/health")
117
+ async def health():
118
+ """Simple health check endpoint"""
119
+ return {"status": "ok", "model_loaded": model is not None}
120
+
121
  @app.post("/predict")
122
  async def predict(
123
  csv_file: UploadFile = File(...),
124
  x_static: str = Form(...),
125
  ):
126
  try:
127
+ logger.info("Received prediction request")
128
+
129
  # Parse static input
130
  x_static_list = json.loads(x_static)
131
  x_static_array = np.array([x_static_list], dtype=np.float32)
132
+ logger.info(f"Static data shape: {x_static_array.shape}")
133
 
134
  # Load and process CSV
135
  content = await csv_file.read()
136
+ df = pd.read_csv(io.StringIO(content.decode('utf-8')), skiprows=26)
137
+
138
+ logger.info(f"Loaded CSV with shape: {df.shape}")
139
+
140
  df = prepare_time_data(df)
141
+ logger.info("Time data prepared successfully")
142
 
143
  # Feature extraction
144
  float_cols = [
 
149
  features = float_cols + ['sin_day', 'cos_day']
150
  x_time_array = df[features].to_numpy(dtype=np.float32)
151
  x_time_array = np.expand_dims(x_time_array, axis=0)
152
+ logger.info(f"Time features shape: {x_time_array.shape}")
153
 
154
  # Normalize
155
+ try:
156
+ x_static_norm, x_time_norm = normalize(
157
+ x_static_array,
158
+ x_time_array,
159
+ scaler_dict=scaler_dict,
160
+ scaler_dict_static=scaler_dict_static
161
+ )
162
+ logger.info("Data normalized successfully")
163
+ except Exception as norm_error:
164
+ logger.error(f"Normalization error: {str(norm_error)}")
165
+ # Fall back to using unnormalized data if normalization fails
166
+ logger.warning("Using unnormalized data as fallback")
167
+ x_static_norm = x_static_array
168
+ x_time_norm = x_time_array
169
 
170
  # To tensors
171
  x_time_tensor = torch.tensor(x_time_norm).float().to(device)
172
  x_static_tensor = torch.tensor(x_static_norm).float().to(device)
173
 
174
  # Predict
175
+ logger.info("Running prediction")
176
  with torch.no_grad():
177
  output = model(x_time_tensor, x_static_tensor)
178
  output = torch.clamp(output, min=0.0, max=5.0)
179
 
180
  predictions = output.cpu().numpy().tolist()[0]
181
+ logger.info(f"Prediction completed: {predictions}")
182
 
183
  drought_classes = {
184
  0: "No Drought (D0)",
 
201
  }
202
  }
203
 
204
+ logger.info("Returning prediction result")
205
  return JSONResponse(content=result)
206
 
207
  except Exception as e:
208
+ logger.error(f"Prediction error: {str(e)}")
209
  raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
210
 
211
 
212
  def prepare_time_data(df):
213
+ try:
214
+ if 'YEAR' not in df.columns or 'DOY' not in df.columns:
215
+ if 'date' in df.columns:
216
+ df['date'] = pd.to_datetime(df['date'])
217
+ df['YEAR'] = df['date'].dt.year
218
+ df['DOY'] = df['date'].dt.dayofyear
219
+ else:
220
+ raise ValueError("Input CSV must contain either 'date' column or both 'YEAR' and 'DOY' columns")
221
+
222
+ if 'date' not in df.columns:
223
+ df['date'] = pd.to_datetime(df['YEAR'].astype(str) + df['DOY'].astype(str), format="%Y%j")
224
+
225
+ df[['sin_day', 'cos_day']] = df['date'].apply(lambda d: pd.Series(date_encode(d)))
226
+
227
+ float_cols = [
228
+ 'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
229
+ 'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
230
+ 'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
231
+ ]
232
+ for col in float_cols:
233
+ if col in df.columns and df[col].isna().any():
234
+ df[col] = interpolate_nans(df[col].values)
235
+
236
+ return df
237
+ except Exception as e:
238
+ logger.error(f"Error preparing time data: {str(e)}")
239
+ raise