Update app.py
Browse files
app.py
CHANGED
|
@@ -5,26 +5,25 @@ from src.model import LSTM # Adjust to your model path
|
|
| 5 |
|
| 6 |
# Load the model
|
| 7 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 8 |
-
model_path = "./water_forecast_2.pth"
|
| 9 |
model = LSTM(input_size=8, lstm_layer_sizes=[128, 128, 128], output_size=3).to(device)
|
| 10 |
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
| 11 |
model.eval()
|
| 12 |
|
| 13 |
# Define the prediction function
|
| 14 |
def predict_water_usage(state_idx, target_year, structured_data):
|
| 15 |
-
# Convert input data to tensor
|
| 16 |
tensor_data = torch.tensor(np.array(list(structured_data.values())), dtype=torch.float32).to(device)
|
| 17 |
with torch.no_grad():
|
| 18 |
output = model(tensor_data)
|
| 19 |
return output.tolist()
|
| 20 |
|
| 21 |
-
#
|
| 22 |
inputs = [
|
| 23 |
-
gr.
|
| 24 |
-
gr.
|
| 25 |
-
gr.
|
| 26 |
]
|
| 27 |
-
outputs = gr.
|
| 28 |
|
| 29 |
interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs)
|
| 30 |
|
|
|
|
| 5 |
|
| 6 |
# Load the model
|
| 7 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 8 |
+
model_path = "./water_forecast_2.pth"
|
| 9 |
model = LSTM(input_size=8, lstm_layer_sizes=[128, 128, 128], output_size=3).to(device)
|
| 10 |
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
| 11 |
model.eval()
|
| 12 |
|
| 13 |
# Define the prediction function
|
| 14 |
def predict_water_usage(state_idx, target_year, structured_data):
|
|
|
|
| 15 |
tensor_data = torch.tensor(np.array(list(structured_data.values())), dtype=torch.float32).to(device)
|
| 16 |
with torch.no_grad():
|
| 17 |
output = model(tensor_data)
|
| 18 |
return output.tolist()
|
| 19 |
|
| 20 |
+
# Update Gradio interface using gr.components
|
| 21 |
inputs = [
|
| 22 |
+
gr.components.Number(label="State Index"),
|
| 23 |
+
gr.components.Number(label="Target Year"),
|
| 24 |
+
gr.components.JSON(label="Structured Data") # Expects JSON input
|
| 25 |
]
|
| 26 |
+
outputs = gr.components.JSON(label="Prediction")
|
| 27 |
|
| 28 |
interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs)
|
| 29 |
|