l1aF2027 commited on
Commit
8993367
·
verified ·
1 Parent(s): 6b76727

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -71
app.py CHANGED
@@ -1,6 +1,8 @@
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
 
 
4
  import pandas as pd
5
  import numpy as np
6
  import torch
@@ -13,44 +15,27 @@ from utils import normalize, date_encode, interpolate_nans
13
  from datetime import datetime
14
  from typing import List, Optional
15
 
16
- app = FastAPI(
17
- title="Drought Prediction API",
18
- description="API for predicting drought severity based on weather data",
19
- version="1.0.0"
20
- )
21
-
22
- # Enable CORS
23
- app.add_middleware(
24
- CORSMiddleware,
25
- allow_origins=["*"],
26
- allow_credentials=True,
27
- allow_methods=["*"],
28
- allow_headers=["*"],
29
- )
30
-
31
- # Load model and scalers
32
- @app.on_event("startup")
33
- async def load_model():
34
  global model, scaler_dict, scaler_dict_static, device
35
-
36
- # Set device
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
-
39
- # Load scalers
40
  scaler_dict = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict.joblib"))
41
  scaler_dict_static = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict_static.joblib"))
42
-
43
- # Define model parameters
44
- time_dim = 20 # Number of features in time data (18 weather + 2 date encoding)
 
45
  lstm_dim = 256
46
  num_layers = 2
47
  dropout = 0.15
48
- static_dim = 29 # Number of features in static data
49
  staticfc_dim = 16
50
  hidden_dim = 256
51
- output_size = 6 # Output classes
52
- print("Khởi tạo dữ scaler hoàn tất")
53
- # Initialize model
54
  model = DroughtNetLSTM(
55
  time_dim=time_dim,
56
  lstm_dim=lstm_dim,
@@ -61,16 +46,31 @@ async def load_model():
61
  hidden_dim=hidden_dim,
62
  output_size=output_size
63
  )
64
-
65
- # Load model weights
66
  model.load_state_dict(torch.load(
67
  os.path.join(os.path.dirname(__file__), "best_macro_f1_model.pt"),
68
  map_location=device
69
  ))
70
- print("Khởi tạo dữ model hoàn tất")
71
-
72
  model.to(device)
73
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  @app.get("/")
76
  async def root():
@@ -82,50 +82,44 @@ async def predict(
82
  x_static: str = Form(...),
83
  ):
84
  try:
85
- # Parse x_static from JSON string to list
86
  x_static_list = json.loads(x_static)
87
  x_static_array = np.array([x_static_list], dtype=np.float32)
88
-
89
- # Read CSV file
90
  content = await csv_file.read()
91
  df = pd.read_csv(io.StringIO(content.decode('utf-8')))
92
-
93
- # Prepare time data
94
  df = prepare_time_data(df)
95
-
96
- # Get features
97
  float_cols = [
98
  'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
99
  'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
100
  'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
101
  ]
102
-
103
  features = float_cols + ['sin_day', 'cos_day']
104
  x_time_array = df[features].to_numpy(dtype=np.float32)
105
  x_time_array = np.expand_dims(x_time_array, axis=0)
106
-
107
- # Normalize data
108
  x_static_norm, x_time_norm = normalize(
109
- x_static_array,
110
- x_time_array,
111
- scaler_dict=scaler_dict,
112
  scaler_dict_static=scaler_dict_static
113
  )
114
-
115
- # Convert to tensors
116
  x_time_tensor = torch.tensor(x_time_norm).float().to(device)
117
  x_static_tensor = torch.tensor(x_static_norm).float().to(device)
118
-
119
  # Predict
120
  with torch.no_grad():
121
  output = model(x_time_tensor, x_static_tensor)
122
- # Clamp output to [0, 5]
123
  output = torch.clamp(output, min=0.0, max=5.0)
124
-
125
- # Convert to list
126
  predictions = output.cpu().numpy().tolist()[0]
127
-
128
- # Create result with class interpretations
129
  drought_classes = {
130
  0: "No Drought (D0)",
131
  1: "Abnormally Dry (D1)",
@@ -134,7 +128,7 @@ async def predict(
134
  4: "Extreme Drought (D4)",
135
  5: "Exceptional Drought (D5)"
136
  }
137
-
138
  result = {
139
  "raw_predictions": predictions,
140
  "max_class": {
@@ -146,46 +140,38 @@ async def predict(
146
  drought_classes[i]: float(predictions[i]) for i in range(len(predictions))
147
  }
148
  }
149
-
150
  return JSONResponse(content=result)
151
-
152
  except Exception as e:
153
  raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
154
 
 
155
  def prepare_time_data(df):
156
- """
157
- Prepare time-series data for the model.
158
- """
159
- # Ensure we have YEAR and DOY columns
160
  if 'YEAR' not in df.columns or 'DOY' not in df.columns:
161
- # Try to extract from date column if it exists
162
  if 'date' in df.columns:
163
  df['date'] = pd.to_datetime(df['date'])
164
  df['YEAR'] = df['date'].dt.year
165
  df['DOY'] = df['date'].dt.dayofyear
166
  else:
167
  raise ValueError("Input CSV must contain either 'date' column or both 'YEAR' and 'DOY' columns")
168
-
169
- # Create date column if it doesn't exist
170
  if 'date' not in df.columns:
171
  df['date'] = pd.to_datetime(df['YEAR'].astype(str) + df['DOY'].astype(str), format="%Y%j")
172
-
173
- # Apply date encoding to create sin_day and cos_day columns
174
  df[['sin_day', 'cos_day']] = df['date'].apply(lambda d: pd.Series(date_encode(d)))
175
-
176
- # Handle missing values if any
177
  float_cols = [
178
  'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
179
  'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
180
  'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
181
  ]
182
-
183
  for col in float_cols:
184
  if col in df.columns and df[col].isna().any():
185
  df[col] = interpolate_nans(df[col].values)
186
-
187
  return df
188
 
189
  if __name__ == "__main__":
190
  import uvicorn
191
- uvicorn.run("app:app", host="0.0.0.0", port=8000)
 
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from contextlib import asynccontextmanager
5
+
6
  import pandas as pd
7
  import numpy as np
8
  import torch
 
15
  from datetime import datetime
16
  from typing import List, Optional
17
 
18
+ # Lifespan event handler (thay thế @app.on_event)
19
+ @asynccontextmanager
20
+ async def lifespan(app: FastAPI):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  global model, scaler_dict, scaler_dict_static, device
22
+
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
 
25
  scaler_dict = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict.joblib"))
26
  scaler_dict_static = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict_static.joblib"))
27
+ print("Khởi tạo dữ scaler hoàn tất")
28
+
29
+ # Define model params
30
+ time_dim = 20
31
  lstm_dim = 256
32
  num_layers = 2
33
  dropout = 0.15
34
+ static_dim = 29
35
  staticfc_dim = 16
36
  hidden_dim = 256
37
+ output_size = 6
38
+
 
39
  model = DroughtNetLSTM(
40
  time_dim=time_dim,
41
  lstm_dim=lstm_dim,
 
46
  hidden_dim=hidden_dim,
47
  output_size=output_size
48
  )
 
 
49
  model.load_state_dict(torch.load(
50
  os.path.join(os.path.dirname(__file__), "best_macro_f1_model.pt"),
51
  map_location=device
52
  ))
 
 
53
  model.to(device)
54
  model.eval()
55
+ print("Khởi tạo dữ model hoàn tất")
56
+
57
+ yield # Cho phép app chạy
58
+
59
+ app = FastAPI(
60
+ title="Drought Prediction API",
61
+ description="API for predicting drought severity based on weather data",
62
+ version="1.0.0",
63
+ lifespan=lifespan
64
+ )
65
+
66
+ # Enable CORS
67
+ app.add_middleware(
68
+ CORSMiddleware,
69
+ allow_origins=["*"],
70
+ allow_credentials=True,
71
+ allow_methods=["*"],
72
+ allow_headers=["*"],
73
+ )
74
 
75
  @app.get("/")
76
  async def root():
 
82
  x_static: str = Form(...),
83
  ):
84
  try:
85
+ # Parse static input
86
  x_static_list = json.loads(x_static)
87
  x_static_array = np.array([x_static_list], dtype=np.float32)
88
+
89
+ # Load and process CSV
90
  content = await csv_file.read()
91
  df = pd.read_csv(io.StringIO(content.decode('utf-8')))
 
 
92
  df = prepare_time_data(df)
93
+
94
+ # Feature extraction
95
  float_cols = [
96
  'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
97
  'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
98
  'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
99
  ]
 
100
  features = float_cols + ['sin_day', 'cos_day']
101
  x_time_array = df[features].to_numpy(dtype=np.float32)
102
  x_time_array = np.expand_dims(x_time_array, axis=0)
103
+
104
+ # Normalize
105
  x_static_norm, x_time_norm = normalize(
106
+ x_static_array,
107
+ x_time_array,
108
+ scaler_dict=scaler_dict,
109
  scaler_dict_static=scaler_dict_static
110
  )
111
+
112
+ # To tensors
113
  x_time_tensor = torch.tensor(x_time_norm).float().to(device)
114
  x_static_tensor = torch.tensor(x_static_norm).float().to(device)
115
+
116
  # Predict
117
  with torch.no_grad():
118
  output = model(x_time_tensor, x_static_tensor)
 
119
  output = torch.clamp(output, min=0.0, max=5.0)
120
+
 
121
  predictions = output.cpu().numpy().tolist()[0]
122
+
 
123
  drought_classes = {
124
  0: "No Drought (D0)",
125
  1: "Abnormally Dry (D1)",
 
128
  4: "Extreme Drought (D4)",
129
  5: "Exceptional Drought (D5)"
130
  }
131
+
132
  result = {
133
  "raw_predictions": predictions,
134
  "max_class": {
 
140
  drought_classes[i]: float(predictions[i]) for i in range(len(predictions))
141
  }
142
  }
143
+
144
  return JSONResponse(content=result)
145
+
146
  except Exception as e:
147
  raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
148
 
149
+
150
  def prepare_time_data(df):
 
 
 
 
151
  if 'YEAR' not in df.columns or 'DOY' not in df.columns:
 
152
  if 'date' in df.columns:
153
  df['date'] = pd.to_datetime(df['date'])
154
  df['YEAR'] = df['date'].dt.year
155
  df['DOY'] = df['date'].dt.dayofyear
156
  else:
157
  raise ValueError("Input CSV must contain either 'date' column or both 'YEAR' and 'DOY' columns")
158
+
 
159
  if 'date' not in df.columns:
160
  df['date'] = pd.to_datetime(df['YEAR'].astype(str) + df['DOY'].astype(str), format="%Y%j")
161
+
 
162
  df[['sin_day', 'cos_day']] = df['date'].apply(lambda d: pd.Series(date_encode(d)))
163
+
 
164
  float_cols = [
165
  'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
166
  'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
167
  'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
168
  ]
 
169
  for col in float_cols:
170
  if col in df.columns and df[col].isna().any():
171
  df[col] = interpolate_nans(df[col].values)
172
+
173
  return df
174
 
175
  if __name__ == "__main__":
176
  import uvicorn
177
+ uvicorn.run("app:app", host="0.0.0.0", port=8000)