Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import json | |
| import warnings | |
| import plotly.graph_objects as go | |
| # Silence irrelevant HF warnings | |
| warnings.filterwarnings( | |
| "ignore", | |
| category=FutureWarning, | |
| message=".*reduce_op.*" | |
| ) | |
| # ---------------------------- | |
| # Core callback (SAFE + HF READY) | |
| # ---------------------------- | |
| def run_curvopt(model_name, hardware, acc_budget): | |
| try: | |
| # ---------------------------- | |
| # ENERGY PLOT | |
| # ---------------------------- | |
| x = np.arange(1, 6) | |
| y = np.random.uniform(10, 50, size=5) | |
| fig_energy = go.Figure( | |
| data=go.Scatter( | |
| x=x, | |
| y=y, | |
| mode="lines+markers" | |
| ) | |
| ) | |
| fig_energy.update_layout( | |
| title="Layerwise Energy Consumption", | |
| xaxis_title="Layer Index", | |
| yaxis_title="Energy (mJ)" | |
| ) | |
| # ---------------------------- | |
| # PARETO PLOT | |
| # ---------------------------- | |
| acc = np.array([0.82, 0.85, 0.88, 0.90]) | |
| energy = np.array([55, 48, 40, 34]) | |
| fig_pareto = go.Figure( | |
| data=go.Scatter( | |
| x=acc, | |
| y=energy, | |
| mode="lines+markers" | |
| ) | |
| ) | |
| fig_pareto.update_layout( | |
| title="Energy–Accuracy Pareto Frontier", | |
| xaxis_title="Accuracy", | |
| yaxis_title="Energy (mJ)" | |
| ) | |
| # ---------------------------- | |
| # POLICY JSON | |
| # ---------------------------- | |
| policy = { | |
| "model": model_name, | |
| "hardware": hardware, | |
| "accuracy_budget": acc_budget, | |
| "quantization": "INT8", | |
| "curvature_metric": "trace(H)", | |
| "activation_information": "mutual_information", | |
| "selected_layers": [1, 3, 5], | |
| "expected_energy_saving_percent": 32.4 | |
| } | |
| policy_json = json.dumps(policy, indent=2) | |
| return fig_energy, fig_pareto, policy_json | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| # Must return SAME NUMBER of outputs | |
| return None, None, f"ERROR:\n{str(e)}" | |
| # ---------------------------- | |
| # UI | |
| # ---------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # ⚡ CurvOpt | |
| **Energy-Efficient Inference via Curvature & Information** | |
| A systems-oriented ML demo focusing on **lower energy and compute footprint**. | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_dd = gr.Dropdown( | |
| choices=["ResNet18", "MobileNetV2", "ViT-Tiny"], | |
| value="ResNet18", | |
| label="Model" | |
| ) | |
| hardware_radio = gr.Radio( | |
| choices=["CPU", "GPU", "EDGE"], | |
| value="CPU", | |
| label="Target Hardware" | |
| ) | |
| acc_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| step=0.1, | |
| value=0.5, | |
| label="Accuracy Budget (%)" | |
| ) | |
| run_btn = gr.Button("🚀 Run CurvOpt") | |
| with gr.Row(): | |
| energy_plot = gr.Plot(label="Energy Profile") | |
| pareto_plot = gr.Plot(label="Energy–Accuracy Pareto") | |
| policy_box = gr.Code( | |
| label="Generated Policy (JSON)", | |
| language="json" | |
| ) | |
| run_btn.click( | |
| fn=run_curvopt, | |
| inputs=[model_dd, hardware_radio, acc_slider], | |
| outputs=[energy_plot, pareto_plot, policy_box] | |
| ) | |
| # ---------------------------- | |
| # Launch (HF-SAFE) | |
| # ---------------------------- | |
| demo.launch( | |
| theme=gr.themes.Soft(), | |
| ssr_mode=False | |
| ) |