File size: 1,692 Bytes
4e188a6 a063804 4e188a6 5bf6359 4e188a6 a063804 5bf6359 4e188a6 a063804 5bf6359 fca02b7 145227f 2a200cd 5bf6359 45ca800 145227f 5bf6359 2a200cd 5bf6359 4e188a6 5bf6359 2a200cd 5bf6359 a063804 57cee6e a063804 145227f 2a200cd a063804 145227f 4e188a6 2a200cd a063804 7859c40 2a200cd 7859c40 2a200cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import gradio as gr
import numpy as np
from src.model import LSTM
# Load the model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = "./andhra_forecast.pt"
model = torch.load(model_path, map_location=device)
model.eval()
# Define the prediction function
def predict_water_usage(state_idx, target_year, structured_data):
if len(structured_data) < 3:
return {"error": "Structured data must include 3 years of data for the specified state."}
# Convert structured data for model input (extract values for model)
data_values = [list(values) for values in structured_data.values()]
inputs = [[np.log(value + 1) for value in sublist] for sublist in data_values]
# Ensure the data has the right shape for the model
if len(inputs) != 3:
return {"error": "Structured data should have 3 years of data."}
inputs = torch.tensor(inputs, dtype=torch.float32)
predictions = model(inputs).cpu().detach().numpy()
with torch.no_grad():
output = [np.exp(prediction) - 1 for prediction in predictions]
return output
# Get model output
return {"error" : "Does not contain the torch model grad"}
# Configure Gradio interface
inputs = [
gr.Number(label="State Index"), # Numeric input for state index
gr.Number(label="Target Year"), # Numeric input for target year
gr.JSON(label="Structured Data") # JSON input for structured data
]
outputs = gr.JSON(label="Prediction")
# Set up the Gradio Interface
interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs)
# Launch Gradio
if __name__ == "__main__":
interface.launch(show_error=True)
|