HussainKAUST's picture
Update app.py
2f7cacd verified
import gradio as gr
import joblib
import pandas as pd
import numpy as np
import pickle
from openai import OpenAI
import os
# ====================================
# 1. LOAD MODEL FILES
# ====================================
print("๐Ÿš€ Loading model...")
model = joblib.load('saudi_gdp_model.pkl')
scaler = joblib.load('saudi_gdp_scaler.pkl')
selected_features = joblib.load('saudi_gdp_features.pkl')
df_base = pd.read_csv('df_base.csv')
with open('metadata.pkl', 'rb') as f:
metadata = pickle.load(f)
sector_cols = metadata['sector_cols']
key_sectors = metadata['key_sectors']
known_forecasts = metadata['known_forecasts']
print("โœ“ Model ready")
# ====================================
# 2. INITIALIZE OPENAI CLIENT
# ====================================
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# ====================================
# 3. MODEL FORECASTING FUNCTION
# ====================================
def recursive_forecast(target_year):
"""Exact recursive forecasting from Colab"""
forecast_df = df_base.copy()
last_known_year = int(df_base['Year'].max())
for year in range(last_known_year + 1, target_year + 1):
feature_row = {}
for feat in selected_features:
if 'lag1' in feat:
base_feat = feat.replace('_lag1', '')
feature_row[feat] = forecast_df[base_feat].iloc[-1] if base_feat in forecast_df.columns else 0
elif 'lag2' in feat:
base_feat = feat.replace('_lag2', '')
feature_row[feat] = forecast_df[base_feat].iloc[-2] if base_feat in forecast_df.columns and len(forecast_df) >= 2 else 0
elif 'lag3' in feat:
base_feat = feat.replace('_lag3', '')
feature_row[feat] = forecast_df[base_feat].iloc[-3] if base_feat in forecast_df.columns and len(forecast_df) >= 3 else 0
else:
feature_row[feat] = 0
X_new = pd.DataFrame([feature_row])[selected_features]
X_new_scaled = scaler.transform(X_new)
y_pred = model.predict(X_new_scaled)[0]
last_gdp = forecast_df['Non-Oil Activities'].iloc[-1]
new_gdp = last_gdp * (1 + y_pred / 100)
new_row = {'Year': year, 'Non-Oil Activities': new_gdp, 'NonOil_GDP_YoY': y_pred}
for sector in sector_cols:
new_row[f"{sector}_YoY"] = y_pred * 0.9
new_row['CPI_YoY_%'] = forecast_df['CPI_YoY_%'].tail(3).mean()
new_row['PMI_Annual_Avg'] = forecast_df['PMI_Annual_Avg'].tail(3).mean()
new_row['POS_YoY'] = forecast_df['POS_YoY'].tail(3).mean()
for sector in key_sectors:
if f"{sector}_roll3" in forecast_df.columns:
new_row[f"{sector}_roll3"] = forecast_df[sector].tail(3).mean() if sector in forecast_df.columns else y_pred * 0.9
forecast_df = pd.concat([forecast_df, pd.DataFrame([new_row])], ignore_index=True)
return y_pred
def forecast_interface(year_input):
"""Clean forecast output with better formatting"""
try:
year = int(year_input)
if year < 2025 or year > 2050:
return f"โŒ Please enter a year between 2025-2050"
prediction = recursive_forecast(year)
# Better formatted output with line breaks
result = f"""
**Year:** {year}
**Predicted Growth:** {prediction:.2f}%
"""
return result
except Exception as e:
return f"โŒ Error: {str(e)}"
# ====================================
# 4. CHATBOT FUNCTION
# ====================================
def chatbot_interface(user_message, history):
"""Clean conversational chatbot"""
if not user_message.strip():
return history
try:
system_prompt = f"""You are a Saudi economy expert. Answer clearly and conversationally.
KEY DATA:
- Forecasts: 2025 ({known_forecasts[2025]:.2f}%), 2026 ({known_forecasts[2026]:.2f}%), 2027 ({known_forecasts[2027]:.2f}%)
- Top drivers: Retail Trade, Government, Construction, Transport
- Recent: 2022 (12.4%), 2023 (7.0%)
- Accuracy: RMSE {metadata['test_rmse']:.2f}%
STYLE: Conversational, 2-4 sentences, direct, helpful."""
gpt_messages = [{"role": "system", "content": system_prompt}]
for msg in history[-4:]:
if msg["role"] == "user":
gpt_messages.append({"role": "user", "content": msg["content"]})
elif msg["role"] == "assistant":
gpt_messages.append({"role": "assistant", "content": msg["content"]})
gpt_messages.append({"role": "user", "content": user_message})
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=gpt_messages,
temperature=0.7,
max_tokens=250
)
assistant_message = response.choices[0].message.content
new_history = history + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
return new_history
except Exception as e:
error_history = history + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": f"Error: {str(e)}"}
]
return error_history
# ====================================
# 5. GRADIO INTERFACE - IMPROVED DESIGN
# ====================================
custom_css = """
/* Container settings - WIDER */
.gradio-container {
max-width: 1600px !important;
margin: 0 auto !important;
}
/* Main content centering */
.contain {
max-width: 1400px !important;
margin: 0 auto !important;
}
/* Forecast output - WHITE text on BRIGHT GREEN background */
.forecast-output {
font-size: 1.8em !important;
font-weight: 600 !important;
padding: 40px 30px !important;
background: linear-gradient(135deg, #00a651 0%, #006C35 100%) !important;
color: #FFFFFF !important;
border-radius: 16px !important;
text-align: center !important;
box-shadow: 0 8px 20px rgba(0, 108, 53, 0.3) !important;
line-height: 1.8 !important;
min-height: 200px !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
}
/* Force white text in forecast output */
.forecast-output p {
color: #FFFFFF !important;
margin: 0 !important;
}
.forecast-output strong {
color: #FFFFFF !important;
font-weight: 700 !important;
}
/* Tab styling */
.tab-nav button {
font-size: 1.1em !important;
padding: 12px 24px !important;
}
/* Input boxes - larger */
.gradio-textbox input {
font-size: 1.1em !important;
padding: 12px !important;
}
/* Buttons - larger */
.gradio-button {
font-size: 1.1em !important;
padding: 12px 24px !important;
border-radius: 8px !important;
}
/* Examples */
.example-container {
margin-top: 15px !important;
}
/* Chatbot messages - larger */
.message-wrap {
font-size: 1.05em !important;
}
/* Better spacing */
.block {
padding: 20px !important;
}
/* Center align tabs content */
.tabitem {
padding: 30px 20px !important;
}
/* Title styling */
h1 {
text-align: center !important;
margin-bottom: 10px !important;
}
h3 {
text-align: center !important;
color: #666 !important;
margin-top: 5px !important;
}
"""
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="green"),
css=custom_css,
title="๐Ÿ‡ธ๐Ÿ‡ฆ Saudi GDP Intelligence"
) as demo:
gr.Markdown("""
# ๐Ÿ‡ธ๐Ÿ‡ฆ Saudi Non-Oil GDP Intelligence
### AI-Powered Economic Forecasting & Analysis
""")
with gr.Tabs():
with gr.Tab("๐Ÿ“Š GDP Forecast"):
with gr.Row():
with gr.Column(scale=1):
year_input = gr.Textbox(
label="Enter Year (2025-2050)",
placeholder="2030",
value="2025",
lines=1
)
forecast_btn = gr.Button("๐Ÿ”ฎ Forecast", variant="primary", size="lg")
gr.Examples(
examples=[["2025"], ["2030"], ["2035"]],
inputs=year_input,
label="Quick Examples"
)
with gr.Column(scale=1):
forecast_output = gr.Markdown(
value="๐Ÿ‘ˆ **Enter a year to see forecast**",
elem_classes="forecast-output"
)
forecast_btn.click(
fn=forecast_interface,
inputs=year_input,
outputs=forecast_output
)
with gr.Tab("๐Ÿ’ฌ Economic Chat"):
chatbot = gr.Chatbot(
label="",
height=600,
show_label=False,
type="messages",
avatar_images=(None, "https://em-content.zobj.net/source/twitter/376/flag-saudi-arabia_1f1f8-1f1e6.png")
)
with gr.Row():
msg_input = gr.Textbox(
label="",
placeholder="Ask me about Saudi GDP, forecasts, sectors, or Vision 2030...",
lines=1,
scale=9,
show_label=False
)
send_btn = gr.Button("Send", variant="primary", scale=1)
gr.Examples(
examples=[
["What's the forecast for 2025?"],
["How accurate is the model?"],
["What drives GDP growth?"],
["Tell me about Vision 2030"]
],
inputs=msg_input,
label="Try these questions"
)
clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Chat", size="sm")
send_btn.click(
fn=chatbot_interface,
inputs=[msg_input, chatbot],
outputs=chatbot
).then(
fn=lambda: "",
outputs=msg_input
)
msg_input.submit(
fn=chatbot_interface,
inputs=[msg_input, chatbot],
outputs=chatbot
).then(
fn=lambda: "",
outputs=msg_input
)
clear_btn.click(fn=lambda: None, outputs=chatbot)
if __name__ == "__main__":
demo.launch()