Bhuvi20 commited on
Commit
145227f
·
verified ·
1 Parent(s): 57cee6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import gradio as gr
3
  import numpy as np
 
4
  from src.model import LSTM # Adjust to your model path
5
 
6
  # Load the model
@@ -13,13 +14,21 @@ model.eval()
13
  # Define the prediction function
14
  def predict_water_usage(state_idx, target_year, structured_data):
15
  # Convert structured data JSON string to dictionary
16
- structured_data = eval(structured_data) if isinstance(structured_data, str) else structured_data
17
-
 
 
 
18
  if state_idx not in structured_data or len(structured_data[state_idx]) < 5:
19
  return {"error": "Structured data must include 5 years of data for the specified state."}
20
 
21
- # Convert structured data for model input
22
- data_values = [values for year, values in structured_data[state_idx].items()]
 
 
 
 
 
23
  tensor_data = torch.tensor(data_values, dtype=torch.float32).to(device)
24
 
25
  # Get model output
@@ -30,23 +39,24 @@ def predict_water_usage(state_idx, target_year, structured_data):
30
 
31
  # Configure Gradio interface
32
  inputs = [
33
- gr.components.Number(label="State Index"),
34
- gr.components.Number(label="Target Year"),
35
- gr.components.Textbox(
36
  label="Structured Data (JSON format)",
37
  lines=10,
38
  placeholder="""{
39
- "state_idx": {
40
- "2020": [value1, value2, ..., value8],
41
- "2021": [value1, value2, ..., value8],
42
- "2022": [value1, value2, ..., value8],
43
- "2023": [value1, value2, ..., value8],
44
- "2024": [value1, value2, ..., value8]
45
- }
46
- }"""
47
  )
48
  ]
49
- outputs = gr.components.JSON(label="Prediction")
 
50
 
51
  interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs)
52
 
 
1
  import torch
2
  import gradio as gr
3
  import numpy as np
4
+ import json # Import json for safer parsing
5
  from src.model import LSTM # Adjust to your model path
6
 
7
  # Load the model
 
14
  # Define the prediction function
15
  def predict_water_usage(state_idx, target_year, structured_data):
16
  # Convert structured data JSON string to dictionary
17
+ try:
18
+ structured_data = json.loads(structured_data) if isinstance(structured_data, str) else structured_data
19
+ except json.JSONDecodeError:
20
+ return {"error": "Invalid JSON format for structured data."}
21
+
22
  if state_idx not in structured_data or len(structured_data[state_idx]) < 5:
23
  return {"error": "Structured data must include 5 years of data for the specified state."}
24
 
25
+ # Convert structured data for model input (extract values for model)
26
+ data_values = [list(values) for year, values in structured_data[state_idx].items()]
27
+
28
+ # Ensure the data has the right shape for the model
29
+ if len(data_values) != 5: # Check if there are exactly 5 years of data
30
+ return {"error": "Structured data should have 5 years of data."}
31
+
32
  tensor_data = torch.tensor(data_values, dtype=torch.float32).to(device)
33
 
34
  # Get model output
 
39
 
40
  # Configure Gradio interface
41
  inputs = [
42
+ gr.Number(label="State Index"), # Numeric input for state index
43
+ gr.Number(label="Target Year"), # Numeric input for target year
44
+ gr.Textbox(
45
  label="Structured Data (JSON format)",
46
  lines=10,
47
  placeholder="""{
48
+ "state_idx": {
49
+ "2020": [value1, value2, ..., value8],
50
+ "2021": [value1, value2, ..., value8],
51
+ "2022": [value1, value2, ..., value8],
52
+ "2023": [value1, value2, ..., value8],
53
+ "2024": [value1, value2, ..., value8]
54
+ }
55
+ }"""
56
  )
57
  ]
58
+
59
+ outputs = gr.JSON(label="Prediction")
60
 
61
  interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs)
62