Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, BartTokenizer, BartForConditionalGeneration, pipeline | |
| import numpy as np | |
| import torch | |
| import re | |
| from textstat import textstat | |
| MAX_LEN = 256 | |
| NUM_BEAMS = 4 | |
| EARLY_STOPPING = True | |
| N_OUT = 4 | |
| cwi_tok = AutoTokenizer.from_pretrained('twigs/cwi-regressor') | |
| cwi_model = AutoModelForSequenceClassification.from_pretrained( | |
| 'twigs/cwi-regressor') | |
| simpl_tok = BartTokenizer.from_pretrained('twigs/bart-text2text-simplifier') | |
| simpl_model = BartForConditionalGeneration.from_pretrained( | |
| 'twigs/bart-text2text-simplifier') | |
| cwi_pipe = pipeline('text-classification', model=cwi_model, | |
| tokenizer=cwi_tok, function_to_apply='none') | |
| fill_pipe = pipeline('fill-mask', top_k=1) | |
| def id_replace_complex(s, threshold=0.2): | |
| # get all tokens | |
| tokens = re.compile('\w+').findall(s) | |
| cands = [f"{t}. {s}" for t in tokens] | |
| # get complex tokens | |
| # if score >= threshold select tokens[idx] | |
| compl_tok = [tokens[idx] for idx, x in enumerate( | |
| cwi_pipe(cands)) if x['score'] >= threshold] | |
| masked = [s[:s.index(t)] + '<mask>' + s[s.index(t)+len(t):] for t in compl_tok] | |
| cands = fill_pipe(masked) | |
| # structure is different in 1 vs n complex words | |
| replacements = [el['token_str'] if type( | |
| el) == dict else el[0]['token_str'] for el in cands] | |
| # some tokens get prefixed with space | |
| replacements = [tok if tok.find(' ') == -1 else tok[1:] | |
| for tok in replacements] | |
| for i, el in enumerate(compl_tok): | |
| idx = s.index(el) | |
| s = s[:idx] + replacements[i] + s[idx+len(el):] | |
| return s, compl_tok, replacements | |
| def generate_candidate_text(s, model, tokenizer, tokenized=False): | |
| out = simpl_tok([s], max_length=256, padding="max_length", truncation=True, | |
| return_tensors='pt') if not tokenized else s | |
| generated_ids = model.generate( | |
| input_ids=out['input_ids'], | |
| attention_mask=out['attention_mask'], | |
| use_cache=True, | |
| decoder_start_token_id=simpl_model.config.pad_token_id, | |
| num_beams=NUM_BEAMS, | |
| max_length=MAX_LEN, | |
| early_stopping=EARLY_STOPPING, | |
| num_return_sequences=N_OUT | |
| ) | |
| return [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[ | |
| 1:] for ids in generated_ids] | |
| def rank_candidate_text(sentences): | |
| fkgl_scores = [textstat.flesch_kincaid_grade(s) for s in sentences] | |
| return sentences[np.argmin(fkgl_scores)] | |
| def full_pipeline(source, simpl_model, simpl_tok, tokens, lexical=False): | |
| modified, complex_words, replacements = id_replace_complex(source, threshold=0.2) if lexical else (source, None, None) | |
| cands = generate_candidate_text(tokens+modified, simpl_model, simpl_tok) | |
| output = rank_candidate_text(cands) | |
| return output, complex_words, replacements | |
| def main(): | |
| aug_tok = ['c_', 'lev_', 'dep_', 'rank_', 'rat_', 'n_syl_'] | |
| base_tokens = ['CharRatio', 'LevSim', 'DependencyTreeDepth', | |
| 'WordComplexity', 'WordRatio', 'NumberOfSyllables'] | |
| default_values = [0.8, 0.6, 0.9, 0.8, 0.9, 1.9] | |
| user_values = default_values | |
| tok_values = dict((t, default_values[idx]) for idx, t in enumerate(base_tokens)) | |
| example_sentences = ["A matchbook is a small cardboard folder (matchcover) enclosing a quantity of matches and having a coarse striking surface on the exterior.", | |
| "If there are no strong land use controls, buildings are built along a bypass, converting it into an ordinary town road, and the bypass may eventually become as congested as the local streets it was intended to avoid.", | |
| "Plot Captain Caleb Holt (Kirk Cameron) is a firefighter in Albany, Georgia and firmly keeps the cardinal rule of all firemen, \"Never leave your partner behind\".", | |
| "Britpop emerged from the British independent music scene of the early 1990s and was characterised by bands influenced by British guitar pop music of the 1960s and 1970s."] | |
| st.title("Make it Simple") | |
| with st.expander("Example sentences"): | |
| for s in example_sentences: | |
| st.code(body=s) | |
| with st.form(key="simplify"): | |
| input_sentence = st.text_area("Original sentence") | |
| lexical = st.checkbox("Identify and replace complex words", value=True) | |
| tok = st.multiselect( | |
| label="Tokens to augment the sentence", options=base_tokens, default=base_tokens) | |
| if (tok): | |
| st.text("Select the desired intensity") | |
| for idx, t in enumerate(tok): | |
| user_values[idx] = st.slider( | |
| t, min_value=0., max_value=1., value=tok_values[t], step=0.1, key=t) | |
| submit = st.form_submit_button("Process") | |
| if (submit): | |
| tokens = " ".join([t+str(v) for t, v in zip(aug_tok, user_values)]) + " " | |
| output, words, replacements = full_pipeline(input_sentence, simpl_model, simpl_tok, tokens, lexical) | |
| c1, c2, c3 = st.columns([1,1,2]) | |
| with c1: | |
| st.markdown("#### Words identified as complex") | |
| if words: | |
| for w in words: | |
| st.markdown(f"* {w}") | |
| else: | |
| st.markdown("None :smile:") | |
| with c2: | |
| st.markdown("#### Their mask-predicted replacement") | |
| if replacements: | |
| for w in replacements: | |
| st.markdown(f"* {w}") | |
| else: | |
| st.markdown("None :smile:") | |
| with c3: | |
| st.markdown(f"#### Original Sentence:\n > {input_sentence}") | |
| st.markdown(f"#### Output Sentence:\n > {output}") | |
| if __name__ == '__main__': | |
| main() | |