3C3H-HeatMap / app.py
alielfilali01's picture
Update app.py
00f7c3d verified
raw
history blame
8.52 kB
import gradio as gr
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image
# -------------------------------
# 1. Load Results from Local File
# -------------------------------
def load_results():
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the JSON file (assumes file is stored in "files/aragen_v1_results.json")
results_file = os.path.join(current_dir, "files", "aragen_v1_results.json")
with open(results_file, "r") as f:
data = json.load(f)
# Filter out any non-model entries (e.g., timestamp entries)
model_data = [entry for entry in data if "Meta" in entry]
return model_data
# Load the JSON data once when the app starts
DATA = load_results()
# Extract model names for the dropdown from the JSON "Meta" field
def get_model_names(data):
model_names = [entry["Meta"]["Model Name"] for entry in data]
return model_names
MODEL_NAMES = get_model_names(DATA)
# -------------------------------
# 2. Define Metrics and Heatmap Generation Functions
# -------------------------------
# Define the six metrics in the desired order.
METRICS = ["Correctness", "Completeness", "Conciseness", "Helpfulness", "Honesty", "Harmlessness"]
def generate_heatmap_image(model_entry):
"""
For a given model entry, extract the six metrics and compute a 6x6 similarity matrix
using the definition: similarity = 1 - |v_i - v_j|, then return the heatmap as a PIL image.
"""
scores = model_entry["claude-3.5-sonnet Scores"]["3C3H Scores"]
# Create a vector with the metrics in the defined order.
v = np.array([scores[m] for m in METRICS])
# Compute the 6x6 similarity matrix.
matrix = 1 - np.abs(np.subtract.outer(v, v))
# Create a mask for the upper triangle (keeping the diagonal visible).
mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
# Set a consistent figure size that will work well in the gallery
plt.figure(figsize=(6, 5), dpi=100)
sns.heatmap(matrix,
mask=mask,
annot=True,
fmt=".2f",
cmap="viridis",
xticklabels=METRICS,
yticklabels=METRICS,
cbar_kws={"label": "Similarity"})
plt.title(f"Confusion Matrix for Model: {model_entry['Meta']['Model Name']}")
plt.xlabel("Metrics")
plt.ylabel("Metrics")
plt.tight_layout()
# Save the plot to a bytes buffer.
buf = BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close()
buf.seek(0)
# Convert the buffer into a PIL Image.
image = Image.open(buf).convert("RGB")
# Resize the image to a reasonable fixed size for the gallery
# This helps maintain consistency and prevent oversized images
max_size = (800, 600)
image.thumbnail(max_size, Image.Resampling.LANCZOS)
return image
def generate_heatmaps(selected_model_names):
"""
Filter the global DATA for entries matching the selected model names,
generate a heatmap for each, and return a list of PIL images.
"""
filtered_entries = [entry for entry in DATA if entry["Meta"]["Model Name"] in selected_model_names]
images = []
for entry in filtered_entries:
img = generate_heatmap_image(entry)
images.append(img)
return images
# -------------------------------
# 3. Build the Gradio Interface
# -------------------------------
with gr.Blocks(css="""
.gallery-item img {
max-width: 100% !important;
max-height: 100% !important;
object-fit: contain !important;
}
""") as demo:
gr.Markdown("## 3C3H Heatmap Generator")
gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
with gr.Row():
default_models = ["silma-ai/SILMA-9B-Instruct-v1.0", "google/gemma-2-9b-it"]
model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=default_models) # value=MODEL_NAMES[:3]
generate_btn = gr.Button("Generate Heatmaps")
# Set height and columns for better display
gallery = gr.Gallery(
label="Heatmaps",
columns=2,
height="auto",
object_fit="contain"
)
generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
# Launch the Gradio app
demo.launch()
# import gradio as gr
# import json
# import os
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns
# from io import BytesIO
# from PIL import Image
# # -------------------------------
# # 1. Load Results from Local File
# # -------------------------------
# def load_results():
# # Get the directory of the current file
# current_dir = os.path.dirname(os.path.abspath(__file__))
# # Construct the path to the JSON file (assumes file is stored in "files/aragen_v1_results.json")
# results_file = os.path.join(current_dir, "files", "aragen_v1_results.json")
# with open(results_file, "r") as f:
# data = json.load(f)
# # Filter out any non-model entries (e.g., timestamp entries)
# model_data = [entry for entry in data if "Meta" in entry]
# return model_data
# # Load the JSON data once when the app starts
# DATA = load_results()
# # Extract model names for the dropdown from the JSON "Meta" field
# def get_model_names(data):
# model_names = [entry["Meta"]["Model Name"] for entry in data]
# return model_names
# MODEL_NAMES = get_model_names(DATA)
# # -------------------------------
# # 2. Define Metrics and Heatmap Generation Functions
# # -------------------------------
# # Define the six metrics in the desired order.
# METRICS = ["Correctness", "Completeness", "Conciseness", "Helpfulness", "Honesty", "Harmlessness"]
# def generate_heatmap_image(model_entry):
# """
# For a given model entry, extract the six metrics and compute a 6x6 similarity matrix
# using the definition: similarity = 1 - |v_i - v_j|, then return the heatmap as a PIL image.
# """
# scores = model_entry["claude-3.5-sonnet Scores"]["3C3H Scores"]
# # Create a vector with the metrics in the defined order.
# v = np.array([scores[m] for m in METRICS])
# # Compute the 6x6 similarity matrix.
# matrix = 1 - np.abs(np.subtract.outer(v, v))
# # Create a mask for the upper triangle (keeping the diagonal visible).
# mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
# plt.figure(figsize=(6, 5))
# sns.heatmap(matrix,
# mask=mask,
# annot=True,
# fmt=".2f",
# cmap="viridis",
# xticklabels=METRICS,
# yticklabels=METRICS,
# cbar_kws={"label": "Similarity"})
# plt.title(f"Confusion Matrix for Model: {model_entry['Meta']['Model Name']}")
# plt.xlabel("Metrics")
# plt.ylabel("Metrics")
# plt.tight_layout()
# # Save the plot to a bytes buffer.
# buf = BytesIO()
# plt.savefig(buf, format="png")
# plt.close()
# buf.seek(0)
# # Convert the buffer into a PIL Image.
# image = Image.open(buf).convert("RGB")
# return image
# def generate_heatmaps(selected_model_names):
# """
# Filter the global DATA for entries matching the selected model names,
# generate a heatmap for each, and return a list of PIL images.
# """
# filtered_entries = [entry for entry in DATA if entry["Meta"]["Model Name"] in selected_model_names]
# images = []
# for entry in filtered_entries:
# img = generate_heatmap_image(entry)
# images.append(img)
# return images
# # -------------------------------
# # 3. Build the Gradio Interface
# # -------------------------------
# with gr.Blocks() as demo:
# gr.Markdown("## 3C3H Heatmap Generator")
# gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
# with gr.Row():
# model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
# generate_btn = gr.Button("Generate Heatmaps")
# # Use the 'columns' parameter to set a grid layout in the gallery.
# gallery = gr.Gallery(label="Heatmaps", columns=2)
# generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
# # Launch the Gradio app
# demo.launch()