3C3H-HeatMap / app.py
alielfilali01's picture
Update app.py
6a40ae3 verified
raw
history blame
3.65 kB
import gradio as gr
import json
import requests
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
# -------------------------------
# 1. Configuration and Data Loading
# -------------------------------
# URL to the JSON file (the URL below resolves to the raw file)
DATA_URL = "https://huggingface.co/spaces/alielfilali01/3C3H-HeatMap/resolve/main/files/aragen_v1_results.json"
# Define the metrics order (6 dimensions)
METRICS = ["Correctness", "Completeness", "Conciseness", "Helpfulness", "Honesty", "Harmlessness"]
def load_data(url=DATA_URL):
response = requests.get(url)
data = response.json()
# Filter out any non-model entries (e.g. timestamp entries)
model_data = [entry for entry in data if "Meta" in entry]
return model_data
# Load the JSON data once when the app starts
DATA = load_data()
# Extract model names for the dropdown based on the JSON "Meta" field
def get_model_names(data):
model_names = [entry["Meta"]["Model Name"] for entry in data]
return model_names
MODEL_NAMES = get_model_names(DATA)
# -------------------------------
# 2. Heatmap Generation Functions
# -------------------------------
def generate_heatmap_image(model_entry):
"""
Given a model entry from the JSON data, this function extracts the 6 metrics,
computes a 6x6 similarity matrix using the definition: similarity = 1 - |v_i - v_j|,
and returns the heatmap image as bytes.
"""
scores = model_entry["claude-3.5-sonnet Scores"]["3C3H Scores"]
# Create a vector with the metrics in the defined order
v = np.array([scores[m] for m in METRICS])
# Compute the 6x6 similarity matrix
matrix = 1 - np.abs(np.subtract.outer(v, v))
# Create a mask for the upper triangle (diagonal remains visible)
mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
plt.figure(figsize=(6, 5))
ax = sns.heatmap(matrix,
mask=mask,
annot=True,
fmt=".2f",
cmap="viridis",
xticklabels=METRICS,
yticklabels=METRICS,
cbar_kws={"label": "Similarity"})
plt.title(f"Confusion Matrix for Model: {model_entry['Meta']['Model Name']}")
plt.xlabel("Metrics")
plt.ylabel("Metrics")
plt.tight_layout()
# Save the figure to a bytes buffer
buf = BytesIO()
plt.savefig(buf, format="png")
plt.close()
buf.seek(0)
return buf.read()
def generate_heatmaps(selected_model_names):
"""
Filters the global DATA for entries matching the selected model names,
generates a heatmap for each one, and returns a list of image bytes.
"""
filtered_entries = [entry for entry in DATA if entry["Meta"]["Model Name"] in selected_model_names]
images = []
for entry in filtered_entries:
img_bytes = generate_heatmap_image(entry)
images.append(img_bytes)
return images
# -------------------------------
# 3. Build the Gradio Interface
# -------------------------------
with gr.Blocks() as demo:
gr.Markdown("## 3C3H Heatmap Generator")
gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
with gr.Row():
model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
generate_btn = gr.Button("Generate Heatmaps")
gallery = gr.Gallery(label="Heatmaps").style(grid=[2], height="auto")
generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
# Launch the Gradio app
demo.launch()