import torch import os import gradio as gr from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from IndicTransToolkit.processor import IndicProcessor # Get token from environment variable token = os.getenv("HUGGINGFACE_HUB_TOKEN") # Device configuration DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Model configuration - English to Kannada translation src_lang, tgt_lang = "eng_Latn", "kan_Knda" model_name = "ai4bharat/indictrans2-en-indic-dist-200M" # Global variables to store model and tokenizer model = None tokenizer = None ip = None def load_model(): """Load the translation model and tokenizer""" global model, tokenizer, ip try: print(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, token=token ) model = AutoModelForSeq2SeqLM.from_pretrained( model_name, trust_remote_code=True, dtype=torch.float16, token=token ).to(DEVICE) ip = IndicProcessor(inference=True) print(f"Model loaded successfully on {DEVICE}") return True except Exception as e: print(f"Error loading model: {str(e)}") return False def translate_text(input_text): """ Translate input text using the loaded model Args: input_text: Single sentence to translate Returns: Translated text """ if not model or not tokenizer or not ip: return "❌ Model not loaded. Please check the model configuration." if not input_text.strip(): return "Please enter some text to translate." try: # Single sentence translation input_sentences = [input_text.strip()] if not input_sentences: return "No valid sentences found." # Preprocess the input batch = ip.preprocess_batch( input_sentences, src_lang=src_lang, tgt_lang=tgt_lang, ) # Tokenize the sentences inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) # Generate translations with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=False, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) # Decode the generated tokens generated_tokens = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True, ) # Postprocess the translations translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang) # Return single translation return translations[0] if translations else "Translation failed." except Exception as e: return f"❌ Translation error: {str(e)}" def create_interface(): """Create and configure the Gradio interface""" # Load model on startup model_loaded = load_model() if not model_loaded: # Create a simple error interface with gr.Blocks(title="Translation App - Error") as demo: gr.Markdown("## ❌ Model Loading Error") gr.Markdown("Failed to load the translation model. Please check:") gr.Markdown("- Your Hugging Face token is set correctly") gr.Markdown("- You have access to the gated model") gr.Markdown("- Your internet connection is working") return demo # Create the main interface with gr.Blocks( title="AI4Bharat IndicTrans2 Translation", theme=gr.themes.Soft(), ) as demo: gr.Markdown( f""" # 🌍 AI4Bharat IndicTrans2 Translation **Current Configuration:** - **Source Language:** {src_lang} (English) - **Target Language:** {tgt_lang} (Kannada) - **Model:** {model_name} - **Device:** {DEVICE} Enter text below to translate from English to Kannada. """) with gr.Row(): with gr.Column(): input_text = gr.Textbox( label=f"Input Text ({src_lang})", placeholder="Enter English text to translate...", lines=5, max_lines=10 ) with gr.Row(): translate_btn = gr.Button("🔄 Translate", variant="primary") clear_btn = gr.Button("🗑️ Clear") with gr.Column(): output_text = gr.Textbox( label=f"Translation ({tgt_lang})", lines=5, max_lines=10, interactive=False ) # Example inputs gr.Markdown("### 📝 Example Inputs:") examples = [ ["Hello, how are you?"], ["I am going to the market today."], ["This is a very beautiful place."], ["Can you help me?"], ] gr.Examples( examples=examples, inputs=[input_text], outputs=[output_text], fn=translate_text, cache_examples=True ) # Event handlers translate_btn.click( fn=translate_text, inputs=[input_text], outputs=[output_text] ) clear_btn.click( fn=lambda: ("", ""), outputs=[input_text, output_text] ) # Add footer gr.Markdown("---") return demo if __name__ == "__main__": # Create and launch the interface demo = create_interface() # Launch the app demo.launch( server_name="0.0.0.0", # Allow external connections server_port=7860, # Default Gradio port share=False, # Set to True if you want a public link debug=True, show_error=True )