LSTM-forecaster / gradio_ui.py
nkapila6's picture
Upload gradio_ui.py
820f8cc verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on 2024-07-19 15:52:34 Friday
@author: Nikhil Kapila
"""
import gradio as gr
import pandas as pd
import utils.fetcher as fetch
import utils.forecasting_utils as forecast_utils
import utils.prediction as predict
import utils.plot as plotter
import utils.anomaly_detection as anomaly
import utils.fetch_comstock_data as cfetcher
building_types = ['Office', 'Education', 'Lodging / Residential']
prediction_period_choices = ['short term', 'long term']
sizes = ['1001_5000', '5001_10000', '10001_25000', '25001_50000', '50001_100000', '100001_200000', '200001_500000']
states = ['GA']
years = ['1990 to 1999', '1980 to 1989', '2000 to 2012', '2013 to 2018']
def inference_pipeline(dataset_loc:str, btype:str, state:str, ft:str, index_col:str, output_col:str, lookback_period:str='short_term'):
print(index_col)
print(output_col)
if output_col == '': output_col = 'out.site_energy.total.energy_consumption'
# if output_col == '': output_col = 'out.electricity.total.energy_consumption'
if index_col == '': index_col = 'timestamp'
metadata_loc = 'datasets/comstock/GA_filtered_building_list.csv'
if btype == 'Lodging / Residential': btype = 'lodging_residential'
if lookback_period == 'short term':
lookback = 30
inference_model_path = 'metaflow_models_30hrs'
else:
lookback = 360
inference_model_path = 'metaflow_models_360hrs'
# fetch torch model and scaler
model = fetch.fetch_model(ft, btype.lower(), inference_model_path)
scaler = fetch.fetch_scaler(ft, btype.lower(), inference_model_path)
# input reading
try: df = pd.read_parquet(dataset_loc)
except: df = pd.read_csv(dataset_loc)
if index_col not in df.columns:
raise ValueError(f'Error column {index_col} is not found in the csv file.')
else: df.set_index(index_col, inplace=True)
# if there are other columns, pick only the output col
if len(df.columns) > 1:
df = df[output_col]
# if data is not hourly, make it hourly
if pd.infer_freq(df.index) != 'h':
hourly_data = df.resample('h').sum()
hourly_data = pd.DataFrame(hourly_data)
else: hourly_data = df
# creating sliding windows since the torch model expects it to make inference
X, y, timestamp = forecast_utils.sliding_windows(hourly_data, lookback)
# return dfs to be plotted
predicted_lookback = predict.model_predict_lookback(model, scaler, hourly_data, lookback)
predicted_input, mape = predict.model_predict_full_input(model, scaler, hourly_data, lookback)
anomalies = anomaly.detect_anomalies_with_sliding_windows(hourly_data, lookback)
# get upgrade dicts
comstock_id = cfetcher.find_id(floor_area=ft, building_type=btype)
upgrade_dict = cfetcher.get_datasets_from_comstock(b_id=comstock_id,
url_dict=cfetcher.fetch_building_urls(b_id=comstock_id, state=state))
# return plotly fig objects so gr.Plot() can understand
plot_anomalies = anomaly.plotly_anomaly(anomalies)
plot_fig_full_input = plotter.standard_plotter(timestamp, y, predicted_input)
plot_fig_lookback = plotter.lookback_plotter(hourly_data, predicted_lookback)
plot_0 = plotter.upgrade_plotter(hourly_data, upgrade_dict[0])
plot_28 = plotter.upgrade_plotter(hourly_data, upgrade_dict[28])
plot_29 = plotter.upgrade_plotter(hourly_data, upgrade_dict[29])
plot_31 = plotter.upgrade_plotter(hourly_data, upgrade_dict[31])
return (plot_fig_lookback,
plot_fig_full_input,
plot_anomalies,
gr.Markdown(f'Trained model has an error of {mape}', visible=True),
plot_0,
plot_28,
plot_29,
plot_31)
# Gradio UI starts
with gr.Blocks() as demo:
gr.Markdown('# LSTM Forecaster')
with gr.Row():
with gr.Column():
btype = gr.Dropdown(label="How is the building being used?", choices=building_types)
# construction_year = gr.Dropdown(label='Which year building was constructed?', choices=years)
state = gr.Dropdown(label='State in the US', choices=states)
size = gr.Dropdown(label='Size in sq-ft', choices=sizes)
forecast_period = gr.Dropdown(label='How far into the future to forecast?', choices=prediction_period_choices)
index_col = gr.Textbox(label='Date-time (quarterly or hourly) column in your dataset? (default in placeholder)', placeholder='timestamp')
value_col = gr.Textbox(label='Electricity value column in your dataset? (default in placeholder)', placeholder='out.electricity.total.energy_consumption')
in_data_path = gr.UploadButton(variant='secondary', label="Upload past data", type="filepath")
# use_example = gr.UploadButton(variant='secondary', label="Use example data", type="filepath")
make_inference = gr.Button(value='Forecast', variant='primary')
with gr.Column():
gr.Markdown('## Forecast On Input Data')
plot_forecast_lookback = gr.Plot()
gr.Markdown('## Using Input Data To Assess Model Performance')
gr.Markdown('This gives us an idea on how much the trained model performs on unseen input by the user.')
markdown_lstm_forecast_input = gr.Markdown('Model Forecast Accuracy on Provided Data', visible=False)
plot_forecast_full_input = gr.Plot()
gr.Markdown('## Comparison of input data with different upgrades')
gr.Markdown('### Comparison with Similar Building')
plot_0 = gr.Plot()
gr.Markdown('### Comparison with Similar Building having Upgrade #28: Wall & Roof Insulation + New Windows')
plot_28 = gr.Plot()
gr.Markdown('### Comparison with Similar Building having Upgrade #29: LED Lighting + Variable Speed HP RTU or HP Boilers')
plot_29 = gr.Plot()
gr.Markdown('### Comparison with Similar Building having Upgrade #31: Upgrade 28 + Upgrade 29')
plot_31 = gr.Plot()
gr.Markdown('## Anomaly detection (WIP)')
plot_anomalies = gr.Plot()
make_inference.click(fn=inference_pipeline, inputs=[in_data_path, btype, state, size, index_col, value_col, forecast_period],
outputs=[plot_forecast_lookback,
plot_forecast_full_input,
plot_anomalies,
markdown_lstm_forecast_input,
plot_0,
plot_28,
plot_29,
plot_31])
# plot_on_input = gr.Plot(label='Plot on Input')
# inference_plot = gr.Plot(label='Plot on Inference')
if __name__ == "__main__":
demo.launch()