simple-translator / src /streamlit_app.py
aribilgiogr's picture
Update src/streamlit_app.py
fbd9b4d verified
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import streamlit as st
st.title("Translation Model (TR-EN/EN-TR)")
direction = st.radio("Direction:", ["tr-en", "en-tr"])
@st.cache_resource
def load_models(lang='tr-en'):
if lang == 'tr-en' or lang == 'en-tr':
model_name = f'Helsinki-NLP/opus-mt-tc-big-{lang}'
else:
return None, None
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return tokenizer, model
def translate(text, tokenizer, model):
if not text.strip():
raise ValueError('Text cannot be empty')
inputs = tokenizer(text, return_tensors='pt', padding=True)
with torch.no_grad():
translated_tokens = model.generate(**inputs)
translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return translation
t, m = load_models(direction)
text_to_translate = st.text_area("Enter text to translate:")
if st.button("Translate"):
if text_to_translate.strip():
translation = translate(text_to_translate, t, m)
st.info(translation)
else:
st.warning('Text cannot be empty')