Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from load_data import * | |
| import matplotlib.pyplot as plt | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer | |
| from transformers import AutoModelForSequenceClassification | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli") | |
| nli_model = ( | |
| AutoModelForSequenceClassification.from_pretrained( | |
| "facebook/bart-large-mnli" | |
| ).cuda() | |
| if torch.cuda.is_available() | |
| else AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli") | |
| ) | |
| def get_prob(sequence, label): | |
| premise = sequence | |
| hypothesis = f"This example is {label}." | |
| # run through model pre-trained on MNLI | |
| x = tokenizer.encode( | |
| premise, hypothesis, return_tensors="pt", truncation_strategy="only_first" | |
| ) | |
| logits = nli_model(x.to(device))[0] | |
| # we throw away "neutral" (dim 1) and take the probability of | |
| # "entailment" (2) as the probability of the label being true | |
| entail_contradiction_logits = logits[:, [0, 2]] | |
| probs = entail_contradiction_logits.softmax(dim=1) | |
| prob_label_is_true = probs[:, 1] | |
| return prob_label_is_true[0].item() | |
| def judge_mbti(sequence, labels): | |
| out = [] | |
| for l in labels: | |
| temp = get_prob(sequence, l) | |
| out.append((l, temp)) | |
| out = sorted(out, key=lambda x: x[1], reverse=True) | |
| return out | |
| def compute_score(text, type): | |
| x, y = type.split("_") | |
| x_score = np.sum([i[1] for i in judge_mbti(text, keywords_en[type][x])]) | |
| y_score = np.sum([i[1] for i in judge_mbti(text, keywords_en[type][y])]) | |
| if x_score > y_score: | |
| choice = x | |
| score = x_score | |
| else: | |
| choice = y | |
| score = y_score | |
| x_score_scaled = (x_score / (x_score + y_score)) * 100 | |
| y_score_scaled = (y_score / (x_score + y_score)) * 100 | |
| stat = {x: x_score_scaled, y: y_score_scaled} | |
| return choice, stat | |
| def mbti_translator(text): | |
| E_I = compute_score(text, "E_I") | |
| N_S = compute_score(text, "N_S") | |
| T_F = compute_score(text, "T_F") | |
| P_J = compute_score(text, "P_J") | |
| return (E_I[0] + N_S[0] + T_F[0] + P_J[0]), (E_I[1], N_S[1], T_F[1], P_J[1]) | |
| def plot_mbti(result): | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| start = 0 | |
| x, y = result.values() | |
| x_type, y_type = result.keys() | |
| ax.broken_barh([(start, x), (x, x + y)], [10, 9], facecolors=("#FFC5BF", "#D4F0F0")) | |
| ax.set_ylim(5, 15) | |
| ax.set_xlim(0, 100) | |
| ax.spines["left"].set_visible(False) | |
| ax.spines["bottom"].set_visible(False) | |
| ax.spines["top"].set_visible(False) | |
| ax.spines["right"].set_visible(False) | |
| ax.set_yticks([15, 25]) | |
| ax.set_xticks([0, 25, 50, 75, 100]) | |
| ax.text(x - 6, 14.5, x_type + " :" + str(int(x)) + "%", fontsize=15) | |
| ax.text((x + y) - 6, 14.5, y_type + " :" + str(int(y)) + "%", fontsize=15) | |
| st.pyplot(fig) | |