curvopt-space / app.py
GirishaBuilds01's picture
Update app.py
7969146 verified
import gradio as gr
import numpy as np
import json
import warnings
import plotly.graph_objects as go
# Silence irrelevant HF warnings
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message=".*reduce_op.*"
)
# ----------------------------
# Core callback (SAFE + HF READY)
# ----------------------------
def run_curvopt(model_name, hardware, acc_budget):
try:
# ----------------------------
# ENERGY PLOT
# ----------------------------
x = np.arange(1, 6)
y = np.random.uniform(10, 50, size=5)
fig_energy = go.Figure(
data=go.Scatter(
x=x,
y=y,
mode="lines+markers"
)
)
fig_energy.update_layout(
title="Layerwise Energy Consumption",
xaxis_title="Layer Index",
yaxis_title="Energy (mJ)"
)
# ----------------------------
# PARETO PLOT
# ----------------------------
acc = np.array([0.82, 0.85, 0.88, 0.90])
energy = np.array([55, 48, 40, 34])
fig_pareto = go.Figure(
data=go.Scatter(
x=acc,
y=energy,
mode="lines+markers"
)
)
fig_pareto.update_layout(
title="Energy–Accuracy Pareto Frontier",
xaxis_title="Accuracy",
yaxis_title="Energy (mJ)"
)
# ----------------------------
# POLICY JSON
# ----------------------------
policy = {
"model": model_name,
"hardware": hardware,
"accuracy_budget": acc_budget,
"quantization": "INT8",
"curvature_metric": "trace(H)",
"activation_information": "mutual_information",
"selected_layers": [1, 3, 5],
"expected_energy_saving_percent": 32.4
}
policy_json = json.dumps(policy, indent=2)
return fig_energy, fig_pareto, policy_json
except Exception as e:
import traceback
print(traceback.format_exc())
# Must return SAME NUMBER of outputs
return None, None, f"ERROR:\n{str(e)}"
# ----------------------------
# UI
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown(
"""
# ⚡ CurvOpt
**Energy-Efficient Inference via Curvature & Information**
A systems-oriented ML demo focusing on **lower energy and compute footprint**.
"""
)
with gr.Row():
model_dd = gr.Dropdown(
choices=["ResNet18", "MobileNetV2", "ViT-Tiny"],
value="ResNet18",
label="Model"
)
hardware_radio = gr.Radio(
choices=["CPU", "GPU", "EDGE"],
value="CPU",
label="Target Hardware"
)
acc_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.5,
label="Accuracy Budget (%)"
)
run_btn = gr.Button("🚀 Run CurvOpt")
with gr.Row():
energy_plot = gr.Plot(label="Energy Profile")
pareto_plot = gr.Plot(label="Energy–Accuracy Pareto")
policy_box = gr.Code(
label="Generated Policy (JSON)",
language="json"
)
run_btn.click(
fn=run_curvopt,
inputs=[model_dd, hardware_radio, acc_slider],
outputs=[energy_plot, pareto_plot, policy_box]
)
# ----------------------------
# Launch (HF-SAFE)
# ----------------------------
demo.launch(
theme=gr.themes.Soft(),
ssr_mode=False
)