""" PowerZoo: Interactive Web Showcase HuggingFace Spaces application with Gradio + Plotly. 5 Tabs: Project Overview | Power System Explorer | Data Visualization | Training Dashboard | Algorithm Comparison """ import json import os from pathlib import Path import numpy as np import pandas as pd import plotly.graph_objects as go from plotly.subplots import make_subplots import gradio as gr # === Monkey-patch: fix Gradio additionalProperties schema error with Plotly === _original_plot_init = gr.Plot.__init__ def _patched_plot_init(self, *args, **kwargs): _original_plot_init(self, *args, **kwargs) if hasattr(self, "schema") and isinstance(self.schema, dict): self.schema.pop("additionalProperties", None) gr.Plot.__init__ = _patched_plot_init # === Data Loading === DATA_DIR = Path(__file__).parent / "data" ASSETS_DIR = Path(__file__).parent / "assets" with open(DATA_DIR / "loadshapes.json") as f: LOADSHAPES = json.load(f) with open(DATA_DIR / "pv_sample.json") as f: PV_DATA = json.load(f) with open(DATA_DIR / "environments.json") as f: ENVIRONMENTS = json.load(f) with open(DATA_DIR / "algorithms.json") as f: ALGORITHMS = json.load(f) with open(DATA_DIR / "sample_training.json") as f: TRAINING = json.load(f) # Architecture diagram JSON data (Plotly figures) ARCH_FIGS = {} _ARCH_NAMES = [ "algorithm_hierarchy", "training_pipeline", "runner_algorithm_matrix", "happo_family", "mappo_family", "dan_happo", "ddpg_family", "hasac", "value_decomposition", "twots_vvc", ] for fig_name in _ARCH_NAMES: fig_path = DATA_DIR / f"{fig_name}.json" if fig_path.exists(): ARCH_FIGS[fig_name] = go.Figure(json.loads(fig_path.read_text())) # === Color Palette === COLORS = { "primary": "#1565c0", "secondary": "#5e35b1", "accent": "#2e7d32", "warning": "#e65100", "agents": ["#1565c0", "#5e35b1", "#2e7d32", "#e65100", "#c62828", "#00838f"], "loadshapes": ["#1565c0", "#e65100", "#2e7d32"], } # === Init figures (required for gr.Plot(value=fig) pattern) === _INIT_FIG = go.Figure() _INIT_FIG.update_layout( template="plotly_white", height=400, margin=dict(l=40, r=40, t=40, b=40), ) # ============================================================ # Plot Factory Functions # ============================================================ def plot_load_profiles(selected_shapes: list[str]) -> go.Figure: """Plot annual load curves for selected LoadShapes.""" if not selected_shapes: fig = go.Figure() fig.update_layout(template="plotly_white", height=450) fig.add_annotation(text="Select at least one load shape", showarrow=False, font=dict(size=16)) return fig fig = go.Figure() for i, name in enumerate(selected_shapes): if name in LOADSHAPES: data = LOADSHAPES[name] # X-axis: approximate hours across the year (730 points, each ~12h apart) x = np.linspace(0, 8760, len(data)) fig.add_trace(go.Scatter( x=x, y=data, mode="lines", name=name, line=dict(color=COLORS["loadshapes"][i % 3], width=1.5), opacity=0.85, )) fig.update_layout( template="plotly_white", height=450, title="Annual Load Profiles (8760 hours)", xaxis_title="Hour of Year", yaxis_title="Load (p.u.)", legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), hovermode="x unified", ) return fig def plot_load_daily_stats(shape_name: str) -> go.Figure: """Plot daily mean +/- std band for a single LoadShape (reshape 365x24).""" if shape_name not in LOADSHAPES: fig = go.Figure() fig.update_layout(template="plotly_white", height=450) fig.add_annotation(text="Select a load shape", showarrow=False, font=dict(size=16)) return fig raw = LOADSHAPES[shape_name] # Upsample back to ~8760 via linear interpolation for reshape full = np.interp(np.arange(8760), np.linspace(0, 8759, len(raw)), raw) daily = full.reshape(365, 24) mean = daily.mean(axis=0) std = daily.std(axis=0) hours = list(range(24)) fig = go.Figure() # Std band fig.add_trace(go.Scatter( x=hours + hours[::-1], y=np.concatenate([mean + std, (mean - std)[::-1]]).tolist(), fill="toself", fillcolor="rgba(21, 101, 192, 0.15)", line=dict(color="rgba(0,0,0,0)"), showlegend=True, name="Std Dev Band", )) # Mean line fig.add_trace(go.Scatter( x=hours, y=mean.tolist(), mode="lines+markers", name="Daily Mean", line=dict(color=COLORS["primary"], width=2.5), marker=dict(size=5), )) fig.update_layout( template="plotly_white", height=450, title=f"{shape_name} - Daily Average Pattern", xaxis_title="Hour of Day", yaxis_title="Load (p.u.)", legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), ) return fig def plot_pv_timeseries() -> go.Figure: """Plot 7-day PV irradiance time series.""" fig = go.Figure() fig.add_trace(go.Scatter( x=PV_DATA["timestamps"], y=PV_DATA["irradiance_wm2"], mode="lines", name="Irradiance", line=dict(color=COLORS["warning"], width=1.2), fill="tozeroy", fillcolor="rgba(230, 81, 0, 0.1)", )) fig.update_layout( template="plotly_white", height=450, title="7-Day Solar Irradiance (Jan 2025, 10-min resolution)", xaxis_title="Timestamp", yaxis_title="Irradiance (W/m²)", ) return fig def plot_pv_scatter() -> go.Figure: """Plot irradiance vs temperature scatter, colored by hour of day.""" fig = go.Figure() fig.add_trace(go.Scatter( x=PV_DATA["temperature_c"], y=PV_DATA["irradiance_wm2"], mode="markers", marker=dict( size=4, color=PV_DATA["hours"], colorscale="Viridis", colorbar=dict(title="Hour"), opacity=0.6, ), text=[f"Hour: {h:.1f}" for h in PV_DATA["hours"]], hovertemplate="Temp: %{x:.1f}°C
Irradiance: %{y:.0f} W/m²
%{text}", )) fig.update_layout( template="plotly_white", height=450, title="Irradiance vs Temperature (colored by hour of day)", xaxis_title="Temperature (°C)", yaxis_title="Irradiance (W/m²)", ) return fig def plot_training_rewards() -> go.Figure: """Plot episode rewards training curve.""" if "episode_rewards" not in TRAINING: return _INIT_FIG data = TRAINING["episode_rewards"] steps = [pt[0] for pt in data] values = [pt[1] for pt in data] fig = go.Figure() fig.add_trace(go.Scatter( x=steps, y=values, mode="lines+markers", name="Episode Rewards", line=dict(color=COLORS["primary"], width=2), marker=dict(size=4), )) fig.update_layout( template="plotly_white", height=450, title="HAPPO on 13Bus - Episode Rewards", xaxis_title="Training Steps", yaxis_title="Total Reward", ) return fig def plot_training_metric(metric_name: str) -> go.Figure: """Plot per-agent comparison for a selected metric.""" fig = go.Figure() for agent_id in range(6): key = f"agent{agent_id}_{metric_name}" if key in TRAINING: data = TRAINING[key] steps = [pt[0] for pt in data] values = [pt[1] for pt in data] fig.add_trace(go.Scatter( x=steps, y=values, mode="lines+markers", name=f"Agent {agent_id}", line=dict(color=COLORS["agents"][agent_id], width=1.8), marker=dict(size=4), )) display_name = metric_name.replace("_", " ").title() fig.update_layout( template="plotly_white", height=450, title=f"Per-Agent {display_name}", xaxis_title="Training Steps", yaxis_title=display_name, legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), ) return fig def plot_power_metrics() -> go.Figure: """Plot 2x2 subplot: power_loss_kw, power_loss_kvar, total_power_kw, total_power_kvar.""" fig = make_subplots( rows=2, cols=2, subplot_titles=( "Power Loss (kW)", "Power Loss (kVar)", "Total Power (kW)", "Total Power (kVar)", ), vertical_spacing=0.12, horizontal_spacing=0.1, ) metrics = [ ("power_loss_kw", 1, 1, COLORS["primary"]), ("power_loss_kvar", 1, 2, COLORS["secondary"]), ("total_power_kw", 2, 1, COLORS["accent"]), ("total_power_kvar", 2, 2, COLORS["warning"]), ] for key, row, col, color in metrics: if key in TRAINING: data = TRAINING[key] steps = [pt[0] for pt in data] values = [pt[1] for pt in data] fig.add_trace( go.Scatter( x=steps, y=values, mode="lines+markers", line=dict(color=color, width=2), marker=dict(size=4), showlegend=False, ), row=row, col=col, ) fig.update_layout( template="plotly_white", height=550, title_text="Power System Metrics During Training", ) return fig # ============================================================ # Environment Explorer Helpers # ============================================================ def get_env_names() -> list[str]: """Return list of environment names.""" return list(ENVIRONMENTS.keys()) def get_system_names(env_name: str) -> gr.Dropdown: """Update system dropdown based on selected environment.""" if env_name and env_name in ENVIRONMENTS: systems = list(ENVIRONMENTS[env_name]["systems"].keys()) return gr.Dropdown(choices=systems, value=systems[0] if systems else None) return gr.Dropdown(choices=[], value=None) def get_env_info(env_name: str) -> str: """Return environment description as markdown.""" if not env_name or env_name not in ENVIRONMENTS: return "Select an environment to view details." env = ENVIRONMENTS[env_name] md = f"### {env_name}\n\n" md += f"**Description**: {env['description']}\n\n" md += f"**Action Space**: {env['action_space']}\n\n" md += f"**Observation**: {env['observation']}\n\n" md += f"**Reward**: {env['reward']}\n\n" md += "**Key Features**:\n" for feat in env["features"]: md += f"- {feat}\n" return md def get_system_table(env_name: str, system_name: str) -> pd.DataFrame: """Return system configuration as a DataFrame.""" if not env_name or not system_name: return pd.DataFrame({"Property": ["Select environment and system"], "Value": ["-"]}) env = ENVIRONMENTS.get(env_name, {}) system = env.get("systems", {}).get(system_name, {}) if not system: return pd.DataFrame({"Property": ["System not found"], "Value": ["-"]}) rows = [] for key, val in system.items(): if key == "name": continue display_key = key.replace("_", " ").title() rows.append({"Property": display_key, "Value": str(val)}) return pd.DataFrame(rows) # ============================================================ # Algorithm Table # ============================================================ def get_algorithm_df() -> pd.DataFrame: """Return algorithm comparison DataFrame.""" return pd.DataFrame([ { "Algorithm": a["name"], "Type": a["type"], "Policy": a["policy"], "Action Space": a["action_space"], "Key Feature": a["key_feature"], } for a in ALGORITHMS ]) # ============================================================ # Build Gradio App # ============================================================ def build_app() -> gr.Blocks: """Construct the Gradio Blocks application with 5 tabs.""" with gr.Blocks( title="PowerZoo: MARL for Power Systems", theme=gr.themes.Soft( primary_hue="blue", secondary_hue="purple", ), ) as app: # Header gr.Markdown( """ # ⚡ PowerZoo: A Universal MARL Platform for Power System Control **4 Environments** · **15 Algorithms** · **9 IEEE Test Systems** · **IEEE TSG 2025** """ ) with gr.Tabs(): # -------------------------------------------------------- # Tab 1: Project Overview # -------------------------------------------------------- with gr.Tab("Project Overview"): gr.Markdown( """ ## About PowerZoo PowerZoo is a comprehensive multi-agent reinforcement learning (MARL) platform designed for intelligent power system control. It provides a unified interface for training and evaluating MARL algorithms across diverse power system environments. ### Key Highlights """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown( """ **🏗️ 4 Environments** - PowerZoo VVC (Volt-VAR Control) - SmartGrid (PV Integration) - Stackelberg Game (Market) - DSR (Fault Recovery) """ ) with gr.Column(scale=1): gr.Markdown( """ **🤖 15 MARL Algorithms** - On-Policy: HAPPO, HATRPO, MAPPO, etc. - Off-Policy: HADDPG, HASAC, MADDPG, etc. - Value-Based: QMIX, HAD3QN - Special: 2TS-VVC, DAN-HAPPO """ ) with gr.Column(scale=1): gr.Markdown( """ **🔌 9 IEEE Test Systems** - 13-Bus, 34-Bus, 123-Bus, 8500-Node - PV variants (Conservative/Optimized/Aggressive) - From rapid prototyping to scalability testing """ ) # Architecture SVG svg_path = ASSETS_DIR / "architecture.svg" if svg_path.exists(): svg_content = svg_path.read_text() gr.HTML( f'
' f'{svg_content}' f'
' ) gr.Markdown( """ ### Citation > PowerZoo: A Universal Multi-Agent Reinforcement Learning Platform for Power System Control. > *IEEE Transactions on Smart Grid*, 2025. **Links**: [GitHub](https://github.com/XJTU-RL/PowerZoo) · [Paper](https://ieeexplore.ieee.org/) """ ) # -------------------------------------------------------- # Tab 2: Power System Explorer # -------------------------------------------------------- with gr.Tab("Power System Explorer"): gr.Markdown("## Explore Environments & IEEE Test Systems") with gr.Row(): env_dropdown = gr.Dropdown( choices=get_env_names(), label="Select Environment", value=get_env_names()[0], ) system_dropdown = gr.Dropdown( choices=[], label="Select IEEE System", ) env_info = gr.Markdown(value="Select an environment to view details.") system_table = gr.Dataframe( headers=["Property", "Value"], label="System Configuration", ) # Wire events env_dropdown.change( fn=get_system_names, inputs=env_dropdown, outputs=system_dropdown, ) env_dropdown.change( fn=get_env_info, inputs=env_dropdown, outputs=env_info, ) system_dropdown.change( fn=get_system_table, inputs=[env_dropdown, system_dropdown], outputs=system_table, ) # Trigger initial load app.load( fn=get_system_names, inputs=env_dropdown, outputs=system_dropdown, ) app.load( fn=get_env_info, inputs=env_dropdown, outputs=env_info, ) # -------------------------------------------------------- # Tab 3: Data Visualization # -------------------------------------------------------- with gr.Tab("Data Visualization"): with gr.Tabs(): # Sub-tab: Load Profiles with gr.Tab("Load Profiles"): gr.Markdown("### Annual Load Shape Visualization") load_checkbox = gr.CheckboxGroup( choices=list(LOADSHAPES.keys()), value=list(LOADSHAPES.keys()), label="Select Load Shapes", ) load_annual_plot = gr.Plot(value=plot_load_profiles(list(LOADSHAPES.keys()))) gr.Markdown("### Daily Average Pattern") load_radio = gr.Radio( choices=list(LOADSHAPES.keys()), value="LoadShape1", label="Select Load Shape for Daily Analysis", ) load_daily_plot = gr.Plot(value=plot_load_daily_stats("LoadShape1")) load_checkbox.change( fn=plot_load_profiles, inputs=load_checkbox, outputs=load_annual_plot, ) load_radio.change( fn=plot_load_daily_stats, inputs=load_radio, outputs=load_daily_plot, ) # Sub-tab: PV Data with gr.Tab("PV Generation"): gr.Markdown("### Solar PV Data (January 2025, First 7 Days)") pv_ts_plot = gr.Plot(value=plot_pv_timeseries()) gr.Markdown("### Irradiance-Temperature Correlation") pv_scatter_plot = gr.Plot(value=plot_pv_scatter()) # -------------------------------------------------------- # Tab 4: Training Dashboard # -------------------------------------------------------- with gr.Tab("Training Dashboard"): gr.Markdown( """ ## HAPPO Training on IEEE 13-Bus System Sample training metrics from a HAPPO experiment on the PowerZoo VVC environment. """ ) rewards_plot = gr.Plot(value=plot_training_rewards()) gr.Markdown("### Per-Agent Metrics") metric_dropdown = gr.Dropdown( choices=["policy_loss", "dist_entropy"], value="policy_loss", label="Select Metric", ) agent_plot = gr.Plot(value=plot_training_metric("policy_loss")) metric_dropdown.change( fn=plot_training_metric, inputs=metric_dropdown, outputs=agent_plot, ) gr.Markdown("### Power System Metrics") power_plot = gr.Plot(value=plot_power_metrics()) # -------------------------------------------------------- # Tab 5: Algorithm Comparison # -------------------------------------------------------- with gr.Tab("Algorithm Comparison"): gr.Markdown( """ ## 15 MARL Algorithms PowerZoo supports a comprehensive suite of multi-agent reinforcement learning algorithms spanning on-policy, off-policy, value-based, and specialized methods. """ ) gr.Dataframe( value=get_algorithm_df(), label="Algorithm Feature Matrix", interactive=False, ) # -------------------------------------------------------- # Tab 6: Architecture Diagrams # -------------------------------------------------------- with gr.Tab("Architecture Diagrams"): gr.Markdown( """ ## Interactive Architecture Diagrams Explore the algorithm inheritance hierarchy, training pipeline flow, and runner-algorithm compatibility matrix. """ ) if "algorithm_hierarchy" in ARCH_FIGS: gr.Markdown("### Algorithm Inheritance Hierarchy") gr.Plot(value=ARCH_FIGS["algorithm_hierarchy"]) if "training_pipeline" in ARCH_FIGS: gr.Markdown("### Training Pipeline Flow") gr.Plot(value=ARCH_FIGS["training_pipeline"]) if "runner_algorithm_matrix" in ARCH_FIGS: gr.Markdown("### Runner-Algorithm Compatibility Matrix") gr.Plot(value=ARCH_FIGS["runner_algorithm_matrix"]) gr.Markdown("---\n## Algorithm Internal Architectures") _algo_details = [ ("happo_family", "HAPPO / HATRPO / HAA2C"), ("mappo_family", "MAPPO / SN-MAPPO"), ("dan_happo", "DAN-HAPPO"), ("ddpg_family", "DDPG Family (HADDPG / HATD3 / MADDPG / MATD3)"), ("hasac", "HASAC"), ("value_decomposition", "QMIX / HAD3QN"), ("twots_vvc", "2TS-VVC"), ] for key, label in _algo_details: if key in ARCH_FIGS: gr.Markdown(f"### {label}") gr.Plot(value=ARCH_FIGS[key]) # Footer gr.Markdown( """ --- **PowerZoo** · MIT License · [XJTU-RL](https://github.com/XJTU-RL) · IEEE Transactions on Smart Grid, 2025 """ ) return app # ============================================================ # Launch # ============================================================ if __name__ == "__main__": app = build_app() app.launch( server_name="0.0.0.0", server_port=7860, share=False, factory_reboot=True, allowed_paths=["assets"], )