File size: 1,692 Bytes
4e188a6
a063804
4e188a6
5bf6359
4e188a6
 
a063804
5bf6359
 
4e188a6
a063804
 
5bf6359
 
 
fca02b7
145227f
2a200cd
5bf6359
45ca800
145227f
 
5bf6359
 
2a200cd
 
 
5bf6359
 
 
4e188a6
5bf6359
 
 
2a200cd
5bf6359
a063804
57cee6e
a063804
145227f
 
2a200cd
a063804
145227f
 
4e188a6
2a200cd
a063804
7859c40
2a200cd
7859c40
2a200cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import gradio as gr
import numpy as np
from src.model import LSTM  

# Load the model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = "./andhra_forecast.pt" 
model = torch.load(model_path, map_location=device)
model.eval()
# Define the prediction function
def predict_water_usage(state_idx, target_year, structured_data):
   
    if len(structured_data) < 3:
        return {"error": "Structured data must include 3 years of data for the specified state."}

    # Convert structured data for model input (extract values for model)
    data_values = [list(values) for values in structured_data.values()]
    inputs = [[np.log(value + 1) for value in sublist] for sublist in data_values]
    
    
    # Ensure the data has the right shape for the model
    if len(inputs) != 3:  
        return {"error": "Structured data should have 3 years of data."}



    inputs = torch.tensor(inputs, dtype=torch.float32)
    predictions = model(inputs).cpu().detach().numpy()

    with torch.no_grad():
        output = [np.exp(prediction) - 1 for prediction in predictions]
        return output
    # Get model output

    return {"error" : "Does not contain the torch model grad"}

# Configure Gradio interface
inputs = [
    gr.Number(label="State Index"),  # Numeric input for state index
    gr.Number(label="Target Year"),  # Numeric input for target year
    gr.JSON(label="Structured Data")  # JSON input for structured data
]

outputs = gr.JSON(label="Prediction")

# Set up the Gradio Interface
interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs)

# Launch Gradio
if __name__ == "__main__":
    interface.launch(show_error=True)