| import gradio as gr | |
| from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
| # Load the mBART model and tokenizer | |
| model_name = "facebook/mbart-large-50-many-to-many-mmt" | |
| model = MBartForConditionalGeneration.from_pretrained(model_name) | |
| tokenizer = MBart50TokenizerFast.from_pretrained(model_name) | |
| # Hardcoded translation function (English -> Swahili) | |
| def translate(text): | |
| tokenizer.src_lang = "en_XX" # Set source language to English | |
| encoded = tokenizer(text, return_tensors="pt") | |
| generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id["sw_KE"]) # Translate to Swahili | |
| return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## English to Swahili Translator") | |
| src_text = gr.Textbox(label="Enter English text") | |
| output_text = gr.Textbox(label="Swahili Translation", interactive=False) | |
| translate_btn = gr.Button("Translate") | |
| translate_btn.click(translate, inputs=src_text, outputs=output_text) | |
| demo.launch() |