APOO-Traffic-Optimizer / apoo_figures.py
omshrivastava's picture
Add figure generator module
3dcf488 verified
"""
APOO Research Paper — Figure Generator & Architecture Diagram
=============================================================
Generates all figures for the APOO research paper and
a clean, production-quality system architecture diagram.
"""
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch
import matplotlib.patheffects as pe
import os
from apoo_core import (
RobertsonDispersion, IndianTrafficGenerator, CorridorSimulator,
OffsetOptimizer, EmissionCalculator, RoadLink, Intersection,
SignalPhase, format_kpi_comparison,
PCU_INDIA, VEHICLE_MIX_PROFILES, WEATHER_SPEED_FACTORS,
EMISSION_FACTORS, IDLE_EMISSION_RATES,
)
from apoo_ml import (
APOOPredictor, train_apoo_model, FEATURE_COLUMNS, TARGET_COLUMN,
prepare_features,
)
FIGURE_DIR = "/app/figures"
os.makedirs(FIGURE_DIR, exist_ok=True)
# ============================================================
# 1. SYSTEM ARCHITECTURE DIAGRAM (Fixed — Clean Matplotlib)
# ============================================================
def create_architecture_diagram():
"""
Clean, non-overlapping system architecture diagram using matplotlib.
Uses FancyBboxPatch for boxes and annotate for arrows.
Includes database/data source layer.
"""
fig, ax = plt.subplots(figsize=(18, 12))
ax.set_xlim(0, 18)
ax.set_ylim(0, 12)
ax.axis('off')
fig.patch.set_facecolor('white')
# ---- Colour palette ----
C = {
'data': '#1abc9c', # teal — databases
'input': '#3498db', # blue — inputs
'process': '#e74c3c', # red — core processing
'ml': '#9b59b6', # purple — ML
'output': '#2ecc71', # green — outputs
'india': '#f39c12', # orange — India-specific
'arrow': '#34495e', # dark grey
'bg': '#ecf0f1', # light grey background strips
}
def draw_box(x, y, w, h, title, lines, color, title_size=10, line_size=8.5):
"""Draw a rounded box with title and bullet lines."""
box = FancyBboxPatch(
(x, y), w, h,
boxstyle="round,pad=0.15",
facecolor=color, alpha=0.13,
edgecolor=color, linewidth=2.2,
)
ax.add_patch(box)
# Title (bold)
ax.text(x + w/2, y + h - 0.30, title,
ha='center', va='top', fontsize=title_size,
fontweight='bold', color='#2c3e50',
path_effects=[pe.withStroke(linewidth=3, foreground='white')])
# Body lines
for i, line in enumerate(lines):
ax.text(x + 0.25, y + h - 0.65 - i*0.34, line,
ha='left', va='top', fontsize=line_size, color='#34495e')
def draw_arrow(x1, y1, x2, y2, style='->', color=C['arrow'], lw=1.8):
ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
arrowprops=dict(arrowstyle=style, color=color,
lw=lw, connectionstyle='arc3,rad=0.0'))
def draw_curved_arrow(x1, y1, x2, y2, rad=0.15, color=C['arrow'], lw=1.5):
ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
arrowprops=dict(arrowstyle='->', color=color, lw=lw,
connectionstyle=f'arc3,rad={rad}'))
# ---- Section background strips ----
for (x0, label, col) in [(0, 'DATA SOURCES', C['data']),
(3.7, 'INPUTS', C['input']),
(7.4, 'CORE PROCESSING', C['process']),
(12.8, 'OUTPUTS', C['output'])]:
w = 3.4 if x0 < 7 else (5.1 if x0 < 10 else 4.9)
ax.add_patch(plt.Rectangle((x0, 0.2), w, 11.3, facecolor=col, alpha=0.04,
edgecolor=col, linewidth=0.8, linestyle='--'))
ax.text(x0 + w/2, 11.65, label, ha='center', va='bottom',
fontsize=11, fontweight='bold', color=col,
bbox=dict(boxstyle='round,pad=0.3', facecolor='white',
edgecolor=col, alpha=0.9, linewidth=1.5))
# ================ COLUMN 1: DATA SOURCES (x=0.2) ================
draw_box(0.2, 9.0, 3.2, 2.0,
"🗄️ Synthetic / SUMO Data",
["• SUMO simulator output",
"• Python random generator",
"• Calibrated for India",
"• 5,000+ training samples"],
C['data'])
draw_box(0.2, 6.3, 3.2, 2.2,
"📂 Historical / Kaggle",
["• Traffic count surveys (IRC)",
"• Kaggle traffic datasets",
"• Probe vehicle GPS traces",
"• Delhi / Bengaluru FCD"],
C['data'])
draw_box(0.2, 3.2, 3.2, 2.6,
"🏛️ Standards & Models",
["• IRC:106-1990 (PCU values)",
"• ARAI BS-VI emission factors",
"• CPCB air quality data",
"• Robertson (1969) params",
"• Greenshields (1935) model"],
C['india'])
draw_box(0.2, 0.5, 3.2, 2.2,
"🔌 Real-Time (Future)",
["• Loop detectors / cameras",
"• ANPR vehicle classification",
"• V2I (Connected vehicles)",
"• Weather API (IMD)"],
C['data'], line_size=8)
# ================ COLUMN 2: INPUTS (x=3.9) ================
draw_box(3.9, 9.0, 3.2, 2.0,
"🛣️ Static Inputs",
["• Road length L (m)",
"• Speed limit (km/h)",
"• Lanes, gradient, width",
"• Saturation flow (PCU/hr)"],
C['input'])
draw_box(3.9, 6.3, 3.2, 2.2,
"📊 Dynamic Inputs",
["• Green start time t_green",
"• Platoon size & composition",
"• Traffic density / queue",
"• Time-of-day, day-of-week"],
C['input'])
draw_box(3.9, 3.2, 3.2, 2.6,
"🇮🇳 India-Specific Inputs",
["• Vehicle mix: 55-70% 2W",
"• Side friction (0.1–0.6)",
"• Monsoon weather flag",
"• PCU vector (IRC:106)",
"• β dispersion (0.50–0.80)"],
C['india'])
# ================ COLUMN 3: CORE PROCESSING (x=7.6) ================
draw_box(7.6, 9.2, 4.8, 1.8,
"⚙️ Robertson's Platoon Dispersion",
[" q'(t) = F·q'(t-1) + (1-F)·α·q(t - t̄)",
" India-calibrated: β = 0.50–0.80",
" α = 1/(1+β·t̄) , F = 1 - α"],
C['process'], title_size=10.5)
draw_box(7.6, 6.5, 4.8, 2.2,
"🤖 ML Travel Time Predictor",
[" XGBoost Quantile Regression",
" P10 / P50 / P90 confidence bands",
" 20 features · SHAP explainability",
" MAE ≈ 8s · R² > 0.80"],
C['ml'], title_size=10.5)
draw_box(7.6, 3.8, 4.8, 2.2,
"🎯 Dynamic Offset Optimizer",
[" Maximize platoon-green overlap",
" Safety buffer: 10–20 s (India)",
" Constraints: min ped green, cycle",
" Fallback: fixed-time if uncertain"],
C['process'], title_size=10.5)
draw_box(7.6, 1.0, 4.8, 2.3,
"🌿 Emission & Fuel Calculator",
[" Running: E = Σ(count × EF × km)",
" Idling: E = Σ(count × IR × min)",
" Fuel: CO₂ / 2310 → litres",
" ARAI BS-VI + CPCB factors"],
C['process'], title_size=10.5)
# ================ COLUMN 4: OUTPUTS (x=13) ================
draw_box(13.0, 9.2, 4.5, 1.8,
"🚦 Optimized Signal Timing",
["• Adjusted green start offset",
"• Phase durations per cycle",
"• Emergency / bus priority"],
C['output'])
draw_box(13.0, 6.5, 4.5, 2.2,
"📈 KPI Dashboard",
["• Delay reduction (%)",
"• Platoons on green (target >70%)",
"• Stops avoided per km",
"• Throughput (veh/hr)"],
C['output'])
draw_box(13.0, 3.5, 4.5, 2.5,
"🌿 Emission / Fuel Report",
["• CO₂, CO, NOx, PM2.5 savings",
"• Fuel saved (mL / corridor)",
"• Health impact estimate",
"• Before / After comparison"],
C['output'])
draw_box(13.0, 0.7, 4.5, 2.3,
"📄 MoRTH Report Package",
["• Simulation KPI tables",
"• SHAP feature analysis",
"• Corridor map + timeline",
"• Roadmap to pilot"],
C['output'])
# ================ ARROWS — Data Sources → Inputs ================
for y_src, y_dst in [(10.0, 10.0), (7.4, 7.4), (4.5, 4.5)]:
draw_arrow(3.4, y_src, 3.9, y_dst, lw=1.5, color=C['data'])
# Standards → India-specific
draw_curved_arrow(3.4, 4.0, 3.9, 4.0, rad=-0.05, color=C['india'])
# ================ ARROWS — Inputs → Processing ================
draw_arrow(7.1, 10.0, 7.6, 10.1, lw=2, color=C['input'])
draw_arrow(7.1, 7.4, 7.6, 7.6, lw=2, color=C['input'])
draw_arrow(7.1, 4.8, 7.6, 4.9, lw=2, color=C['india'])
# India inputs also feed ML
draw_curved_arrow(7.1, 4.2, 7.6, 7.0, rad=-0.25, color=C['india'], lw=1.3)
# ================ ARROWS — Processing internal (vertical) ================
draw_arrow(10.0, 9.2, 10.0, 8.7, lw=2, color=C['process']) # Robertson → ML
draw_arrow(10.0, 6.5, 10.0, 6.0, lw=2, color=C['ml']) # ML → Optimizer
draw_arrow(10.0, 3.8, 10.0, 3.3, lw=2, color=C['process']) # Optimizer → Emissions
# ================ ARROWS — Processing → Outputs ================
draw_arrow(12.4, 10.1, 13.0, 10.1, lw=2, color=C['process'])
draw_arrow(12.4, 7.6, 13.0, 7.6, lw=2, color=C['ml'])
draw_arrow(12.4, 4.5, 13.0, 4.75, lw=2, color=C['process'])
draw_arrow(12.4, 2.1, 13.0, 1.85, lw=2, color=C['process'])
# ================ Feedback arrow (output → inputs) ================
draw_curved_arrow(15.25, 6.5, 5.5, 5.8, rad=0.4, color='#7f8c8d', lw=1.2)
ax.text(9.5, 12.0, "Feedback: actual arrival → retrain ML model",
ha='center', va='bottom', fontsize=8, fontstyle='italic', color='#7f8c8d')
# ================ Title ================
ax.text(9.0, 12.4,
"APOO System Architecture — Predict-Then-Optimize Framework",
ha='center', va='bottom', fontsize=15, fontweight='bold', color='#2c3e50',
bbox=dict(boxstyle='round,pad=0.5', facecolor='white',
edgecolor='#bdc3c7', linewidth=1.5))
plt.tight_layout(pad=0.5)
plt.close(fig)
return fig
# ============================================================
# 2. RESEARCH PAPER FIGURES
# ============================================================
def generate_all_figures(predictor, X_train, X_val, y_train, y_val, training_df):
"""Generate all figures needed for the research paper."""
figures = {}
# ---- Fig 1: Architecture (already done above) ----
figures['fig1_architecture'] = create_architecture_diagram()
figures['fig1_architecture'].savefig(f"{FIGURE_DIR}/fig1_architecture.png",
dpi=200, bbox_inches='tight',
facecolor='white')
print(" ✅ Fig 1: Architecture diagram")
# ---- Fig 2: Robertson Dispersion — India vs Western ----
fig2, axes2 = plt.subplots(1, 2, figsize=(14, 5))
robertson_india = RobertsonDispersion(beta=0.60)
robertson_west = RobertsonDispersion(beta=0.30)
green_time = 40
departure = np.zeros(120)
departure[5:5+green_time] = 0.5 # veh/s discharge
for ax_i, (link_len, speed) in enumerate([(300, 40), (500, 35)]):
ax = axes2[ax_i]
arr_india, tt_india = robertson_india.disperse(
departure, link_len, speed,
vehicle_mix=VEHICLE_MIX_PROFILES['metro_peak'],
weather='clear', side_friction=0.3)
arr_west, tt_west = robertson_west.disperse(
departure, link_len, speed,
vehicle_mix={'car': 1.0}, weather='clear', side_friction=0.1)
t_dep = np.arange(len(departure))
t_arr_i = np.arange(len(arr_india))
t_arr_w = np.arange(len(arr_west))
ax.fill_between(t_dep, departure, alpha=0.2, color='#e74c3c', label='Departure')
ax.plot(t_dep, departure, color='#e74c3c', linewidth=2.5)
ax.plot(t_arr_i, arr_india, color='#2ecc71', linewidth=2.5,
label=f'India (β={robertson_india._adjust_beta(VEHICLE_MIX_PROFILES["metro_peak"]):.2f})')
ax.fill_between(t_arr_i, arr_india, alpha=0.12, color='#2ecc71')
ax.plot(t_arr_w[:len(arr_west)], arr_west, color='#3498db', linewidth=2,
linestyle='--', label='Western (β=0.30)')
ax.axvline(tt_india + 5, linestyle=':', color='#7f8c8d', alpha=0.7)
ax.text(tt_india + 7, max(arr_india)*0.9, f't̄={tt_india:.0f}s',
fontsize=9, color='#7f8c8d')
ax.set_xlabel('Time (seconds)', fontsize=11)
ax.set_ylabel('Flow Rate (veh/s)', fontsize=11)
ax.set_title(f'Link: {link_len}m, {speed} km/h', fontsize=12, fontweight='bold')
ax.legend(fontsize=9, loc='upper right')
ax.grid(alpha=0.25)
ax.set_xlim(0, min(len(arr_india), 200))
fig2.suptitle('Figure 2: Robertson Platoon Dispersion — Indian vs Western Traffic',
fontsize=13, fontweight='bold', y=1.02)
plt.tight_layout()
fig2.savefig(f"{FIGURE_DIR}/fig2_robertson_dispersion.png", dpi=200, bbox_inches='tight')
plt.close(fig2)
figures['fig2_robertson'] = fig2
print(" ✅ Fig 2: Robertson dispersion comparison")
# ---- Fig 3: ML Predictions vs Actual with Uncertainty ----
figures['fig3_predictions'] = predictor.plot_predictions_vs_actual(
X_val, y_val, "Indian Urban Corridor (5,000 samples)")
figures['fig3_predictions'].savefig(f"{FIGURE_DIR}/fig3_predictions.png",
dpi=200, bbox_inches='tight')
print(" ✅ Fig 3: Predictions vs actual")
# ---- Fig 4: SHAP Beeswarm ----
figures['fig4_shap_beeswarm'] = predictor.plot_shap_beeswarm(max_display=15)
figures['fig4_shap_beeswarm'].savefig(f"{FIGURE_DIR}/fig4_shap_beeswarm.png",
dpi=200, bbox_inches='tight')
print(" ✅ Fig 4: SHAP beeswarm")
# ---- Fig 5: SHAP Bar ----
figures['fig5_shap_bar'] = predictor.plot_shap_bar(max_display=15)
figures['fig5_shap_bar'].savefig(f"{FIGURE_DIR}/fig5_shap_bar.png",
dpi=200, bbox_inches='tight')
print(" ✅ Fig 5: SHAP bar")
# ---- Fig 6: Quantile Calibration ----
figures['fig6_calibration'] = predictor.plot_quantile_calibration(X_val, y_val)
figures['fig6_calibration'].savefig(f"{FIGURE_DIR}/fig6_calibration.png",
dpi=200, bbox_inches='tight')
print(" ✅ Fig 6: Quantile calibration")
# ---- Fig 7: Simulation Comparison (multi-scenario) ----
fig7, axes7 = plt.subplots(2, 2, figsize=(14, 10))
gen = IndianTrafficGenerator(seed=42)
scenarios = [
("Metro Peak (Clear)", "metro", "morning_peak", "clear", 2000),
("Metro Peak (Monsoon)", "metro", "morning_peak", "heavy_rain", 2000),
("Tier-2 Peak (Clear)", "tier2", "morning_peak", "clear", 1500),
("Metro Off-Peak", "metro", "off_peak", "clear", 800),
]
kpi_rows = []
for idx, (name, city, profile, weather, flow) in enumerate(scenarios):
ax = axes7[idx // 2][idx % 2]
mix = VEHICLE_MIX_PROFILES.get(f'{city}_peak', VEHICLE_MIX_PROFILES['metro_peak'])
# Fixed-time
ints_f, links_f = gen.generate_corridor(5, 300, city)
gen_d = IndianTrafficGenerator(seed=100+idx)
demand = gen_d.generate_demand_profile(1.5, flow, profile, city)
sim_f = CorridorSimulator(ints_f, links_f, RobertsonDispersion(0.60), OffsetOptimizer(12))
res_f = sim_f.simulate(demand, mix, weather, 'fixed')
# APOO
ints_a, links_a = gen.generate_corridor(5, 300, city)
sim_a = CorridorSimulator(ints_a, links_a, RobertsonDispersion(0.60), OffsetOptimizer(12))
res_a = sim_a.simulate(demand, mix, weather, 'apoo')
# Plot delay timeline
if res_f.cycle_details and res_a.cycle_details:
df_f = pd.DataFrame(res_f.cycle_details).groupby('time_min')['delay_s'].mean()
df_a = pd.DataFrame(res_a.cycle_details).groupby('time_min')['delay_s'].mean()
ax.plot(df_f.index, df_f.values, 'o-', color='#e74c3c', markersize=5,
linewidth=2, label='Fixed-Time')
ax.plot(df_a.index, df_a.values, 's-', color='#2ecc71', markersize=5,
linewidth=2, label='APOO')
ax.set_title(name, fontsize=11, fontweight='bold')
ax.set_xlabel('Time (min)', fontsize=10)
ax.set_ylabel('Avg Delay (s)', fontsize=10)
ax.legend(fontsize=9)
ax.grid(alpha=0.25)
kpi_rows.append({
'Scenario': name,
'Fixed Delay (s)': round(res_f.avg_delay_per_vehicle_s, 1),
'APOO Delay (s)': round(res_a.avg_delay_per_vehicle_s, 1),
'Delay Reduction (%)': round((res_f.avg_delay_per_vehicle_s - res_a.avg_delay_per_vehicle_s) /
max(res_f.avg_delay_per_vehicle_s, 0.01) * 100, 1),
'Fixed Green (%)': round(res_f.green_arrival_pct, 1),
'APOO Green (%)': round(res_a.green_arrival_pct, 1),
'Fixed CO₂ (g)': round(res_f.total_co2_g, 0),
'APOO CO₂ (g)': round(res_a.total_co2_g, 0),
})
fig7.suptitle('Figure 7: Delay Comparison Across Scenarios — Fixed-Time vs APOO',
fontsize=13, fontweight='bold', y=1.01)
plt.tight_layout()
fig7.savefig(f"{FIGURE_DIR}/fig7_scenario_comparison.png", dpi=200, bbox_inches='tight')
plt.close(fig7)
figures['fig7_scenarios'] = fig7
print(" ✅ Fig 7: Multi-scenario simulation comparison")
# Save KPI table
kpi_df = pd.DataFrame(kpi_rows)
kpi_df.to_csv(f"{FIGURE_DIR}/table2_scenario_kpis.csv", index=False)
figures['table2_kpis'] = kpi_df
# ---- Fig 8: Feature Distribution Grid ----
fig8, axes8 = plt.subplots(3, 4, figsize=(16, 10))
axes_flat = axes8.flatten()
key_features = [
("link_length_m", "Link Length (m)"),
("speed_limit_kmh", "Speed Limit (km/h)"),
("density_veh_km_lane", "Density (veh/km/lane)"),
("pct_two_wheeler", "Two-Wheeler (%)"),
("pct_car", "Car (%)"),
("pct_auto", "Auto-Rickshaw (%)"),
("pct_bus", "Bus (%)"),
("weather_speed_factor", "Weather Factor"),
("platoon_size", "Platoon Size"),
("side_friction", "Side Friction"),
("actual_travel_time_s", "Travel Time (s)"),
("platoon_dispersion_s", "Dispersion (s)"),
]
for i, (col, label) in enumerate(key_features):
if col in training_df.columns:
axes_flat[i].hist(training_df[col], bins=40, alpha=0.7,
color='#3498db', edgecolor='white')
axes_flat[i].set_title(label, fontsize=9, fontweight='bold')
axes_flat[i].grid(alpha=0.2)
fig8.suptitle('Figure 8: Training Data Feature Distributions (N=5,000)',
fontsize=13, fontweight='bold')
plt.tight_layout()
fig8.savefig(f"{FIGURE_DIR}/fig8_feature_distributions.png", dpi=200, bbox_inches='tight')
plt.close(fig8)
figures['fig8_distributions'] = fig8
print(" ✅ Fig 8: Feature distributions")
# ---- Fig 9: PCU Values Bar Chart ----
fig9, ax9 = plt.subplots(figsize=(10, 5))
vtypes = list(PCU_INDIA.keys())
pcus = list(PCU_INDIA.values())
colors9 = ['#e74c3c' if p >= 2 else '#3498db' if p >= 1 else '#2ecc71' for p in pcus]
bars = ax9.barh(vtypes, pcus, color=colors9, edgecolor='white', height=0.6)
for bar, pcu in zip(bars, pcus):
ax9.text(bar.get_width() + 0.05, bar.get_y() + bar.get_height()/2,
f'{pcu:.1f}', va='center', fontsize=11, fontweight='bold')
ax9.set_xlabel('Passenger Car Units (PCU)', fontsize=12)
ax9.set_title('Figure 9: IRC:106-1990 PCU Values for Indian Vehicle Types',
fontsize=12, fontweight='bold')
ax9.grid(alpha=0.2, axis='x')
ax9.set_xlim(0, max(pcus) + 0.5)
plt.tight_layout()
fig9.savefig(f"{FIGURE_DIR}/fig9_pcu_values.png", dpi=200, bbox_inches='tight')
plt.close(fig9)
figures['fig9_pcu'] = fig9
print(" ✅ Fig 9: PCU values")
# ---- Fig 10: Emission Factors Grouped Bar ----
fig10, ax10 = plt.subplots(figsize=(12, 5))
pollutants = ['CO', 'NOx', 'CO2']
etypes = ['two_wheeler', 'car', 'auto_rickshaw', 'bus', 'truck']
x = np.arange(len(etypes))
width = 0.25
for i, poll in enumerate(pollutants):
vals = [EMISSION_FACTORS[v][poll] for v in etypes]
if poll == 'CO2':
vals = [v / 10 for v in vals] # Scale CO2 for visibility
label = 'CO₂ (÷10)'
else:
label = poll
ax10.bar(x + i*width, vals, width, label=label, alpha=0.8)
ax10.set_xticks(x + width)
ax10.set_xticklabels([t.replace('_', ' ').title() for t in etypes], fontsize=10)
ax10.set_ylabel('Emission Factor (g/km)', fontsize=11)
ax10.set_title('Figure 10: ARAI BS-VI Emission Factors by Vehicle Type',
fontsize=12, fontweight='bold')
ax10.legend(fontsize=10)
ax10.grid(alpha=0.2, axis='y')
plt.tight_layout()
fig10.savefig(f"{FIGURE_DIR}/fig10_emission_factors.png", dpi=200, bbox_inches='tight')
plt.close(fig10)
figures['fig10_emissions'] = fig10
print(" ✅ Fig 10: Emission factors")
# ---- Fig 11: Weather Impact ----
fig11, ax11 = plt.subplots(figsize=(8, 4.5))
conditions = list(WEATHER_SPEED_FACTORS.keys())
factors = [WEATHER_SPEED_FACTORS[c] * 100 for c in conditions]
colors11 = ['#2ecc71', '#f1c40f', '#e74c3c', '#9b59b6', '#34495e']
bars11 = ax11.bar(conditions, factors, color=colors11, edgecolor='white', width=0.5)
for bar, f in zip(bars11, factors):
ax11.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
f'{f:.0f}%', ha='center', fontsize=11, fontweight='bold')
ax11.set_ylabel('Effective Speed (% of Free-Flow)', fontsize=11)
ax11.set_title('Figure 11: Weather Impact on Effective Speed',
fontsize=12, fontweight='bold')
ax11.set_ylim(0, 115)
ax11.grid(alpha=0.2, axis='y')
plt.tight_layout()
fig11.savefig(f"{FIGURE_DIR}/fig11_weather_impact.png", dpi=200, bbox_inches='tight')
plt.close(fig11)
figures['fig11_weather'] = fig11
print(" ✅ Fig 11: Weather impact")
print(f"\n✅ All figures saved to {FIGURE_DIR}/")
return figures, kpi_df
# ============================================================
# MAIN
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("APOO Research Paper — Figure Generator")
print("=" * 60)
print("\n[1/2] Training ML model...")
predictor, X_train, X_val, y_train, y_val, training_df = train_apoo_model(
n_samples=5000, city_type='metro')
print("\n[2/2] Generating all figures...")
figures, kpi_df = generate_all_figures(
predictor, X_train, X_val, y_train, y_val, training_df)
print("\n📊 Scenario KPI Table:")
print(kpi_df.to_string(index=False))
print("\nDone!")