GDP-Prediction / app.py
Jineet's picture
Update app.py
4c9997e verified
import os
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from model import GDPPredictor
# Initialize the predictor
predictor = GDPPredictor()
# Check if we have models saved
model_file = 'gdp_models.pkl'
data_file = 'Consolidated.csv'
# Initialize and train models if needed
if os.path.exists(model_file):
print("Loading pre-trained models...")
predictor.load_models(model_file)
predictor.load_data(data_file)
else:
print("Training new models...")
predictor.load_data(data_file)
predictor.train_models()
predictor.save_models(model_file)
# Get latest GDP for reference
latest_year, latest_gdp = predictor.get_latest_gdp()
# Create a dictionary of all features and their default values (current values)
feature_info = predictor.get_feature_info()
# Create sliders for each feature, organized by category
def create_feature_inputs():
inputs = []
# For each category
for category, features in feature_info.items():
# Add a label for the category
inputs.append(gr.Markdown(f"## {category}"))
# Add sliders for each feature in this category
for feature_name, (min_val, max_val, mean_val, current_val) in features.items():
# Adjust slider range to be a bit wider than historical data
slider_min = min_val * 0.9
slider_max = max_val * 1.1
# Create a slider for this feature
slider = gr.Slider(
minimum=slider_min,
maximum=slider_max,
value=current_val, # Default to current value
step=(slider_max - slider_min) / 100, # 100 steps across range
label=feature_name
)
inputs.append(slider)
return inputs
def predict(*feature_values):
# Get all input features as flat list
flat_inputs = list(feature_values)
# Map features to values
feature_names = []
for category, features in feature_info.items():
for feature_name in features:
feature_names.append(feature_name)
# Create input dictionary
input_dict = {feature_names[i]: flat_inputs[i] for i in range(len(feature_names))}
# Make prediction
try:
predictions = predictor.predict_gdp(input_dict)
# Get ensemble prediction and calculate change
ensemble_pred = predictions['Ensemble']
change = ensemble_pred - latest_gdp
pct_change = (change / latest_gdp) * 100
# Format results
result_text = f"# GDP Prediction Results\n\n"
result_text += f"## Primary Prediction\n"
result_text += f"**Ensemble Model:** {ensemble_pred:.2f} USD billion\n\n"
result_text += f"## Comparison with {latest_year} GDP ({latest_gdp:.2f} USD billion)\n"
result_text += f"**Change:** {change:.2f} USD billion ({pct_change:.2f}%)\n\n"
result_text += f"## All Model Predictions\n"
# Add all individual model predictions
for name, pred in sorted(predictions.items(), key=lambda x: x[1], reverse=True):
if name != 'Ensemble':
result_text += f"- **{name}:** {pred:.2f} USD billion\n"
# Create visualization
fig, ax = plt.subplots(figsize=(10, 6))
# Get last 10 years data
df = predictor.cleaned_df
last_years = df.sort_values('Year').tail(10)
ax.plot(last_years['Year'], last_years[predictor.target], 'o-', linewidth=2, label='Historical GDP')
# Add prediction point
pred_year = latest_year + 1
ax.scatter([pred_year], [predictions['Ensemble']], color='green', s=150, label='Prediction')
# Format plot
ax.set_title('GDP Prediction', fontsize=14)
ax.set_xlabel('Year', fontsize=12)
ax.set_ylabel('Real GDP (USD billion)', fontsize=12)
ax.grid(True, alpha=0.3)
ax.legend()
return result_text, fig
except Exception as e:
return f"Error making prediction: {str(e)}", None
# Create the interface
with gr.Blocks(title="GDP Predictor") as demo:
gr.Markdown("# GDP Prediction Model")
gr.Markdown(f"""
This application predicts GDP based on various economic indicators. The current dataset contains data up to the year {latest_year}.
Adjust the sliders below to see how changes in different economic indicators might affect GDP.
The default values are set to the most recent values from the dataset.
""")
with gr.Row():
with gr.Column(scale=2):
# Create input sliders from feature info
inputs = create_feature_inputs()
with gr.Column(scale=3):
# Output components
prediction_text = gr.Markdown("Adjust sliders and click 'Predict' to see results")
prediction_plot = gr.Plot(label="GDP Prediction Visualization")
# Predict button
predict_btn = gr.Button("Predict GDP")
predict_btn.click(fn=predict, inputs=inputs, outputs=[prediction_text, prediction_plot])
# Launch the app
if __name__ == "__main__":
demo.launch()