San-NLP's picture
Create app.py
05373f8 verified
import streamlit as st
import torch
from transformers import MarianMTModel, MarianTokenizer
st.set_page_config(
page_title="Language Translation App",
page_icon="🌍",
layout="wide"
)
# -----------------------------
# Custom CSS
# -----------------------------
st.markdown("""
<style>
.block-container {
padding-top: 1.5rem;
padding-bottom: 2rem;
max-width: 1100px;
}
.app-title {
font-size: 3.2rem;
font-weight: 800;
color: #2d2d3a;
margin-bottom: 0.25rem;
}
.app-subtitle {
font-size: 1.2rem;
color: #555;
margin-bottom: 1.8rem;
}
.stSelectbox label, .stTextArea label {
font-size: 1.05rem !important;
font-weight: 600 !important;
}
.stTextArea textarea {
font-size: 1.15rem !important;
border-radius: 12px !important;
}
.stButton > button {
min-width: 140px;
height: 48px;
font-size: 1rem;
font-weight: 600;
border-radius: 10px;
}
.result-label {
font-size: 1.05rem;
font-weight: 600;
margin-top: 1rem;
margin-bottom: 0.5rem;
}
</style>
""", unsafe_allow_html=True)
# -----------------------------
# Supported languages
# -----------------------------
LANGUAGES = {
"English": "en",
"French": "fr",
"German": "de",
"Spanish": "es",
"Italian": "it",
"Portuguese": "pt",
"Dutch": "nl",
"Romanian": "ro",
"Arabic": "ar",
"Hindi": "hi",
}
# -----------------------------
# Helsinki-NLP OPUS-MT models
# -----------------------------
MODEL_MAP = {
("en", "fr"): "Helsinki-NLP/opus-mt-en-fr",
("fr", "en"): "Helsinki-NLP/opus-mt-fr-en",
("en", "de"): "Helsinki-NLP/opus-mt-en-de",
("de", "en"): "Helsinki-NLP/opus-mt-de-en",
("en", "es"): "Helsinki-NLP/opus-mt-en-es",
("es", "en"): "Helsinki-NLP/opus-mt-es-en",
("en", "it"): "Helsinki-NLP/opus-mt-en-it",
("it", "en"): "Helsinki-NLP/opus-mt-it-en",
("en", "pt"): "Helsinki-NLP/opus-mt-en-pt",
("pt", "en"): "Helsinki-NLP/opus-mt-pt-en",
("en", "nl"): "Helsinki-NLP/opus-mt-en-nl",
("nl", "en"): "Helsinki-NLP/opus-mt-nl-en",
("en", "ro"): "Helsinki-NLP/opus-mt-en-ro",
("ro", "en"): "Helsinki-NLP/opus-mt-ro-en",
("en", "ar"): "Helsinki-NLP/opus-mt-en-ar",
("ar", "en"): "Helsinki-NLP/opus-mt-ar-en",
("en", "hi"): "Helsinki-NLP/opus-mt-en-hi",
("hi", "en"): "Helsinki-NLP/opus-mt-hi-en",
}
# -----------------------------
# Session state
# -----------------------------
if "input_text" not in st.session_state:
st.session_state.input_text = ""
if "translated_text" not in st.session_state:
st.session_state.translated_text = ""
if "model_info" not in st.session_state:
st.session_state.model_info = ""
# -----------------------------
# Device
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Load model + tokenizer
# -----------------------------
@st.cache_resource
def load_model(model_name: str):
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
model.to(DEVICE)
return tokenizer, model
# -----------------------------
# Translation function
# -----------------------------
def translate_text(text: str, src_lang: str, tgt_lang: str):
if src_lang == tgt_lang:
return text, "Same language selected"
pair = (src_lang, tgt_lang)
if pair not in MODEL_MAP:
return None, f"No open-source model available for {src_lang}{tgt_lang}"
model_name = MODEL_MAP[pair]
try:
tokenizer, model = load_model(model_name)
inputs = tokenizer(
[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
inputs = {key: value.to(DEVICE) for key, value in inputs.items()}
translated_tokens = model.generate(
**inputs,
max_length=512,
num_beams=4,
early_stopping=True
)
translated_text = tokenizer.decode(
translated_tokens[0],
skip_special_tokens=True
)
return translated_text, model_name
except Exception as e:
return None, f"Translation failed: {str(e)}"
# -----------------------------
# Header
# -----------------------------
st.markdown('<div class="app-title">Language Translation App 🌍</div>', unsafe_allow_html=True)
st.markdown(
'<div class="app-subtitle">Translate text between multiple languages using open-source models.</div>',
unsafe_allow_html=True
)
# -----------------------------
# Language selection
# -----------------------------
col1, col2 = st.columns(2)
with col1:
source_language_name = st.selectbox(
"Select Source Language",
list(LANGUAGES.keys()),
index=0
)
with col2:
target_language_name = st.selectbox(
"Select Target Language",
list(LANGUAGES.keys()),
index=2
)
source_language = LANGUAGES[source_language_name]
target_language = LANGUAGES[target_language_name]
# -----------------------------
# Input area
# -----------------------------
input_text = st.text_area(
"Enter Text to Translate",
value=st.session_state.input_text,
height=220,
placeholder="Type or paste your text here..."
)
st.session_state.input_text = input_text
# -----------------------------
# Buttons
# -----------------------------
b1, b2 = st.columns([1, 1])
with b1:
translate_button = st.button("Translate")
with b2:
clear_button = st.button("Clear")
if clear_button:
st.session_state.input_text = ""
st.session_state.translated_text = ""
st.session_state.model_info = ""
st.rerun()
# -----------------------------
# Translate action
# -----------------------------
if translate_button:
if not input_text.strip():
st.warning("Please enter some text to translate.")
else:
with st.spinner("Translating..."):
translated_text, info = translate_text(
input_text.strip(),
source_language,
target_language
)
if translated_text is None:
st.error(info)
else:
st.session_state.translated_text = translated_text
st.session_state.model_info = info
st.success("Translation completed successfully.")
# -----------------------------
# Output area
# -----------------------------
if st.session_state.translated_text:
st.markdown('<div class="result-label">Translated Text</div>', unsafe_allow_html=True)
st.text_area(
"",
value=st.session_state.translated_text,
height=220
)
st.caption(f"Model used: {st.session_state.model_info}")