Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -50,7 +50,8 @@ def generate_heatmap_image(model_entry):
|
|
| 50 |
# Create a mask for the upper triangle (keeping the diagonal visible).
|
| 51 |
mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
|
| 52 |
|
| 53 |
-
|
|
|
|
| 54 |
sns.heatmap(matrix,
|
| 55 |
mask=mask,
|
| 56 |
annot=True,
|
|
@@ -66,11 +67,18 @@ def generate_heatmap_image(model_entry):
|
|
| 66 |
|
| 67 |
# Save the plot to a bytes buffer.
|
| 68 |
buf = BytesIO()
|
| 69 |
-
plt.savefig(buf, format="png")
|
| 70 |
plt.close()
|
| 71 |
buf.seek(0)
|
|
|
|
| 72 |
# Convert the buffer into a PIL Image.
|
| 73 |
image = Image.open(buf).convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
return image
|
| 75 |
|
| 76 |
def generate_heatmaps(selected_model_names):
|
|
@@ -88,7 +96,13 @@ def generate_heatmaps(selected_model_names):
|
|
| 88 |
# -------------------------------
|
| 89 |
# 3. Build the Gradio Interface
|
| 90 |
# -------------------------------
|
| 91 |
-
with gr.Blocks(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
gr.Markdown("## 3C3H Heatmap Generator")
|
| 93 |
gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
|
| 94 |
|
|
@@ -96,10 +110,123 @@ with gr.Blocks() as demo:
|
|
| 96 |
model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
|
| 97 |
|
| 98 |
generate_btn = gr.Button("Generate Heatmaps")
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
|
| 103 |
|
| 104 |
# Launch the Gradio app
|
| 105 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# Create a mask for the upper triangle (keeping the diagonal visible).
|
| 51 |
mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
|
| 52 |
|
| 53 |
+
# Set a consistent figure size that will work well in the gallery
|
| 54 |
+
plt.figure(figsize=(6, 5), dpi=100)
|
| 55 |
sns.heatmap(matrix,
|
| 56 |
mask=mask,
|
| 57 |
annot=True,
|
|
|
|
| 67 |
|
| 68 |
# Save the plot to a bytes buffer.
|
| 69 |
buf = BytesIO()
|
| 70 |
+
plt.savefig(buf, format="png", bbox_inches="tight")
|
| 71 |
plt.close()
|
| 72 |
buf.seek(0)
|
| 73 |
+
|
| 74 |
# Convert the buffer into a PIL Image.
|
| 75 |
image = Image.open(buf).convert("RGB")
|
| 76 |
+
|
| 77 |
+
# Resize the image to a reasonable fixed size for the gallery
|
| 78 |
+
# This helps maintain consistency and prevent oversized images
|
| 79 |
+
max_size = (800, 600)
|
| 80 |
+
image.thumbnail(max_size, Image.Resampling.LANCZOS)
|
| 81 |
+
|
| 82 |
return image
|
| 83 |
|
| 84 |
def generate_heatmaps(selected_model_names):
|
|
|
|
| 96 |
# -------------------------------
|
| 97 |
# 3. Build the Gradio Interface
|
| 98 |
# -------------------------------
|
| 99 |
+
with gr.Blocks(css="""
|
| 100 |
+
.gallery-item img {
|
| 101 |
+
max-width: 100% !important;
|
| 102 |
+
max-height: 100% !important;
|
| 103 |
+
object-fit: contain !important;
|
| 104 |
+
}
|
| 105 |
+
""") as demo:
|
| 106 |
gr.Markdown("## 3C3H Heatmap Generator")
|
| 107 |
gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
|
| 108 |
|
|
|
|
| 110 |
model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
|
| 111 |
|
| 112 |
generate_btn = gr.Button("Generate Heatmaps")
|
| 113 |
+
|
| 114 |
+
# Set height and columns for better display
|
| 115 |
+
gallery = gr.Gallery(
|
| 116 |
+
label="Heatmaps",
|
| 117 |
+
columns=2,
|
| 118 |
+
height="auto",
|
| 119 |
+
object_fit="contain"
|
| 120 |
+
)
|
| 121 |
|
| 122 |
generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
|
| 123 |
|
| 124 |
# Launch the Gradio app
|
| 125 |
demo.launch()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# import gradio as gr
|
| 129 |
+
# import json
|
| 130 |
+
# import os
|
| 131 |
+
# import numpy as np
|
| 132 |
+
# import matplotlib.pyplot as plt
|
| 133 |
+
# import seaborn as sns
|
| 134 |
+
# from io import BytesIO
|
| 135 |
+
# from PIL import Image
|
| 136 |
+
|
| 137 |
+
# # -------------------------------
|
| 138 |
+
# # 1. Load Results from Local File
|
| 139 |
+
# # -------------------------------
|
| 140 |
+
# def load_results():
|
| 141 |
+
# # Get the directory of the current file
|
| 142 |
+
# current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 143 |
+
# # Construct the path to the JSON file (assumes file is stored in "files/aragen_v1_results.json")
|
| 144 |
+
# results_file = os.path.join(current_dir, "files", "aragen_v1_results.json")
|
| 145 |
+
# with open(results_file, "r") as f:
|
| 146 |
+
# data = json.load(f)
|
| 147 |
+
# # Filter out any non-model entries (e.g., timestamp entries)
|
| 148 |
+
# model_data = [entry for entry in data if "Meta" in entry]
|
| 149 |
+
# return model_data
|
| 150 |
+
|
| 151 |
+
# # Load the JSON data once when the app starts
|
| 152 |
+
# DATA = load_results()
|
| 153 |
+
|
| 154 |
+
# # Extract model names for the dropdown from the JSON "Meta" field
|
| 155 |
+
# def get_model_names(data):
|
| 156 |
+
# model_names = [entry["Meta"]["Model Name"] for entry in data]
|
| 157 |
+
# return model_names
|
| 158 |
+
|
| 159 |
+
# MODEL_NAMES = get_model_names(DATA)
|
| 160 |
+
|
| 161 |
+
# # -------------------------------
|
| 162 |
+
# # 2. Define Metrics and Heatmap Generation Functions
|
| 163 |
+
# # -------------------------------
|
| 164 |
+
# # Define the six metrics in the desired order.
|
| 165 |
+
# METRICS = ["Correctness", "Completeness", "Conciseness", "Helpfulness", "Honesty", "Harmlessness"]
|
| 166 |
+
|
| 167 |
+
# def generate_heatmap_image(model_entry):
|
| 168 |
+
# """
|
| 169 |
+
# For a given model entry, extract the six metrics and compute a 6x6 similarity matrix
|
| 170 |
+
# using the definition: similarity = 1 - |v_i - v_j|, then return the heatmap as a PIL image.
|
| 171 |
+
# """
|
| 172 |
+
# scores = model_entry["claude-3.5-sonnet Scores"]["3C3H Scores"]
|
| 173 |
+
# # Create a vector with the metrics in the defined order.
|
| 174 |
+
# v = np.array([scores[m] for m in METRICS])
|
| 175 |
+
# # Compute the 6x6 similarity matrix.
|
| 176 |
+
# matrix = 1 - np.abs(np.subtract.outer(v, v))
|
| 177 |
+
# # Create a mask for the upper triangle (keeping the diagonal visible).
|
| 178 |
+
# mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
|
| 179 |
+
|
| 180 |
+
# plt.figure(figsize=(6, 5))
|
| 181 |
+
# sns.heatmap(matrix,
|
| 182 |
+
# mask=mask,
|
| 183 |
+
# annot=True,
|
| 184 |
+
# fmt=".2f",
|
| 185 |
+
# cmap="viridis",
|
| 186 |
+
# xticklabels=METRICS,
|
| 187 |
+
# yticklabels=METRICS,
|
| 188 |
+
# cbar_kws={"label": "Similarity"})
|
| 189 |
+
# plt.title(f"Confusion Matrix for Model: {model_entry['Meta']['Model Name']}")
|
| 190 |
+
# plt.xlabel("Metrics")
|
| 191 |
+
# plt.ylabel("Metrics")
|
| 192 |
+
# plt.tight_layout()
|
| 193 |
+
|
| 194 |
+
# # Save the plot to a bytes buffer.
|
| 195 |
+
# buf = BytesIO()
|
| 196 |
+
# plt.savefig(buf, format="png")
|
| 197 |
+
# plt.close()
|
| 198 |
+
# buf.seek(0)
|
| 199 |
+
# # Convert the buffer into a PIL Image.
|
| 200 |
+
# image = Image.open(buf).convert("RGB")
|
| 201 |
+
# return image
|
| 202 |
+
|
| 203 |
+
# def generate_heatmaps(selected_model_names):
|
| 204 |
+
# """
|
| 205 |
+
# Filter the global DATA for entries matching the selected model names,
|
| 206 |
+
# generate a heatmap for each, and return a list of PIL images.
|
| 207 |
+
# """
|
| 208 |
+
# filtered_entries = [entry for entry in DATA if entry["Meta"]["Model Name"] in selected_model_names]
|
| 209 |
+
# images = []
|
| 210 |
+
# for entry in filtered_entries:
|
| 211 |
+
# img = generate_heatmap_image(entry)
|
| 212 |
+
# images.append(img)
|
| 213 |
+
# return images
|
| 214 |
+
|
| 215 |
+
# # -------------------------------
|
| 216 |
+
# # 3. Build the Gradio Interface
|
| 217 |
+
# # -------------------------------
|
| 218 |
+
# with gr.Blocks() as demo:
|
| 219 |
+
# gr.Markdown("## 3C3H Heatmap Generator")
|
| 220 |
+
# gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
|
| 221 |
+
|
| 222 |
+
# with gr.Row():
|
| 223 |
+
# model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
|
| 224 |
+
|
| 225 |
+
# generate_btn = gr.Button("Generate Heatmaps")
|
| 226 |
+
# # Use the 'columns' parameter to set a grid layout in the gallery.
|
| 227 |
+
# gallery = gr.Gallery(label="Heatmaps", columns=2)
|
| 228 |
+
|
| 229 |
+
# generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
|
| 230 |
+
|
| 231 |
+
# # Launch the Gradio app
|
| 232 |
+
# demo.launch()
|