Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files
main.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import matplotlib
|
| 3 |
+
import models
|
| 4 |
+
import utils
|
| 5 |
+
|
| 6 |
+
def infer(gt: str, data: str) -> matplotlib.figure:
|
| 7 |
+
nli_res = models.compute_metric(gt, data)
|
| 8 |
+
tone_res = models.compare_tone(gt, data)
|
| 9 |
+
res_text = ""
|
| 10 |
+
if (nli_res["label"] == "neutral"):
|
| 11 |
+
res_text += "Model's response is unrelated to the Ground Truth"
|
| 12 |
+
if (nli_res["label"] == "contradiction"):
|
| 13 |
+
res_text += "Model's response contradicts the Ground Truth"
|
| 14 |
+
if (nli_res["label"] == "entailment"):
|
| 15 |
+
res_text += "Model's response is consistant with the Ground Truth"
|
| 16 |
+
return res_text, utils.create_pie_chart_nli(nli_res), utils.plot_tones(tone_res)
|
| 17 |
+
|
| 18 |
+
examples = [["Cross-encoders are better than bi-encoders for analyzing the relationship betwen texts", "Bi-encoders are superior to cross-encoders"],
|
| 19 |
+
["Cross-encoders are better than bi-encoders for analyzing the relationship betwen texts", "The cosine similarity function can be used to compare the outputs of a bi-encoder"],
|
| 20 |
+
["Cross-encoders are better than bi-encoders for analyzing the relationship betwen texts", "Bi-encoders are outperformed by cross-encoders in the task of relationship analysis"],
|
| 21 |
+
["Birds can fly. There are fish in the sea.", "Fish inhabit the ocean. Birds can aviate."],
|
| 22 |
+
["Birds can fly. There are fish in the sea.", "Fish inhabit the ocean. Birds can not aviate."]]
|
| 23 |
+
app = gr.Interface(fn=infer, inputs=[gr.Textbox(label="Ground Truth"), gr.Textbox(label="Model Response")], examples=examples, outputs=[gr.Textbox(label="Result"), gr.Plot(label="Comparison with GT"), gr.Plot(label="Difference in Tone")])
|
| 24 |
+
app.launch()
|
models.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import CrossEncoder
|
| 2 |
+
from transformers import AutoModelForSequenceClassification
|
| 3 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def softmax(x):
|
| 7 |
+
e_x = np.exp(x - np.max(x))
|
| 8 |
+
return e_x / e_x.sum(axis=0)
|
| 9 |
+
|
| 10 |
+
# 90.04% accuracy on MNLI mismatched set
|
| 11 |
+
nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base')
|
| 12 |
+
|
| 13 |
+
def compute_metric(ground_truth: str, inference: str) -> dict:
|
| 14 |
+
scores = nli_model.predict([ground_truth, inference], apply_softmax=True)
|
| 15 |
+
label = ['contradiction', 'entailment', 'neutral'][scores.argmax()]
|
| 16 |
+
return {
|
| 17 |
+
'label': label,
|
| 18 |
+
'contradiction': scores[0],
|
| 19 |
+
'entailment': scores[1],
|
| 20 |
+
'neutral': scores[2],
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
def _compare_tone(text: str) -> dict:
|
| 24 |
+
# Trained on ~124M Tweets for sentiment analysis
|
| 25 |
+
model_name = r"cardiffnlp/twitter-roberta-base-sentiment-latest"
|
| 26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 27 |
+
config = AutoConfig.from_pretrained(model_name)
|
| 28 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 29 |
+
|
| 30 |
+
encoded_input = tokenizer(text, return_tensors='pt')
|
| 31 |
+
output = model(**encoded_input)
|
| 32 |
+
scores = output[0][0].detach().numpy()
|
| 33 |
+
scores = softmax(scores)
|
| 34 |
+
ranking = np.argsort(scores)
|
| 35 |
+
ranking = ranking[::-1]
|
| 36 |
+
result = {}
|
| 37 |
+
for i in range(scores.shape[0]):
|
| 38 |
+
l = config.id2label[ranking[i]]
|
| 39 |
+
s = scores[ranking[i]]
|
| 40 |
+
result[l] = np.round(float(s), 4)
|
| 41 |
+
|
| 42 |
+
return result
|
| 43 |
+
|
| 44 |
+
def compare_tone(ground_truth: str, inference: str) -> dict:
|
| 45 |
+
gt = _compare_tone(ground_truth)
|
| 46 |
+
model_res = _compare_tone(inference)
|
| 47 |
+
return {"gt": gt, "model": model_res}
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
print(compute_metric("Foxes are closer to dogs than they are to cats. Therefore, foxes are not cats.", "Foxes are not cats."))
|
| 51 |
+
print(compute_metric("Foxes are closer to dogs than they are to cats. Therefore, foxes are not cats.", "Foxes are cats."))
|
| 52 |
+
print(compare_tone("This is neutural", "Wtf"))
|
utils.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
|
| 4 |
+
def create_pie_chart_nli(data: dict) -> matplotlib.figure:
|
| 5 |
+
labels = ["neutral", "contradiction", "entailment"]
|
| 6 |
+
sizes = [data[label] for label in labels]
|
| 7 |
+
colors = ["gray", "red", "green"]
|
| 8 |
+
|
| 9 |
+
fig, ax = plt.subplots()
|
| 10 |
+
|
| 11 |
+
ax.set_title("Comparison with GT")
|
| 12 |
+
ax.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%')
|
| 13 |
+
|
| 14 |
+
ax.axis('equal')
|
| 15 |
+
|
| 16 |
+
return fig
|
| 17 |
+
|
| 18 |
+
def plot_tones(data: dict) -> matplotlib.figure:
|
| 19 |
+
keys = data["gt"].keys()
|
| 20 |
+
|
| 21 |
+
fig, ax = plt.subplots()
|
| 22 |
+
ax.set_title("Tone")
|
| 23 |
+
ax.bar(x=keys, height=[data["gt"][key] for key in keys], color="b", label="Ground Truth", width=0.7)
|
| 24 |
+
ax.bar(x=keys, height=[data["model"][key] for key in keys], color="r", alpha=0.5, label="Model response", width=0.5)
|
| 25 |
+
|
| 26 |
+
fig.legend()
|
| 27 |
+
|
| 28 |
+
return fig
|