IndicTrans / app.py
NetraVerse's picture
Create the app.py
60ffa1d verified
import torch
import os
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit.processor import IndicProcessor
# Get token from environment variable
token = os.getenv("HUGGINGFACE_HUB_TOKEN")
# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Model configuration - English to Kannada translation
src_lang, tgt_lang = "eng_Latn", "kan_Knda"
model_name = "ai4bharat/indictrans2-en-indic-dist-200M"
# Global variables to store model and tokenizer
model = None
tokenizer = None
ip = None
def load_model():
"""Load the translation model and tokenizer"""
global model, tokenizer, ip
try:
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
token=token
)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
trust_remote_code=True,
dtype=torch.float16,
token=token
).to(DEVICE)
ip = IndicProcessor(inference=True)
print(f"Model loaded successfully on {DEVICE}")
return True
except Exception as e:
print(f"Error loading model: {str(e)}")
return False
def translate_text(input_text):
"""
Translate input text using the loaded model
Args:
input_text: Single sentence to translate
Returns:
Translated text
"""
if not model or not tokenizer or not ip:
return "❌ Model not loaded. Please check the model configuration."
if not input_text.strip():
return "Please enter some text to translate."
try:
# Single sentence translation
input_sentences = [input_text.strip()]
if not input_sentences:
return "No valid sentences found."
# Preprocess the input
batch = ip.preprocess_batch(
input_sentences,
src_lang=src_lang,
tgt_lang=tgt_lang,
)
# Tokenize the sentences
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
# Generate translations
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=False,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens
generated_tokens = tokenizer.batch_decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
# Postprocess the translations
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
# Return single translation
return translations[0] if translations else "Translation failed."
except Exception as e:
return f"❌ Translation error: {str(e)}"
def create_interface():
"""Create and configure the Gradio interface"""
# Load model on startup
model_loaded = load_model()
if not model_loaded:
# Create a simple error interface
with gr.Blocks(title="Translation App - Error") as demo:
gr.Markdown("## ❌ Model Loading Error")
gr.Markdown("Failed to load the translation model. Please check:")
gr.Markdown("- Your Hugging Face token is set correctly")
gr.Markdown("- You have access to the gated model")
gr.Markdown("- Your internet connection is working")
return demo
# Create the main interface
with gr.Blocks(
title="AI4Bharat IndicTrans2 Translation",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
f"""
# 🌍 AI4Bharat IndicTrans2 Translation
**Current Configuration:**
- **Source Language:** {src_lang} (English)
- **Target Language:** {tgt_lang} (Kannada)
- **Model:** {model_name}
- **Device:** {DEVICE}
Enter text below to translate from English to Kannada.
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label=f"Input Text ({src_lang})",
placeholder="Enter English text to translate...",
lines=5,
max_lines=10
)
with gr.Row():
translate_btn = gr.Button("πŸ”„ Translate", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear")
with gr.Column():
output_text = gr.Textbox(
label=f"Translation ({tgt_lang})",
lines=5,
max_lines=10,
interactive=False
)
# Example inputs
gr.Markdown("### πŸ“ Example Inputs:")
examples = [
["Hello, how are you?"],
["I am going to the market today."],
["This is a very beautiful place."],
["Can you help me?"],
]
gr.Examples(
examples=examples,
inputs=[input_text],
outputs=[output_text],
fn=translate_text,
cache_examples=True
)
# Event handlers
translate_btn.click(
fn=translate_text,
inputs=[input_text],
outputs=[output_text]
)
clear_btn.click(
fn=lambda: ("", ""),
outputs=[input_text, output_text]
)
# Add footer
gr.Markdown("---")
return demo
if __name__ == "__main__":
# Create and launch the interface
demo = create_interface()
# Launch the app
demo.launch(
server_name="0.0.0.0", # Allow external connections
server_port=7860, # Default Gradio port
share=False, # Set to True if you want a public link
debug=True,
show_error=True
)