Bhuvi20 commited on
Commit
d2c867c
·
verified ·
1 Parent(s): 445c1f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -5,26 +5,25 @@ from src.model import LSTM # Adjust to your model path
5
 
6
  # Load the model
7
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
- model_path = "./water_forecast_2.pth" # Your model path
9
  model = LSTM(input_size=8, lstm_layer_sizes=[128, 128, 128], output_size=3).to(device)
10
  model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
11
  model.eval()
12
 
13
  # Define the prediction function
14
  def predict_water_usage(state_idx, target_year, structured_data):
15
- # Convert input data to tensor
16
  tensor_data = torch.tensor(np.array(list(structured_data.values())), dtype=torch.float32).to(device)
17
  with torch.no_grad():
18
  output = model(tensor_data)
19
  return output.tolist()
20
 
21
- # Set up Gradio interface
22
  inputs = [
23
- gr.inputs.Number(label="State Index"),
24
- gr.inputs.Number(label="Target Year"),
25
- gr.inputs.JSON(label="Structured Data") # Expects JSON input
26
  ]
27
- outputs = gr.outputs.JSON(label="Prediction")
28
 
29
  interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs)
30
 
 
5
 
6
  # Load the model
7
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ model_path = "./water_forecast_2.pth"
9
  model = LSTM(input_size=8, lstm_layer_sizes=[128, 128, 128], output_size=3).to(device)
10
  model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
11
  model.eval()
12
 
13
  # Define the prediction function
14
  def predict_water_usage(state_idx, target_year, structured_data):
 
15
  tensor_data = torch.tensor(np.array(list(structured_data.values())), dtype=torch.float32).to(device)
16
  with torch.no_grad():
17
  output = model(tensor_data)
18
  return output.tolist()
19
 
20
+ # Update Gradio interface using gr.components
21
  inputs = [
22
+ gr.components.Number(label="State Index"),
23
+ gr.components.Number(label="Target Year"),
24
+ gr.components.JSON(label="Structured Data") # Expects JSON input
25
  ]
26
+ outputs = gr.components.JSON(label="Prediction")
27
 
28
  interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs)
29