Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import plotly.express as px | |
| import requests | |
| import re | |
| import os | |
| import glob | |
| # Download the main results file | |
| def download_main_results(): | |
| url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv" | |
| if not os.path.exists("results-imagenet.csv"): | |
| response = requests.get(url) | |
| with open("results-imagenet.csv", "wb") as f: | |
| f.write(response.content) | |
| def download_github_csvs_api( | |
| repo="huggingface/pytorch-image-models", | |
| folder="results", | |
| filename_pattern=r"benchmark-.*\.csv", | |
| output_dir="benchmarks", | |
| ): | |
| """Download benchmark CSV files from GitHub API.""" | |
| api_url = f"https://api.github.com/repos/{repo}/contents/{folder}" | |
| r = requests.get(api_url) | |
| if r.status_code != 200: | |
| return [] | |
| files = r.json() | |
| matched_files = [f["name"] for f in files if re.match(filename_pattern, f["name"])] | |
| if not matched_files: | |
| return [] | |
| raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/" | |
| os.makedirs(output_dir, exist_ok=True) | |
| for fname in matched_files: | |
| raw_url = raw_base + fname | |
| out_path = os.path.join(output_dir, fname) | |
| if not os.path.exists(out_path): # Only download if not exists | |
| resp = requests.get(raw_url) | |
| if resp.ok: | |
| with open(out_path, "wb") as f: | |
| f.write(resp.content) | |
| return matched_files | |
| def load_main_data(): | |
| """Load the main ImageNet results.""" | |
| download_main_results() | |
| df_results = pd.read_csv("results-imagenet.csv") | |
| df_results["model_org"] = df_results["model"] | |
| df_results["model"] = df_results["model"].str.split(".").str[0] | |
| return df_results | |
| def get_data(benchmark_file, df_results): | |
| """Process benchmark data and merge with main results.""" | |
| pattern = ( | |
| r"^(?:" | |
| r"eva|" | |
| r"maxx?vit(?:v2)?|" | |
| r"coatnet|coatnext|" | |
| r"convnext(?:v2)?|" | |
| r"beit(?:v2)?|" | |
| r"efficient(?:net(?:v2)?|former(?:v2)?|vit)|" | |
| r"regnet[xyvz]?|" | |
| r"levit|" | |
| r"mobilenet(?:v\d*)?|" | |
| r"vitd?|" | |
| r"swin(?:v2)?" | |
| r")$" | |
| ) | |
| if not os.path.exists(benchmark_file): | |
| return pd.DataFrame() | |
| df = pd.read_csv(benchmark_file).merge(df_results, on="model") | |
| df["secs"] = 1.0 / df["infer_samples_per_sec"] | |
| df["family"] = df.model.str.extract("^([a-z]+?(?:v2)?)(?:\d|_|$)") | |
| df = df[~df.model.str.endswith("gn")] | |
| df.loc[df.model.str.contains("resnet.*d"), "family"] = ( | |
| df.loc[df.model.str.contains("resnet.*d"), "family"] + "d" | |
| ) | |
| return df[df.family.str.contains(pattern)] | |
| def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y): | |
| """Create the scatter plot based on user selections.""" | |
| df_results = load_main_data() | |
| df = get_data(benchmark_file, df_results) | |
| if df.empty: | |
| return None | |
| # Filter by selected families | |
| if selected_families: | |
| df = df[df["family"].isin(selected_families)] | |
| if df.empty: | |
| return None | |
| # Create the plot | |
| fig = px.scatter( | |
| df, | |
| width=1000, | |
| height=800, | |
| x=x_axis, | |
| y=y_axis, | |
| size=df['infer_img_size']**2, | |
| log_x=log_x, | |
| log_y=log_y, | |
| color="family", | |
| hover_name="model_org", | |
| hover_data=["infer_samples_per_sec", "infer_img_size"], | |
| title=f"Model Performance: {y_axis} vs {x_axis}", | |
| ) | |
| return fig | |
| def setup_interface(): | |
| """Set up the Gradio interface.""" | |
| # Download benchmark files | |
| downloaded_files = download_github_csvs_api() | |
| # Get available benchmark files | |
| benchmark_files = glob.glob("benchmarks/benchmark-*.csv") | |
| if not benchmark_files: | |
| benchmark_files = ["No benchmark files found"] | |
| # Load sample data to get families and columns | |
| df_results = load_main_data() | |
| # Relevant columns for plotting | |
| plot_columns = [ | |
| "top1", | |
| "top5", | |
| "infer_samples_per_sec", | |
| "secs", | |
| "param_count_x", | |
| "infer_img_size", | |
| ] | |
| # Get families from a sample file (if available) | |
| families = [] | |
| if benchmark_files and benchmark_files[0] != "No benchmark files found": | |
| sample_df = get_data(benchmark_files[0], df_results) | |
| if not sample_df.empty: | |
| families = sorted(sample_df["family"].unique().tolist()) | |
| return benchmark_files, plot_columns, families | |
| # Initialize the interface | |
| benchmark_files, plot_columns, families = setup_interface() | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Image Model Performance Analysis") as demo: | |
| gr.Markdown("# Image Model Performance Analysis") | |
| gr.Markdown( | |
| "Analyze and visualize performance metrics of different image models based on benchmark data." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Set preferred default file | |
| preferred_file = ( | |
| "benchmarks/benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv" | |
| ) | |
| default_file = ( | |
| preferred_file | |
| if preferred_file in benchmark_files | |
| else (benchmark_files[0] if benchmark_files else None) | |
| ) | |
| benchmark_dropdown = gr.Dropdown( | |
| choices=benchmark_files, | |
| value=default_file, | |
| label="Select Benchmark File", | |
| ) | |
| x_axis_radio = gr.Radio(choices=plot_columns, value="secs", label="X-axis") | |
| y_axis_radio = gr.Radio(choices=plot_columns, value="top1", label="Y-axis") | |
| family_checkboxes = gr.CheckboxGroup( | |
| choices=families, value=families, label="Select Model Families" | |
| ) | |
| log_x_checkbox = gr.Checkbox(value=True, label="Log scale X-axis") | |
| log_y_checkbox = gr.Checkbox(value=False, label="Log scale Y-axis") | |
| update_button = gr.Button("Update Plot", variant="primary") | |
| with gr.Column(scale=2): | |
| plot_output = gr.Plot() | |
| gr.Markdown("The benchmark data comes from the [pytorch-image-models](https://github.com/huggingface/pytorch-image-models) repository by [Ross Wightman](https://huggingface.co/rwightman).") | |
| gr.Markdown("Based on the original notebook by [Jeremy Howard](https://huggingface.co/jph00).") | |
| gr.Markdown("Read more about the project on my blog [dronelab.dev](https://dronelab.dev/posts/which-image-models-are-best-updated/).") | |
| # Update plot when button is clicked | |
| update_button.click( | |
| fn=create_plot, | |
| inputs=[ | |
| benchmark_dropdown, | |
| x_axis_radio, | |
| y_axis_radio, | |
| family_checkboxes, | |
| log_x_checkbox, | |
| log_y_checkbox, | |
| ], | |
| outputs=plot_output, | |
| ) | |
| # Auto-update when benchmark file changes | |
| def update_families(benchmark_file): | |
| if not benchmark_file or benchmark_file == "No benchmark files found": | |
| return gr.CheckboxGroup(choices=[], value=[]) | |
| df_results = load_main_data() | |
| df = get_data(benchmark_file, df_results) | |
| if df.empty: | |
| return gr.CheckboxGroup(choices=[], value=[]) | |
| new_families = sorted(df["family"].unique().tolist()) | |
| return gr.CheckboxGroup(choices=new_families, value=new_families) | |
| benchmark_dropdown.change( | |
| fn=update_families, inputs=benchmark_dropdown, outputs=family_checkboxes | |
| ) | |
| # Load initial plot | |
| demo.load( | |
| fn=create_plot, | |
| inputs=[ | |
| benchmark_dropdown, | |
| x_axis_radio, | |
| y_axis_radio, | |
| family_checkboxes, | |
| log_x_checkbox, | |
| log_y_checkbox, | |
| ], | |
| outputs=plot_output, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |