Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import re | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
| import torch | |
| from keybert import KeyBERT | |
| from datasets import load_dataset | |
| import shap | |
| from transformers_interpret import SequenceClassificationExplainer | |
| from ferret import Benchmark | |
| #model_identifier = "karalif/myTestModel" | |
| #model = AutoModelForSequenceClassification.from_pretrained(model_identifier) | |
| #tokenizer = AutoTokenizer.from_pretrained(model_identifier) | |
| name = "karalif/myTestModel" | |
| model = AutoModelForSequenceClassification.from_pretrained(name) | |
| tokenizer = AutoTokenizer.from_pretrained(name, normalization=True) | |
| bench = Benchmark(model, tokenizer) | |
| #text = "hvað er maðurinn eiginlega að pæla ég fatta ekki??????????" | |
| def get_prediction(text): | |
| encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=200) | |
| encoding = {k: v.to(model.device) for k, v in encoding.items()} | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| logits = outputs.logits | |
| sigmoid = torch.nn.Sigmoid() | |
| probs = sigmoid(logits.squeeze().cpu()).numpy() | |
| kw_model = KeyBERT() | |
| keywords = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 1), stop_words='english', use_maxsum=True, nr_candidates=20, top_n=5) | |
| response = "" | |
| labels = ['Politeness', 'Toxicity', 'Sentiment', 'Formality'] | |
| colors = ['#b8e994', '#f8d7da', '#fff3cd', '#bee5eb'] # Corresponding colors for labels | |
| for i, label in enumerate(labels): | |
| response += f"<span style='background-color:{colors[i]}; color:black;'>{label}</span>: {probs[i]*100:.1f}%<br>" | |
| influential_keywords = "INFLUENTIAL KEYWORDS:<br>" | |
| for keyword, score in keywords: | |
| influential_keywords += f"{keyword} (Score: {score:.2f})<br>" | |
| return response, keywords, influential_keywords | |
| def replace_encoding(tokens): | |
| return [token.replace('Ġ', ' ') | |
| .replace('ð', 'ð') | |
| .replace('é', 'é') | |
| .replace('æ', 'æ') | |
| .replace('ý', 'ý') | |
| .replace('á', 'á') | |
| .replace('ú', 'ú') | |
| .replace('ÃŃ', 'í') | |
| .replace('Ãö', 'ö') | |
| .replace('þ', 'þ') | |
| .replace('Ãģ', 'Á') | |
| .replace('Ãį', 'Ú') | |
| .replace('Ãĵ', 'Ó') | |
| .replace('ÃĨ', 'Æ') | |
| .replace('ÃIJ', 'Ð') | |
| .replace('Ãĸ', 'Ö') | |
| .replace('Ãī', 'É') | |
| .replace('Ãļ', 'ý') | |
| for token in tokens[1:-1]] | |
| def predict(text): | |
| explanations_formality = bench.explain(text, target=0) | |
| explanations_sentiment = bench.explain(text, target=1) | |
| explanations_politeness = bench.explain(text, target=2) | |
| explanations_toxicity = bench.explain(text, target=3) | |
| greeting_pattern = r"^(Halló|Hæ|Sæl|Góðan dag|Kær kveðja|Daginn|Kvöldið|Ágætis|Elsku)" | |
| prediction_output, keywords, influential_keywords = get_prediction(text) | |
| greeting_feedback = "" | |
| modified_input = text | |
| for keyword, _ in keywords: | |
| modified_input = modified_input.replace(keyword, f"<span style='color:green;'>{keyword}</span>") | |
| #if not re.match(greeting_pattern, text, re.IGNORECASE): | |
| # greeting_feedback = "OTHER FEEDBACK:<br>Heilsaðu dóninn þinn<br>" | |
| response = f"INPUT:<br>{modified_input}<br><br>MY PREDICTION:<br>{prediction_output}<br>{influential_keywords}<br>{greeting_feedback}" | |
| # Influential words | |
| explanation_lists = [explanations_toxicity, explanations_formality, explanations_sentiment, explanations_politeness] | |
| labels = ['Toxicity', 'Formality', 'Sentiment', 'Politeness'] | |
| response += "<br>MOST INFLUENTIAL WORDS FOR EACH LABEL:<br>" | |
| for i, explanations in enumerate(explanation_lists): | |
| label = labels[i] | |
| for explanation in explanations: | |
| if explanation.explainer == 'Partition SHAP': | |
| tokens = replace_encoding(explanation.tokens) | |
| token_score_pairs = zip(tokens, explanation.scores) | |
| formatted_output = ' '.join([f"{token} ({score})" for token, score in token_score_pairs]) | |
| response += f"{label}: {formatted_output}<br>" | |
| #response += "<br>TOP 2 MOST INFLUENTIAL WORDS FOR EACH LABEL:<br>" | |
| #for i, explanations in enumerate(explanation_lists): | |
| # label = labels[i] | |
| # response += f"{label}:<br>" | |
| # for explanation in explanations: | |
| # if explanation.explainer == 'Partition SHAP': | |
| # sorted_scores = sorted(enumerate(explanation.scores), key=lambda x: abs(x[1]), reverse=True)[:2] | |
| # tokens = replace_encoding(explanation.tokens) | |
| # tokens = [tokens[idx] for idx, _ in sorted_scores] | |
| # formatted_output = ' '.join(tokens) | |
| # response += f"{formatted_output}<br>" | |
| return response | |
| description_html = """ | |
| <center> | |
| <img src='http://www.ru.is/media/HR_logo_vinstri_transparent.png' width='250' height='auto'> | |
| </center> | |
| """ | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.TextArea(label="Enter text here:"), | |
| outputs=gr.HTML(label="Leiðrétt"), | |
| description=description_html, | |
| examples=[ | |
| ["Sæl og blessuð Kristín, hvað er að frella af þér gamla??"], | |
| ], | |
| theme=gr.themes.Default(primary_hue="red", secondary_hue="pink") | |
| ) | |
| demo.launch() |