Spaces:
Sleeping
Sleeping
MaroueneA
commited on
Commit
·
b44bb03
1
Parent(s):
8ca1a85
added model comparaison function
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import pandas as pd
|
|
| 3 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 4 |
from datasets import load_dataset
|
| 5 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
|
|
|
|
| 6 |
import torch
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
import umap
|
|
@@ -88,6 +89,54 @@ def generate_embeddings_and_plot(categories):
|
|
| 88 |
tsne_plot_path = plot_embeddings(tsne_embeddings, "t-SNE Projection of Text Categories", "tsne")
|
| 89 |
return umap_plot_path, tsne_plot_path
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def setup_gradio_interface():
|
| 92 |
with gr.Blocks() as demo:
|
| 93 |
gr.Markdown("## Model Comparison and Text Analysis")
|
|
|
|
| 3 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 4 |
from datasets import load_dataset
|
| 5 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
|
| 6 |
+
from sklearn.cluster import KMeans
|
| 7 |
import torch
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
import umap
|
|
|
|
| 89 |
tsne_plot_path = plot_embeddings(tsne_embeddings, "t-SNE Projection of Text Categories", "tsne")
|
| 90 |
return umap_plot_path, tsne_plot_path
|
| 91 |
|
| 92 |
+
def compare_models(model1, model2):
|
| 93 |
+
# Assuming dataset['test']['text'] returns a list of strings:
|
| 94 |
+
test_texts = dataset['test']['text'] # This is directly usable if it's a list
|
| 95 |
+
# Directly use the labels as a list, without calling .tolist()
|
| 96 |
+
labels = dataset['test']['label']
|
| 97 |
+
|
| 98 |
+
inputs1 = encode(test_texts, tokenizers[model1])
|
| 99 |
+
inputs2 = encode(test_texts, tokenizers[model2])
|
| 100 |
+
|
| 101 |
+
preds1 = predict(models[model1], inputs1)
|
| 102 |
+
preds2 = predict(models[model2], inputs2)
|
| 103 |
+
|
| 104 |
+
metrics1 = calculate_metrics(labels, preds1)
|
| 105 |
+
metrics2 = calculate_metrics(labels, preds2)
|
| 106 |
+
|
| 107 |
+
categories = {
|
| 108 |
+
"correct_both": [],
|
| 109 |
+
"incorrect_both": [],
|
| 110 |
+
"correct_model1_only": [],
|
| 111 |
+
"correct_model2_only": []
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
for i, label in enumerate(labels):
|
| 115 |
+
text = test_texts[i]
|
| 116 |
+
if preds1[i] == label and preds2[i] == label:
|
| 117 |
+
categories["correct_both"].append(text)
|
| 118 |
+
elif preds1[i] != label and preds2[i] != label:
|
| 119 |
+
categories["incorrect_both"].append(text)
|
| 120 |
+
elif preds1[i] == label and preds2[i] != label:
|
| 121 |
+
categories["correct_model1_only"].append(text)
|
| 122 |
+
elif preds1[i] != label and preds2[i] == label:
|
| 123 |
+
categories["correct_model2_only"].append(text)
|
| 124 |
+
|
| 125 |
+
# Generate metrics DataFrame
|
| 126 |
+
metrics_df = pd.DataFrame({
|
| 127 |
+
"Metric": ["Accuracy", "Precision", "Recall", "F1 Score"],
|
| 128 |
+
model1: metrics1[:-1],
|
| 129 |
+
model2: metrics2[:-1],
|
| 130 |
+
})
|
| 131 |
+
metrics_df["% Difference"] = ((metrics_df[model1] - metrics_df[model2]) / metrics_df[model2] * 100).apply(lambda x: f"{x:.2f}%")
|
| 132 |
+
|
| 133 |
+
# Confusion matrices and visualizations
|
| 134 |
+
conf_matrix_path1 = generate_confusion_matrix(metrics1[-1], model1)
|
| 135 |
+
conf_matrix_path2 = generate_confusion_matrix(metrics2[-1], model2)
|
| 136 |
+
umap_plot_path, tsne_plot_path = generate_embeddings_and_plot(categories)
|
| 137 |
+
|
| 138 |
+
return metrics_df, conf_matrix_path1, conf_matrix_path2, umap_plot_path, tsne_plot_path, categories
|
| 139 |
+
|
| 140 |
def setup_gradio_interface():
|
| 141 |
with gr.Blocks() as demo:
|
| 142 |
gr.Markdown("## Model Comparison and Text Analysis")
|