import streamlit as st import torch import torch.nn.functional as F # Import T5 Model and Tokenizer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM st.title("Corrector") models = { "T5 Small": "douha/T5_SpellCorrector"} selected_model = st.radio("Select Model", list(models.keys())) model_name = models[selected_model] if selected_model=='T5 Small': tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Inference function for T5 def t5_summarize(input_text, tokenizer, model): inputs=tokenizer('correction: '+input_text, truncation=True, padding='max_length', max_length=600, return_tensors='pt') output_sequence=model.generate(input_ids=inputs["input_ids"],attention_mask=inputs["attention_mask"], max_new_tokens=100) summary = tokenizer.batch_decode(output_sequence, skip_special_tokens=True) return summary[0] input_text=st.text_area("Input the text to summarize","", height=300) if st.button("Correct"): st.text("It may take a minute or two.") nwords=len(input_text.split(" ")) if selected_model=='T5 Small': summary=t5_summarize(input_text, tokenizer, model) st.header("Corrected text") st.markdown(summary)