Bhuvi20 commited on
Commit
fca02b7
·
verified ·
1 Parent(s): e4a4f84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -5
app.py CHANGED
@@ -12,16 +12,38 @@ model.eval()
12
 
13
  # Define the prediction function
14
  def predict_water_usage(state_idx, target_year, structured_data):
15
- tensor_data = torch.tensor(np.array(list(structured_data.values())), dtype=torch.float32).to(device)
 
 
 
 
 
 
 
 
16
  with torch.no_grad():
17
  output = model(tensor_data)
18
- return output.tolist()
 
19
 
20
- # Update Gradio interface using gr.components
21
  inputs = [
22
  gr.components.Number(label="State Index"),
23
  gr.components.Number(label="Target Year"),
24
- gr.components.JSON(label="Structured Data") # Expects JSON input
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ]
26
  outputs = gr.components.JSON(label="Prediction")
27
 
@@ -29,4 +51,4 @@ interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs)
29
 
30
  # Launch the Gradio app
31
  if __name__ == "__main__":
32
- interface.launch(share=True)
 
12
 
13
  # Define the prediction function
14
  def predict_water_usage(state_idx, target_year, structured_data):
15
+ # Ensure structured_data is correctly formatted
16
+ if state_idx not in structured_data or len(structured_data[state_idx]) < 5:
17
+ return {"error": "Structured data must include 5 years of data for the specified state."}
18
+
19
+ # Convert structured data for model input
20
+ data_values = [values for year, values in structured_data[state_idx].items()]
21
+ tensor_data = torch.tensor(data_values, dtype=torch.float32).to(device)
22
+
23
+ # Get model output
24
  with torch.no_grad():
25
  output = model(tensor_data)
26
+
27
+ return {"prediction": output.tolist()}
28
 
29
+ # Configure Gradio interface with structured JSON input example
30
  inputs = [
31
  gr.components.Number(label="State Index"),
32
  gr.components.Number(label="Target Year"),
33
+ gr.components.JSON(
34
+ label="Structured Data",
35
+ placeholder="""
36
+ {
37
+ "state_idx": {
38
+ "2020": [value1, value2, ..., value8],
39
+ "2021": [value1, value2, ..., value8],
40
+ "2022": [value1, value2, ..., value8],
41
+ "2023": [value1, value2, ..., value8],
42
+ "2024": [value1, value2, ..., value8]
43
+ }
44
+ }
45
+ """
46
+ )
47
  ]
48
  outputs = gr.components.JSON(label="Prediction")
49
 
 
51
 
52
  # Launch the Gradio app
53
  if __name__ == "__main__":
54
+ interface.launch()