Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from src.model import LSTM | |
| # Load the model | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model_path = "./water_forecast_8.pt" | |
| model = torch.load(model_path, map_location=device) | |
| model.eval() | |
| # Define the prediction function | |
| def predict_water_usage(state_idx, target_year, structured_data): | |
| if len(structured_data) < 3: | |
| return {"error": "Structured data must include 3 years of data for the specified state."} | |
| # Convert structured data for model input (extract values for model) | |
| data_values = [list(values) for values in structured_data.values()] | |
| inputs = [[np.log(value + 1) for value in sublist] for sublist in data_values] | |
| # Ensure the data has the right shape for the model | |
| if len(inputs) != 3: | |
| return {"error": "Structured data should have 3 years of data."} | |
| inputs = torch.tensor(inputs, dtype=torch.float32) | |
| predictions = model(inputs).cpu().detach().numpy() | |
| with torch.no_grad(): | |
| output = [np.exp(prediction) - 1 for prediction in predictions] | |
| return output | |
| # Get model output | |
| return {"error" : "Does not contain the torch model grad"} | |
| # Configure Gradio interface | |
| inputs = [ | |
| gr.Number(label="State Index"), # Numeric input for state index | |
| gr.Number(label="Target Year"), # Numeric input for target year | |
| gr.JSON(label="Structured Data") # JSON input for structured data | |
| ] | |
| outputs = gr.JSON(label="Prediction") | |
| # Set up the Gradio Interface | |
| interface = gr.Interface(fn=predict_water_usage, inputs=inputs, outputs=outputs) | |
| # Launch Gradio | |
| if __name__ == "__main__": | |
| interface.launch(show_error=True) | |