Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import joblib | |
| import pandas as pd | |
| import numpy as np | |
| import pickle | |
| from openai import OpenAI | |
| import os | |
| # ==================================== | |
| # 1. LOAD MODEL FILES | |
| # ==================================== | |
| print("๐ Loading model...") | |
| model = joblib.load('saudi_gdp_model.pkl') | |
| scaler = joblib.load('saudi_gdp_scaler.pkl') | |
| selected_features = joblib.load('saudi_gdp_features.pkl') | |
| df_base = pd.read_csv('df_base.csv') | |
| with open('metadata.pkl', 'rb') as f: | |
| metadata = pickle.load(f) | |
| sector_cols = metadata['sector_cols'] | |
| key_sectors = metadata['key_sectors'] | |
| known_forecasts = metadata['known_forecasts'] | |
| print("โ Model ready") | |
| # ==================================== | |
| # 2. INITIALIZE OPENAI CLIENT | |
| # ==================================== | |
| client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| # ==================================== | |
| # 3. MODEL FORECASTING FUNCTION | |
| # ==================================== | |
| def recursive_forecast(target_year): | |
| """Exact recursive forecasting from Colab""" | |
| forecast_df = df_base.copy() | |
| last_known_year = int(df_base['Year'].max()) | |
| for year in range(last_known_year + 1, target_year + 1): | |
| feature_row = {} | |
| for feat in selected_features: | |
| if 'lag1' in feat: | |
| base_feat = feat.replace('_lag1', '') | |
| feature_row[feat] = forecast_df[base_feat].iloc[-1] if base_feat in forecast_df.columns else 0 | |
| elif 'lag2' in feat: | |
| base_feat = feat.replace('_lag2', '') | |
| feature_row[feat] = forecast_df[base_feat].iloc[-2] if base_feat in forecast_df.columns and len(forecast_df) >= 2 else 0 | |
| elif 'lag3' in feat: | |
| base_feat = feat.replace('_lag3', '') | |
| feature_row[feat] = forecast_df[base_feat].iloc[-3] if base_feat in forecast_df.columns and len(forecast_df) >= 3 else 0 | |
| else: | |
| feature_row[feat] = 0 | |
| X_new = pd.DataFrame([feature_row])[selected_features] | |
| X_new_scaled = scaler.transform(X_new) | |
| y_pred = model.predict(X_new_scaled)[0] | |
| last_gdp = forecast_df['Non-Oil Activities'].iloc[-1] | |
| new_gdp = last_gdp * (1 + y_pred / 100) | |
| new_row = {'Year': year, 'Non-Oil Activities': new_gdp, 'NonOil_GDP_YoY': y_pred} | |
| for sector in sector_cols: | |
| new_row[f"{sector}_YoY"] = y_pred * 0.9 | |
| new_row['CPI_YoY_%'] = forecast_df['CPI_YoY_%'].tail(3).mean() | |
| new_row['PMI_Annual_Avg'] = forecast_df['PMI_Annual_Avg'].tail(3).mean() | |
| new_row['POS_YoY'] = forecast_df['POS_YoY'].tail(3).mean() | |
| for sector in key_sectors: | |
| if f"{sector}_roll3" in forecast_df.columns: | |
| new_row[f"{sector}_roll3"] = forecast_df[sector].tail(3).mean() if sector in forecast_df.columns else y_pred * 0.9 | |
| forecast_df = pd.concat([forecast_df, pd.DataFrame([new_row])], ignore_index=True) | |
| return y_pred | |
| def forecast_interface(year_input): | |
| """Clean forecast output with better formatting""" | |
| try: | |
| year = int(year_input) | |
| if year < 2025 or year > 2050: | |
| return f"โ Please enter a year between 2025-2050" | |
| prediction = recursive_forecast(year) | |
| # Better formatted output with line breaks | |
| result = f""" | |
| **Year:** {year} | |
| **Predicted Growth:** {prediction:.2f}% | |
| """ | |
| return result | |
| except Exception as e: | |
| return f"โ Error: {str(e)}" | |
| # ==================================== | |
| # 4. CHATBOT FUNCTION | |
| # ==================================== | |
| def chatbot_interface(user_message, history): | |
| """Clean conversational chatbot""" | |
| if not user_message.strip(): | |
| return history | |
| try: | |
| system_prompt = f"""You are a Saudi economy expert. Answer clearly and conversationally. | |
| KEY DATA: | |
| - Forecasts: 2025 ({known_forecasts[2025]:.2f}%), 2026 ({known_forecasts[2026]:.2f}%), 2027 ({known_forecasts[2027]:.2f}%) | |
| - Top drivers: Retail Trade, Government, Construction, Transport | |
| - Recent: 2022 (12.4%), 2023 (7.0%) | |
| - Accuracy: RMSE {metadata['test_rmse']:.2f}% | |
| STYLE: Conversational, 2-4 sentences, direct, helpful.""" | |
| gpt_messages = [{"role": "system", "content": system_prompt}] | |
| for msg in history[-4:]: | |
| if msg["role"] == "user": | |
| gpt_messages.append({"role": "user", "content": msg["content"]}) | |
| elif msg["role"] == "assistant": | |
| gpt_messages.append({"role": "assistant", "content": msg["content"]}) | |
| gpt_messages.append({"role": "user", "content": user_message}) | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=gpt_messages, | |
| temperature=0.7, | |
| max_tokens=250 | |
| ) | |
| assistant_message = response.choices[0].message.content | |
| new_history = history + [ | |
| {"role": "user", "content": user_message}, | |
| {"role": "assistant", "content": assistant_message} | |
| ] | |
| return new_history | |
| except Exception as e: | |
| error_history = history + [ | |
| {"role": "user", "content": user_message}, | |
| {"role": "assistant", "content": f"Error: {str(e)}"} | |
| ] | |
| return error_history | |
| # ==================================== | |
| # 5. GRADIO INTERFACE - IMPROVED DESIGN | |
| # ==================================== | |
| custom_css = """ | |
| /* Container settings - WIDER */ | |
| .gradio-container { | |
| max-width: 1600px !important; | |
| margin: 0 auto !important; | |
| } | |
| /* Main content centering */ | |
| .contain { | |
| max-width: 1400px !important; | |
| margin: 0 auto !important; | |
| } | |
| /* Forecast output - WHITE text on BRIGHT GREEN background */ | |
| .forecast-output { | |
| font-size: 1.8em !important; | |
| font-weight: 600 !important; | |
| padding: 40px 30px !important; | |
| background: linear-gradient(135deg, #00a651 0%, #006C35 100%) !important; | |
| color: #FFFFFF !important; | |
| border-radius: 16px !important; | |
| text-align: center !important; | |
| box-shadow: 0 8px 20px rgba(0, 108, 53, 0.3) !important; | |
| line-height: 1.8 !important; | |
| min-height: 200px !important; | |
| display: flex !important; | |
| align-items: center !important; | |
| justify-content: center !important; | |
| } | |
| /* Force white text in forecast output */ | |
| .forecast-output p { | |
| color: #FFFFFF !important; | |
| margin: 0 !important; | |
| } | |
| .forecast-output strong { | |
| color: #FFFFFF !important; | |
| font-weight: 700 !important; | |
| } | |
| /* Tab styling */ | |
| .tab-nav button { | |
| font-size: 1.1em !important; | |
| padding: 12px 24px !important; | |
| } | |
| /* Input boxes - larger */ | |
| .gradio-textbox input { | |
| font-size: 1.1em !important; | |
| padding: 12px !important; | |
| } | |
| /* Buttons - larger */ | |
| .gradio-button { | |
| font-size: 1.1em !important; | |
| padding: 12px 24px !important; | |
| border-radius: 8px !important; | |
| } | |
| /* Examples */ | |
| .example-container { | |
| margin-top: 15px !important; | |
| } | |
| /* Chatbot messages - larger */ | |
| .message-wrap { | |
| font-size: 1.05em !important; | |
| } | |
| /* Better spacing */ | |
| .block { | |
| padding: 20px !important; | |
| } | |
| /* Center align tabs content */ | |
| .tabitem { | |
| padding: 30px 20px !important; | |
| } | |
| /* Title styling */ | |
| h1 { | |
| text-align: center !important; | |
| margin-bottom: 10px !important; | |
| } | |
| h3 { | |
| text-align: center !important; | |
| color: #666 !important; | |
| margin-top: 5px !important; | |
| } | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="green"), | |
| css=custom_css, | |
| title="๐ธ๐ฆ Saudi GDP Intelligence" | |
| ) as demo: | |
| gr.Markdown(""" | |
| # ๐ธ๐ฆ Saudi Non-Oil GDP Intelligence | |
| ### AI-Powered Economic Forecasting & Analysis | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("๐ GDP Forecast"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| year_input = gr.Textbox( | |
| label="Enter Year (2025-2050)", | |
| placeholder="2030", | |
| value="2025", | |
| lines=1 | |
| ) | |
| forecast_btn = gr.Button("๐ฎ Forecast", variant="primary", size="lg") | |
| gr.Examples( | |
| examples=[["2025"], ["2030"], ["2035"]], | |
| inputs=year_input, | |
| label="Quick Examples" | |
| ) | |
| with gr.Column(scale=1): | |
| forecast_output = gr.Markdown( | |
| value="๐ **Enter a year to see forecast**", | |
| elem_classes="forecast-output" | |
| ) | |
| forecast_btn.click( | |
| fn=forecast_interface, | |
| inputs=year_input, | |
| outputs=forecast_output | |
| ) | |
| with gr.Tab("๐ฌ Economic Chat"): | |
| chatbot = gr.Chatbot( | |
| label="", | |
| height=600, | |
| show_label=False, | |
| type="messages", | |
| avatar_images=(None, "https://em-content.zobj.net/source/twitter/376/flag-saudi-arabia_1f1f8-1f1e6.png") | |
| ) | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| label="", | |
| placeholder="Ask me about Saudi GDP, forecasts, sectors, or Vision 2030...", | |
| lines=1, | |
| scale=9, | |
| show_label=False | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| gr.Examples( | |
| examples=[ | |
| ["What's the forecast for 2025?"], | |
| ["How accurate is the model?"], | |
| ["What drives GDP growth?"], | |
| ["Tell me about Vision 2030"] | |
| ], | |
| inputs=msg_input, | |
| label="Try these questions" | |
| ) | |
| clear_btn = gr.Button("๐๏ธ Clear Chat", size="sm") | |
| send_btn.click( | |
| fn=chatbot_interface, | |
| inputs=[msg_input, chatbot], | |
| outputs=chatbot | |
| ).then( | |
| fn=lambda: "", | |
| outputs=msg_input | |
| ) | |
| msg_input.submit( | |
| fn=chatbot_interface, | |
| inputs=[msg_input, chatbot], | |
| outputs=chatbot | |
| ).then( | |
| fn=lambda: "", | |
| outputs=msg_input | |
| ) | |
| clear_btn.click(fn=lambda: None, outputs=chatbot) | |
| if __name__ == "__main__": | |
| demo.launch() | |