Minte
Refactor Swahili and Somali model configurations and update loading logic
f525548
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, M2M100ForConditionalGeneration
# Language configuration with specialized models
LANGUAGE_CONFIG = {
"Amharic": {
"code": "amh",
"model_type": "nllb",
"nllb_code": "amh_Ethi"
},
"Swahili": {
"code": "swh",
"model_type": "helsinki_swahili",
"helsinki_code": "swc"
},
"Somali": {
"code": "som",
"model_type": "m2m",
"m2m_code": "so"
},
"Afan Oromo": {
"code": "gaz",
"model_type": "nllb",
"nllb_code": "gaz_Latn"
},
"Tigrinya": {
"code": "tir",
"model_type": "nllb",
"nllb_code": "tir_Ethi"
},
"Chichewa": {
"code": "nya",
"model_type": "nllb",
"nllb_code": "nya_Latn"
}
}
# Model instances
models = {}
tokenizers = {}
print("πŸš€ Initializing translation models...")
# Load Helsinki-NLP Swahili model
try:
print("πŸ“₯ Loading Helsinki-NLP Swahili model...")
swahili_model_id = "Helsinki-NLP/opus-mt-swc-en"
tokenizers['helsinki_swahili'] = AutoTokenizer.from_pretrained(swahili_model_id)
models['helsinki_swahili'] = AutoModelForSeq2SeqLM.from_pretrained(swahili_model_id)
print("βœ… Helsinki-NLP Swahili model loaded successfully!")
except Exception as e:
print(f"❌ Failed to load Helsinki-NLP Swahili model: {e}")
models['helsinki_swahili'] = None
# Load M2M100 model for Somali
try:
print("πŸ“₯ Loading M2M100 model for Somali...")
m2m_model_id = "facebook/m2m100_418M"
tokenizers['m2m'] = AutoTokenizer.from_pretrained(m2m_model_id)
models['m2m'] = M2M100ForConditionalGeneration.from_pretrained(m2m_model_id)
print("βœ… M2M100 model loaded successfully!")
except Exception as e:
print(f"❌ Failed to load M2M100 model: {e}")
models['m2m'] = None
# Load NLLB model for other languages
try:
print("πŸ“₯ Loading NLLB model...")
nllb_model_id = "facebook/nllb-200-distilled-600M"
tokenizers['nllb'] = AutoTokenizer.from_pretrained(nllb_model_id)
models['nllb'] = AutoModelForSeq2SeqLM.from_pretrained(nllb_model_id)
print("βœ… NLLB model loaded successfully!")
except Exception as e:
print(f"❌ Failed to load NLLB model: {e}")
models['nllb'] = None
def translate_with_helsinki_swahili(text):
"""Translate Swahili text using Helsinki-NLP model"""
try:
if models.get('helsinki_swahili') is None or tokenizers.get('helsinki_swahili') is None:
return "Swahili translation model not available"
# Tokenize input
inputs = tokenizers['helsinki_swahili'](text, return_tensors="pt", truncation=True, max_length=512)
# Generate translation
with torch.no_grad():
generated_tokens = models['helsinki_swahili'].generate(
**inputs,
max_length=256,
num_beams=5,
early_stopping=True
)
# Decode
translation = tokenizers['helsinki_swahili'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translation
except Exception as e:
print(f"Helsinki Swahili translation error: {e}")
# Fallback to M2M100 if available
if models.get('m2m') is not None:
return translate_with_m2m(text, "sw")
# Fallback to NLLB if available
elif models.get('nllb') is not None:
return translate_with_nllb(text, "swh_Latn")
return f"Translation failed: {str(e)[:200]}"
def translate_with_m2m(text, source_lang_code):
"""Translate text using M2M100 model"""
try:
if models.get('m2m') is None or tokenizers.get('m2m') is None:
return "M2M100 model not available"
# Set source language
tokenizers['m2m'].src_lang = source_lang_code
# Tokenize input
inputs = tokenizers['m2m'](text, return_tensors="pt", truncation=True, max_length=512)
# Generate translation to English
with torch.no_grad():
generated_tokens = models['m2m'].generate(
**inputs,
forced_bos_token_id=tokenizers['m2m'].get_lang_id("en"),
max_length=256,
num_beams=3,
early_stopping=True
)
# Decode
translation = tokenizers['m2m'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translation
except Exception as e:
print(f"M2M100 translation error: {e}")
# Fallback to NLLB if available
if models.get('nllb') is not None:
lang_map = {"so": "som_Latn", "sw": "swh_Latn"}
nllb_code = lang_map.get(source_lang_code, "eng_Latn")
return translate_with_nllb(text, nllb_code)
return f"Translation failed: {str(e)[:200]}"
def translate_with_nllb(text, source_lang_code):
"""Translate text using NLLB model"""
try:
if models.get('nllb') is None or tokenizers.get('nllb') is None:
return "NLLB model not available"
# Tokenize input
inputs = tokenizers['nllb'](text, return_tensors="pt", truncation=True, max_length=512)
# Define target language (English)
forced_bos_token_id = tokenizers['nllb'].convert_tokens_to_ids("eng_Latn")
# Generate translation
with torch.no_grad():
generated_tokens = models['nllb'].generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=256,
num_beams=3,
early_stopping=True
)
# Decode
translation = tokenizers['nllb'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translation
except Exception as e:
print(f"NLLB translation error: {e}")
return f"Translation failed: {str(e)[:200]}"
def translate_text(text, source_language):
"""Main translation function"""
if not text.strip():
return "Please enter text to translate"
if source_language not in LANGUAGE_CONFIG:
return f"Translation for {source_language} is not supported"
config = LANGUAGE_CONFIG[source_language]
try:
if config["model_type"] == "helsinki_swahili":
return translate_with_helsinki_swahili(text)
elif config["model_type"] == "m2m":
return translate_with_m2m(text, config["m2m_code"])
else: # nllb
return translate_with_nllb(text, config["nllb_code"])
except Exception as e:
print(f"Translation error for {source_language}: {e}")
return f"Translation failed: {str(e)[:200]}"
# Example texts for each language
EXAMPLE_TEXTS = {
"Amharic": "αˆαˆ‰αˆ αˆ°α‹ α‰ αˆαˆ‰αˆ መα‰₯ቢች αŠ₯ኩል αŠα‹α’",
"Swahili": "Habari za asubuhi, leo tunajifunza teknolojia ya usemi.",
"Somali": "Maanta waa maalin qurux badan oo qoraxdu si wanaagsan u iftiimayso.",
"Afan Oromo": "Akkam bulte, har'a technology dubbachuu baranna.",
"Tigrinya": "αˆ˜α‹“αˆα‰² αˆ°αŠ“α‹­α‘ ሎሚ α‰΄αŠ­αŠ–αˆŽαŒ‚ α‹˜αˆ¨α‰£ αŠ•αˆαˆαŒ₯ፒ",
"Chichewa": "Alipo wina aliyense ali ndi ufulu wachibadwidwe."
}
# Test the models on startup
def test_models():
print("πŸ§ͺ Testing translation models...")
test_cases = [
("Swahili", "Habari za asubuhi"),
("Somali", "Maanta waa maalin fiican"),
("Amharic", "αˆ°αˆ‹αˆ"),
("Afan Oromo", "Akkam jirta"),
("Tigrinya", "αˆ°αˆ‹αˆ"),
("Chichewa", "Moni")
]
for lang, text in test_cases:
try:
result = translate_text(text, lang)
print(f"βœ… {lang} test: '{text}' β†’ '{result}'")
except Exception as e:
print(f"❌ {lang} test failed: {e}")
# Run tests on startup
test_models()
# Create Gradio interface
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="green"
),
title="🌍 GihonTech - Local Language to English Translation"
) as demo:
gr.Markdown("# 🌍 GihonTech Local Language to English Translation")
gr.Markdown("Translate text from African languages to English using specialized AI models")
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Source Text",
placeholder="Enter text to translate...",
lines=4,
show_copy_button=True
)
language_select = gr.Dropdown(
choices=list(LANGUAGE_CONFIG.keys()),
value="Swahili",
label="Source Language",
info="Select the language of your text"
)
# Example buttons in two rows
with gr.Row():
for lang in ["Amharic", "Swahili", "Somali"]:
gr.Button(
f"{lang} Example",
size="sm"
).click(
lambda l=lang: EXAMPLE_TEXTS[l],
outputs=text_input
)
with gr.Row():
for lang in ["Afan Oromo", "Tigrinya", "Chichewa"]:
gr.Button(
f"{lang} Example",
size="sm"
).click(
lambda l=lang: EXAMPLE_TEXTS[l],
outputs=text_input
)
translate_btn = gr.Button(
"🎯 Translate to English",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
translation_output = gr.Textbox(
label="English Translation",
placeholder="Your translated text will appear here...",
lines=5,
show_copy_button=True
)
# Connect the translate button
translate_btn.click(
fn=translate_text,
inputs=[text_input, language_select],
outputs=translation_output
)
# Also allow pressing Enter to translate
text_input.submit(
fn=translate_text,
inputs=[text_input, language_select],
outputs=translation_output
)
# Model status and information
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ”§ Model Information")
# Create status display
helsinki_status = "βœ… Loaded" if models.get('helsinki_swahili') else "❌ Failed"
m2m_status = "βœ… Loaded" if models.get('m2m') else "❌ Failed"
nllb_status = "βœ… Loaded" if models.get('nllb') else "❌ Failed"
status_text = f"Helsinki Swahili: {helsinki_status} | M2M100: {m2m_status} | NLLB: {nllb_status}"
gr.Textbox(
value=status_text,
label="Model Status",
interactive=False
)
# Create model info
gr.Markdown(f"""
**Specialized Models:**
- **Swahili:** Helsinki-NLP/opus-mt-swc-en (Specialized Swahili→English)
- **Somali:** Facebook M2M100
- **Other Languages:** Facebook NLLB-200
**Features:**
- High-quality specialized model for Swahili translation
- Optimized models for each language family
- Cross-model fallback for reliability
- Fast and accurate results
""")
# Add CSS for better styling
gr.HTML("""
<style>
.gradio-container {
max-width: 1200px !important;
}
.textbox textarea {
min-height: 120px;
}
</style>
""")
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)