Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, M2M100ForConditionalGeneration | |
| # Language configuration with specialized models | |
| LANGUAGE_CONFIG = { | |
| "Amharic": { | |
| "code": "amh", | |
| "model_type": "nllb", | |
| "nllb_code": "amh_Ethi" | |
| }, | |
| "Swahili": { | |
| "code": "swh", | |
| "model_type": "helsinki_swahili", | |
| "helsinki_code": "swc" | |
| }, | |
| "Somali": { | |
| "code": "som", | |
| "model_type": "m2m", | |
| "m2m_code": "so" | |
| }, | |
| "Afan Oromo": { | |
| "code": "gaz", | |
| "model_type": "nllb", | |
| "nllb_code": "gaz_Latn" | |
| }, | |
| "Tigrinya": { | |
| "code": "tir", | |
| "model_type": "nllb", | |
| "nllb_code": "tir_Ethi" | |
| }, | |
| "Chichewa": { | |
| "code": "nya", | |
| "model_type": "nllb", | |
| "nllb_code": "nya_Latn" | |
| } | |
| } | |
| # Model instances | |
| models = {} | |
| tokenizers = {} | |
| print("π Initializing translation models...") | |
| # Load Helsinki-NLP Swahili model | |
| try: | |
| print("π₯ Loading Helsinki-NLP Swahili model...") | |
| swahili_model_id = "Helsinki-NLP/opus-mt-swc-en" | |
| tokenizers['helsinki_swahili'] = AutoTokenizer.from_pretrained(swahili_model_id) | |
| models['helsinki_swahili'] = AutoModelForSeq2SeqLM.from_pretrained(swahili_model_id) | |
| print("β Helsinki-NLP Swahili model loaded successfully!") | |
| except Exception as e: | |
| print(f"β Failed to load Helsinki-NLP Swahili model: {e}") | |
| models['helsinki_swahili'] = None | |
| # Load M2M100 model for Somali | |
| try: | |
| print("π₯ Loading M2M100 model for Somali...") | |
| m2m_model_id = "facebook/m2m100_418M" | |
| tokenizers['m2m'] = AutoTokenizer.from_pretrained(m2m_model_id) | |
| models['m2m'] = M2M100ForConditionalGeneration.from_pretrained(m2m_model_id) | |
| print("β M2M100 model loaded successfully!") | |
| except Exception as e: | |
| print(f"β Failed to load M2M100 model: {e}") | |
| models['m2m'] = None | |
| # Load NLLB model for other languages | |
| try: | |
| print("π₯ Loading NLLB model...") | |
| nllb_model_id = "facebook/nllb-200-distilled-600M" | |
| tokenizers['nllb'] = AutoTokenizer.from_pretrained(nllb_model_id) | |
| models['nllb'] = AutoModelForSeq2SeqLM.from_pretrained(nllb_model_id) | |
| print("β NLLB model loaded successfully!") | |
| except Exception as e: | |
| print(f"β Failed to load NLLB model: {e}") | |
| models['nllb'] = None | |
| def translate_with_helsinki_swahili(text): | |
| """Translate Swahili text using Helsinki-NLP model""" | |
| try: | |
| if models.get('helsinki_swahili') is None or tokenizers.get('helsinki_swahili') is None: | |
| return "Swahili translation model not available" | |
| # Tokenize input | |
| inputs = tokenizers['helsinki_swahili'](text, return_tensors="pt", truncation=True, max_length=512) | |
| # Generate translation | |
| with torch.no_grad(): | |
| generated_tokens = models['helsinki_swahili'].generate( | |
| **inputs, | |
| max_length=256, | |
| num_beams=5, | |
| early_stopping=True | |
| ) | |
| # Decode | |
| translation = tokenizers['helsinki_swahili'].batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| return translation | |
| except Exception as e: | |
| print(f"Helsinki Swahili translation error: {e}") | |
| # Fallback to M2M100 if available | |
| if models.get('m2m') is not None: | |
| return translate_with_m2m(text, "sw") | |
| # Fallback to NLLB if available | |
| elif models.get('nllb') is not None: | |
| return translate_with_nllb(text, "swh_Latn") | |
| return f"Translation failed: {str(e)[:200]}" | |
| def translate_with_m2m(text, source_lang_code): | |
| """Translate text using M2M100 model""" | |
| try: | |
| if models.get('m2m') is None or tokenizers.get('m2m') is None: | |
| return "M2M100 model not available" | |
| # Set source language | |
| tokenizers['m2m'].src_lang = source_lang_code | |
| # Tokenize input | |
| inputs = tokenizers['m2m'](text, return_tensors="pt", truncation=True, max_length=512) | |
| # Generate translation to English | |
| with torch.no_grad(): | |
| generated_tokens = models['m2m'].generate( | |
| **inputs, | |
| forced_bos_token_id=tokenizers['m2m'].get_lang_id("en"), | |
| max_length=256, | |
| num_beams=3, | |
| early_stopping=True | |
| ) | |
| # Decode | |
| translation = tokenizers['m2m'].batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| return translation | |
| except Exception as e: | |
| print(f"M2M100 translation error: {e}") | |
| # Fallback to NLLB if available | |
| if models.get('nllb') is not None: | |
| lang_map = {"so": "som_Latn", "sw": "swh_Latn"} | |
| nllb_code = lang_map.get(source_lang_code, "eng_Latn") | |
| return translate_with_nllb(text, nllb_code) | |
| return f"Translation failed: {str(e)[:200]}" | |
| def translate_with_nllb(text, source_lang_code): | |
| """Translate text using NLLB model""" | |
| try: | |
| if models.get('nllb') is None or tokenizers.get('nllb') is None: | |
| return "NLLB model not available" | |
| # Tokenize input | |
| inputs = tokenizers['nllb'](text, return_tensors="pt", truncation=True, max_length=512) | |
| # Define target language (English) | |
| forced_bos_token_id = tokenizers['nllb'].convert_tokens_to_ids("eng_Latn") | |
| # Generate translation | |
| with torch.no_grad(): | |
| generated_tokens = models['nllb'].generate( | |
| **inputs, | |
| forced_bos_token_id=forced_bos_token_id, | |
| max_length=256, | |
| num_beams=3, | |
| early_stopping=True | |
| ) | |
| # Decode | |
| translation = tokenizers['nllb'].batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| return translation | |
| except Exception as e: | |
| print(f"NLLB translation error: {e}") | |
| return f"Translation failed: {str(e)[:200]}" | |
| def translate_text(text, source_language): | |
| """Main translation function""" | |
| if not text.strip(): | |
| return "Please enter text to translate" | |
| if source_language not in LANGUAGE_CONFIG: | |
| return f"Translation for {source_language} is not supported" | |
| config = LANGUAGE_CONFIG[source_language] | |
| try: | |
| if config["model_type"] == "helsinki_swahili": | |
| return translate_with_helsinki_swahili(text) | |
| elif config["model_type"] == "m2m": | |
| return translate_with_m2m(text, config["m2m_code"]) | |
| else: # nllb | |
| return translate_with_nllb(text, config["nllb_code"]) | |
| except Exception as e: | |
| print(f"Translation error for {source_language}: {e}") | |
| return f"Translation failed: {str(e)[:200]}" | |
| # Example texts for each language | |
| EXAMPLE_TEXTS = { | |
| "Amharic": "ααα α°α α ααα αα₯αΆα½ α₯α©α ααα’", | |
| "Swahili": "Habari za asubuhi, leo tunajifunza teknolojia ya usemi.", | |
| "Somali": "Maanta waa maalin qurux badan oo qoraxdu si wanaagsan u iftiimayso.", | |
| "Afan Oromo": "Akkam bulte, har'a technology dubbachuu baranna.", | |
| "Tigrinya": "αααα² α°ααα‘ αα α΄αααα αα¨α£ αααα₯α’", | |
| "Chichewa": "Alipo wina aliyense ali ndi ufulu wachibadwidwe." | |
| } | |
| # Test the models on startup | |
| def test_models(): | |
| print("π§ͺ Testing translation models...") | |
| test_cases = [ | |
| ("Swahili", "Habari za asubuhi"), | |
| ("Somali", "Maanta waa maalin fiican"), | |
| ("Amharic", "α°αα"), | |
| ("Afan Oromo", "Akkam jirta"), | |
| ("Tigrinya", "α°αα"), | |
| ("Chichewa", "Moni") | |
| ] | |
| for lang, text in test_cases: | |
| try: | |
| result = translate_text(text, lang) | |
| print(f"β {lang} test: '{text}' β '{result}'") | |
| except Exception as e: | |
| print(f"β {lang} test failed: {e}") | |
| # Run tests on startup | |
| test_models() | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="green" | |
| ), | |
| title="π GihonTech - Local Language to English Translation" | |
| ) as demo: | |
| gr.Markdown("# π GihonTech Local Language to English Translation") | |
| gr.Markdown("Translate text from African languages to English using specialized AI models") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Source Text", | |
| placeholder="Enter text to translate...", | |
| lines=4, | |
| show_copy_button=True | |
| ) | |
| language_select = gr.Dropdown( | |
| choices=list(LANGUAGE_CONFIG.keys()), | |
| value="Swahili", | |
| label="Source Language", | |
| info="Select the language of your text" | |
| ) | |
| # Example buttons in two rows | |
| with gr.Row(): | |
| for lang in ["Amharic", "Swahili", "Somali"]: | |
| gr.Button( | |
| f"{lang} Example", | |
| size="sm" | |
| ).click( | |
| lambda l=lang: EXAMPLE_TEXTS[l], | |
| outputs=text_input | |
| ) | |
| with gr.Row(): | |
| for lang in ["Afan Oromo", "Tigrinya", "Chichewa"]: | |
| gr.Button( | |
| f"{lang} Example", | |
| size="sm" | |
| ).click( | |
| lambda l=lang: EXAMPLE_TEXTS[l], | |
| outputs=text_input | |
| ) | |
| translate_btn = gr.Button( | |
| "π― Translate to English", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| translation_output = gr.Textbox( | |
| label="English Translation", | |
| placeholder="Your translated text will appear here...", | |
| lines=5, | |
| show_copy_button=True | |
| ) | |
| # Connect the translate button | |
| translate_btn.click( | |
| fn=translate_text, | |
| inputs=[text_input, language_select], | |
| outputs=translation_output | |
| ) | |
| # Also allow pressing Enter to translate | |
| text_input.submit( | |
| fn=translate_text, | |
| inputs=[text_input, language_select], | |
| outputs=translation_output | |
| ) | |
| # Model status and information | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π§ Model Information") | |
| # Create status display | |
| helsinki_status = "β Loaded" if models.get('helsinki_swahili') else "β Failed" | |
| m2m_status = "β Loaded" if models.get('m2m') else "β Failed" | |
| nllb_status = "β Loaded" if models.get('nllb') else "β Failed" | |
| status_text = f"Helsinki Swahili: {helsinki_status} | M2M100: {m2m_status} | NLLB: {nllb_status}" | |
| gr.Textbox( | |
| value=status_text, | |
| label="Model Status", | |
| interactive=False | |
| ) | |
| # Create model info | |
| gr.Markdown(f""" | |
| **Specialized Models:** | |
| - **Swahili:** Helsinki-NLP/opus-mt-swc-en (Specialized SwahiliβEnglish) | |
| - **Somali:** Facebook M2M100 | |
| - **Other Languages:** Facebook NLLB-200 | |
| **Features:** | |
| - High-quality specialized model for Swahili translation | |
| - Optimized models for each language family | |
| - Cross-model fallback for reliability | |
| - Fast and accurate results | |
| """) | |
| # Add CSS for better styling | |
| gr.HTML(""" | |
| <style> | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .textbox textarea { | |
| min-height: 120px; | |
| } | |
| </style> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |