l1aF2027 commited on
Commit
6b76727
·
verified ·
1 Parent(s): 55aaccc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -189
app.py CHANGED
@@ -1,190 +1,191 @@
1
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
- from fastapi.responses import JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- import pandas as pd
5
- import numpy as np
6
- import torch
7
- import json
8
- import io
9
- import joblib
10
- import os
11
- from model import DroughtNetLSTM
12
- from utils import normalize, date_encode, interpolate_nans
13
- from datetime import datetime
14
- from typing import List, Optional
15
-
16
- app = FastAPI(
17
- title="Drought Prediction API",
18
- description="API for predicting drought severity based on weather data",
19
- version="1.0.0"
20
- )
21
-
22
- # Enable CORS
23
- app.add_middleware(
24
- CORSMiddleware,
25
- allow_origins=["*"],
26
- allow_credentials=True,
27
- allow_methods=["*"],
28
- allow_headers=["*"],
29
- )
30
-
31
- # Load model and scalers
32
- @app.on_event("startup")
33
- async def load_model():
34
- global model, scaler_dict, scaler_dict_static, device
35
-
36
- # Set device
37
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
-
39
- # Load scalers
40
- scaler_dict = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict.joblib"))
41
- scaler_dict_static = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict_static.joblib"))
42
-
43
- # Define model parameters
44
- time_dim = 20 # Number of features in time data (18 weather + 2 date encoding)
45
- lstm_dim = 256
46
- num_layers = 2
47
- dropout = 0.15
48
- static_dim = 29 # Number of features in static data
49
- staticfc_dim = 16
50
- hidden_dim = 256
51
- output_size = 6 # Output classes
52
-
53
- # Initialize model
54
- model = DroughtNetLSTM(
55
- time_dim=time_dim,
56
- lstm_dim=lstm_dim,
57
- num_layers=num_layers,
58
- dropout=dropout,
59
- static_dim=static_dim,
60
- staticfc_dim=staticfc_dim,
61
- hidden_dim=hidden_dim,
62
- output_size=output_size
63
- )
64
-
65
- # Load model weights
66
- model.load_state_dict(torch.load(
67
- os.path.join(os.path.dirname(__file__), "best_macro_f1_model.pt"),
68
- map_location=device
69
- ))
70
-
71
- model.to(device)
72
- model.eval()
73
-
74
- @app.get("/")
75
- async def root():
76
- return {"message": "Welcome to Drought Prediction API. Use /predict endpoint to make predictions."}
77
-
78
- @app.post("/predict")
79
- async def predict(
80
- csv_file: UploadFile = File(...),
81
- x_static: str = Form(...),
82
- ):
83
- try:
84
- # Parse x_static from JSON string to list
85
- x_static_list = json.loads(x_static)
86
- x_static_array = np.array([x_static_list], dtype=np.float32)
87
-
88
- # Read CSV file
89
- content = await csv_file.read()
90
- df = pd.read_csv(io.StringIO(content.decode('utf-8')))
91
-
92
- # Prepare time data
93
- df = prepare_time_data(df)
94
-
95
- # Get features
96
- float_cols = [
97
- 'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
98
- 'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
99
- 'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
100
- ]
101
-
102
- features = float_cols + ['sin_day', 'cos_day']
103
- x_time_array = df[features].to_numpy(dtype=np.float32)
104
- x_time_array = np.expand_dims(x_time_array, axis=0)
105
-
106
- # Normalize data
107
- x_static_norm, x_time_norm = normalize(
108
- x_static_array,
109
- x_time_array,
110
- scaler_dict=scaler_dict,
111
- scaler_dict_static=scaler_dict_static
112
- )
113
-
114
- # Convert to tensors
115
- x_time_tensor = torch.tensor(x_time_norm).float().to(device)
116
- x_static_tensor = torch.tensor(x_static_norm).float().to(device)
117
-
118
- # Predict
119
- with torch.no_grad():
120
- output = model(x_time_tensor, x_static_tensor)
121
- # Clamp output to [0, 5]
122
- output = torch.clamp(output, min=0.0, max=5.0)
123
-
124
- # Convert to list
125
- predictions = output.cpu().numpy().tolist()[0]
126
-
127
- # Create result with class interpretations
128
- drought_classes = {
129
- 0: "No Drought (D0)",
130
- 1: "Abnormally Dry (D1)",
131
- 2: "Moderate Drought (D2)",
132
- 3: "Severe Drought (D3)",
133
- 4: "Extreme Drought (D4)",
134
- 5: "Exceptional Drought (D5)"
135
- }
136
-
137
- result = {
138
- "raw_predictions": predictions,
139
- "max_class": {
140
- "class": int(np.argmax(predictions)),
141
- "label": drought_classes[int(np.argmax(predictions))],
142
- "confidence": float(np.max(predictions))
143
- },
144
- "class_probabilities": {
145
- drought_classes[i]: float(predictions[i]) for i in range(len(predictions))
146
- }
147
- }
148
-
149
- return JSONResponse(content=result)
150
-
151
- except Exception as e:
152
- raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
153
-
154
- def prepare_time_data(df):
155
- """
156
- Prepare time-series data for the model.
157
- """
158
- # Ensure we have YEAR and DOY columns
159
- if 'YEAR' not in df.columns or 'DOY' not in df.columns:
160
- # Try to extract from date column if it exists
161
- if 'date' in df.columns:
162
- df['date'] = pd.to_datetime(df['date'])
163
- df['YEAR'] = df['date'].dt.year
164
- df['DOY'] = df['date'].dt.dayofyear
165
- else:
166
- raise ValueError("Input CSV must contain either 'date' column or both 'YEAR' and 'DOY' columns")
167
-
168
- # Create date column if it doesn't exist
169
- if 'date' not in df.columns:
170
- df['date'] = pd.to_datetime(df['YEAR'].astype(str) + df['DOY'].astype(str), format="%Y%j")
171
-
172
- # Apply date encoding to create sin_day and cos_day columns
173
- df[['sin_day', 'cos_day']] = df['date'].apply(lambda d: pd.Series(date_encode(d)))
174
-
175
- # Handle missing values if any
176
- float_cols = [
177
- 'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
178
- 'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
179
- 'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
180
- ]
181
-
182
- for col in float_cols:
183
- if col in df.columns and df[col].isna().any():
184
- df[col] = interpolate_nans(df[col].values)
185
-
186
- return df
187
-
188
- if __name__ == "__main__":
189
- import uvicorn
 
190
  uvicorn.run("app:app", host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import pandas as pd
5
+ import numpy as np
6
+ import torch
7
+ import json
8
+ import io
9
+ import joblib
10
+ import os
11
+ from model import DroughtNetLSTM
12
+ from utils import normalize, date_encode, interpolate_nans
13
+ from datetime import datetime
14
+ from typing import List, Optional
15
+
16
+ app = FastAPI(
17
+ title="Drought Prediction API",
18
+ description="API for predicting drought severity based on weather data",
19
+ version="1.0.0"
20
+ )
21
+
22
+ # Enable CORS
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # Load model and scalers
32
+ @app.on_event("startup")
33
+ async def load_model():
34
+ global model, scaler_dict, scaler_dict_static, device
35
+
36
+ # Set device
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ # Load scalers
40
+ scaler_dict = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict.joblib"))
41
+ scaler_dict_static = joblib.load(os.path.join(os.path.dirname(__file__), "scaler_dict_static.joblib"))
42
+
43
+ # Define model parameters
44
+ time_dim = 20 # Number of features in time data (18 weather + 2 date encoding)
45
+ lstm_dim = 256
46
+ num_layers = 2
47
+ dropout = 0.15
48
+ static_dim = 29 # Number of features in static data
49
+ staticfc_dim = 16
50
+ hidden_dim = 256
51
+ output_size = 6 # Output classes
52
+ print("Khởi tạo dữ scaler hoàn tất")
53
+ # Initialize model
54
+ model = DroughtNetLSTM(
55
+ time_dim=time_dim,
56
+ lstm_dim=lstm_dim,
57
+ num_layers=num_layers,
58
+ dropout=dropout,
59
+ static_dim=static_dim,
60
+ staticfc_dim=staticfc_dim,
61
+ hidden_dim=hidden_dim,
62
+ output_size=output_size
63
+ )
64
+
65
+ # Load model weights
66
+ model.load_state_dict(torch.load(
67
+ os.path.join(os.path.dirname(__file__), "best_macro_f1_model.pt"),
68
+ map_location=device
69
+ ))
70
+ print("Khởi tạo dữ model hoàn tất")
71
+
72
+ model.to(device)
73
+ model.eval()
74
+
75
+ @app.get("/")
76
+ async def root():
77
+ return {"message": "Welcome to Drought Prediction API. Use /predict endpoint to make predictions."}
78
+
79
+ @app.post("/predict")
80
+ async def predict(
81
+ csv_file: UploadFile = File(...),
82
+ x_static: str = Form(...),
83
+ ):
84
+ try:
85
+ # Parse x_static from JSON string to list
86
+ x_static_list = json.loads(x_static)
87
+ x_static_array = np.array([x_static_list], dtype=np.float32)
88
+
89
+ # Read CSV file
90
+ content = await csv_file.read()
91
+ df = pd.read_csv(io.StringIO(content.decode('utf-8')))
92
+
93
+ # Prepare time data
94
+ df = prepare_time_data(df)
95
+
96
+ # Get features
97
+ float_cols = [
98
+ 'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
99
+ 'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
100
+ 'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
101
+ ]
102
+
103
+ features = float_cols + ['sin_day', 'cos_day']
104
+ x_time_array = df[features].to_numpy(dtype=np.float32)
105
+ x_time_array = np.expand_dims(x_time_array, axis=0)
106
+
107
+ # Normalize data
108
+ x_static_norm, x_time_norm = normalize(
109
+ x_static_array,
110
+ x_time_array,
111
+ scaler_dict=scaler_dict,
112
+ scaler_dict_static=scaler_dict_static
113
+ )
114
+
115
+ # Convert to tensors
116
+ x_time_tensor = torch.tensor(x_time_norm).float().to(device)
117
+ x_static_tensor = torch.tensor(x_static_norm).float().to(device)
118
+
119
+ # Predict
120
+ with torch.no_grad():
121
+ output = model(x_time_tensor, x_static_tensor)
122
+ # Clamp output to [0, 5]
123
+ output = torch.clamp(output, min=0.0, max=5.0)
124
+
125
+ # Convert to list
126
+ predictions = output.cpu().numpy().tolist()[0]
127
+
128
+ # Create result with class interpretations
129
+ drought_classes = {
130
+ 0: "No Drought (D0)",
131
+ 1: "Abnormally Dry (D1)",
132
+ 2: "Moderate Drought (D2)",
133
+ 3: "Severe Drought (D3)",
134
+ 4: "Extreme Drought (D4)",
135
+ 5: "Exceptional Drought (D5)"
136
+ }
137
+
138
+ result = {
139
+ "raw_predictions": predictions,
140
+ "max_class": {
141
+ "class": int(np.argmax(predictions)),
142
+ "label": drought_classes[int(np.argmax(predictions))],
143
+ "confidence": float(np.max(predictions))
144
+ },
145
+ "class_probabilities": {
146
+ drought_classes[i]: float(predictions[i]) for i in range(len(predictions))
147
+ }
148
+ }
149
+
150
+ return JSONResponse(content=result)
151
+
152
+ except Exception as e:
153
+ raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
154
+
155
+ def prepare_time_data(df):
156
+ """
157
+ Prepare time-series data for the model.
158
+ """
159
+ # Ensure we have YEAR and DOY columns
160
+ if 'YEAR' not in df.columns or 'DOY' not in df.columns:
161
+ # Try to extract from date column if it exists
162
+ if 'date' in df.columns:
163
+ df['date'] = pd.to_datetime(df['date'])
164
+ df['YEAR'] = df['date'].dt.year
165
+ df['DOY'] = df['date'].dt.dayofyear
166
+ else:
167
+ raise ValueError("Input CSV must contain either 'date' column or both 'YEAR' and 'DOY' columns")
168
+
169
+ # Create date column if it doesn't exist
170
+ if 'date' not in df.columns:
171
+ df['date'] = pd.to_datetime(df['YEAR'].astype(str) + df['DOY'].astype(str), format="%Y%j")
172
+
173
+ # Apply date encoding to create sin_day and cos_day columns
174
+ df[['sin_day', 'cos_day']] = df['date'].apply(lambda d: pd.Series(date_encode(d)))
175
+
176
+ # Handle missing values if any
177
+ float_cols = [
178
+ 'PRECTOTCORR', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET', 'T2M_MAX', 'T2M_MIN', 'T2M_RANGE',
179
+ 'TS', 'WS10M', 'WS10M_MAX', 'WS10M_MIN', 'WS10M_RANGE',
180
+ 'WS50M', 'WS50M_MAX', 'WS50M_MIN', 'WS50M_RANGE',
181
+ ]
182
+
183
+ for col in float_cols:
184
+ if col in df.columns and df[col].isna().any():
185
+ df[col] = interpolate_nans(df[col].values)
186
+
187
+ return df
188
+
189
+ if __name__ == "__main__":
190
+ import uvicorn
191
  uvicorn.run("app:app", host="0.0.0.0", port=8000)