Spaces:
Build error
Build error
| from transformers import pipeline | |
| from transformers import AutoModelForSeq2SeqLM | |
| from transformers import AutoTokenizer | |
| from textblob import TextBlob | |
| from hatesonar import Sonar | |
| import gradio as gr | |
| import torch | |
| # Load trained model | |
| model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer") | |
| tokenizer = AutoTokenizer.from_pretrained("output/reframer") | |
| reframer = pipeline('summarization', model=model, tokenizer=tokenizer) | |
| CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text | |
| CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text | |
| SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text | |
| OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive | |
| LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length." | |
| SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment." | |
| OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text." | |
| CACHE = [] # A list storing the most recent 5 reframing history | |
| MAX_STORE = 5 # The maximum number of history user would like to store | |
| BEST_N = 3 # The number of best decodes user would like to seee | |
| def input_error_message(error_type): | |
| # type: (str) -> str | |
| """Generate an input error message from error type.""" | |
| return "[Error]: Invalid Input. " + error_type | |
| def update_cache(cache, new_record): | |
| # type: List[List[str, str, str]] -> List[List[str, str, str]] | |
| """Update the cache to store the most recent five reframing histories.""" | |
| cache.append(new_record) | |
| if len(cache) > MAX_STORE: | |
| cache = cache[1:] | |
| return cache | |
| def reframe(input_text, strategy): | |
| # type: (str, str) -> str | |
| """Reframe the input text with a specified strategy. | |
| The strategy will be concetenated to the input text and passed to a finetuned BART model. | |
| The reframed positive text will be returned. | |
| """ | |
| text_with_strategy = input_text + "Strategy: ['" + strategy + "']" | |
| # Input Control | |
| # The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea. | |
| if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND: | |
| return input_text + input_error_message(LENGTH_ERROR) | |
| # The input text cannot be too positive to ensure the text can be positively reframed. | |
| if TextBlob(input_text).sentiment.polarity > 0.2: | |
| return input_text + input_error_message(SENTIMENT_ERROR) | |
| # The input text cannot be offensive. | |
| sonar = Sonar() | |
| # sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language | |
| if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD: | |
| return input_text + input_error_message(OFFENSIVE_ERROR) | |
| # Reframing | |
| # reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output | |
| reframed_text = reframer(text_with_strategy)[0]['summary_text'] | |
| # Update cache | |
| global CACHE | |
| CACHE = update_cache(CACHE, [input_text, strategy, reframed_text]) | |
| return reframed_text | |
| def show_reframe_change(input_text, strategy): | |
| # type: (str, str) -> List[Tuple[str, str]] | |
| """Compare the addition and deletion of characters in input_text to form reframed_text. | |
| The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text. | |
| """ | |
| reframed_text = reframe(input_text, strategy) | |
| from difflib import Differ | |
| d = Differ() | |
| return [ | |
| (token[2:], token[0] if token[0] != " " else None) | |
| for token in d.compare(input_text, reframed_text) | |
| ] | |
| def show_n_best_decodes(input_text, strategy): | |
| # type: (str, str) -> str | |
| prompt = [input_text + "Strategy: ['" + strategy + "']"] | |
| n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']), | |
| do_sample=True, | |
| num_return_sequences=BEST_N | |
| ) | |
| best_n_result = "" | |
| for i in range(len(n_best_decodes)): | |
| best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True) | |
| if i < BEST_N - 1: | |
| best_n_result += "\n" | |
| return best_n_result | |
| def show_history(cache): | |
| # type: List[List[str, str, str]] -> str | |
| history = "" | |
| for i in cache: | |
| input_text, strategy, reframed_text = i | |
| history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n" | |
| return gr.Textbox.update(value=history, visible=True) | |
| demo = gr.Interface( | |
| fn=show_reframe_change, | |
| inputs=[gr.Textbox(lines=2, placeholder="Please input the sentence to be reframed.", label="Original Text"), gr.Radio(["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?")], | |
| outputs=gr.HighlightedText(label="Diff",combine_adjacent=True,).style(color_map={"+": "green", "-": "red"}), | |
| ) | |
| demo.launch(show_api=True) |