| | import gradio as gr |
| | import torch |
| | import os |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| | import spaces |
| |
|
| | |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| | if HF_TOKEN is None: |
| | raise ValueError("HF_TOKEN environment variable is not set. Please set it before running the script.") |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | zero = torch.Tensor([0]).to(device) |
| | print(f"Device being used: {zero.device}") |
| |
|
| | |
| | MSA_TO_SYRIAN_MODEL = "Omartificial-Intelligence-Space/Shami-MT" |
| | SYRIAN_TO_MSA_MODEL = "Omartificial-Intelligence-Space/SHAMI-MT-2MSA" |
| |
|
| | |
| | print("Loading MSA to Syrian model...") |
| | msa_to_syrian_tokenizer = AutoTokenizer.from_pretrained(MSA_TO_SYRIAN_MODEL) |
| | msa_to_syrian_model = AutoModelForSeq2SeqLM.from_pretrained(MSA_TO_SYRIAN_MODEL).to(device) |
| |
|
| | print("Loading Syrian to MSA model...") |
| | syrian_to_msa_tokenizer = AutoTokenizer.from_pretrained(SYRIAN_TO_MSA_MODEL) |
| | syrian_to_msa_model = AutoModelForSeq2SeqLM.from_pretrained(SYRIAN_TO_MSA_MODEL).to(device) |
| |
|
| | print("Models loaded successfully!") |
| |
|
| | @spaces.GPU(duration=120) |
| | def translate_msa_to_syrian(text): |
| | """Translate from Modern Standard Arabic to Syrian dialect""" |
| | if not text.strip(): |
| | return "" |
| | |
| | try: |
| | input_ids = msa_to_syrian_tokenizer(text, return_tensors="pt").input_ids.to(device) |
| | outputs = msa_to_syrian_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True) |
| | translated_text = msa_to_syrian_tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | return translated_text |
| | except Exception as e: |
| | return f"Translation error: {str(e)}" |
| |
|
| | @spaces.GPU(duration=120) |
| | def translate_syrian_to_msa(text): |
| | """Translate from Syrian dialect to Modern Standard Arabic""" |
| | if not text.strip(): |
| | return "" |
| | |
| | try: |
| | input_ids = syrian_to_msa_tokenizer(text, return_tensors="pt").input_ids.to(device) |
| | outputs = syrian_to_msa_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True) |
| | translated_text = syrian_to_msa_tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | return translated_text |
| | except Exception as e: |
| | return f"Translation error: {str(e)}" |
| |
|
| | def bidirectional_translate(text, direction): |
| | """Handle bidirectional translation based on user selection""" |
| | if direction == "MSA โ Syrian": |
| | return translate_msa_to_syrian(text) |
| | elif direction == "Syrian โ MSA": |
| | return translate_syrian_to_msa(text) |
| | else: |
| | return "Please select a translation direction" |
| |
|
| | |
| | with gr.Blocks(title="SHAMI-MT: Bidirectional Syria Arabic Dialect MT Framework") as demo: |
| | |
| | gr.HTML(""" |
| | <div style="text-align: center; margin-bottom: 2rem;"> |
| | <h1>๐ SHAMI-MT: Bidirectional Arabic Translation</h1> |
| | <p>Translate between Modern Standard Arabic (MSA) and Syrian Dialect</p> |
| | <p><strong>Built on AraT5v2-base-1024 architecture</strong></p> |
| | </div> |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.HTML(""" |
| | <div style="background: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0;"> |
| | <h3>๐ Model Information</h3> |
| | <ul> |
| | <li><strong>Model Type:</strong> Sequence-to-Sequence Translation</li> |
| | <li><strong>Base Model:</strong> UBC-NLP/AraT5v2-base-1024</li> |
| | <li><strong>Languages:</strong> Arabic (MSA โ Syrian Dialect)</li> |
| | <li><strong>Device:</strong> GPU/CPU Auto-detection</li> |
| | </ul> |
| | </div> |
| | """) |
| | |
| | with gr.Column(scale=2): |
| | direction = gr.Dropdown( |
| | choices=["MSA โ Syrian", "Syrian โ MSA"], |
| | value="MSA โ Syrian", |
| | label="Translation Direction" |
| | ) |
| | |
| | input_text = gr.Textbox( |
| | label="Input Text", |
| | placeholder="Enter Arabic text here...", |
| | lines=5 |
| | ) |
| | |
| | translate_btn = gr.Button("๐ Translate", variant="primary") |
| | |
| | output_text = gr.Textbox( |
| | label="Translation", |
| | lines=5 |
| | ) |
| | |
| | |
| | translate_btn.click( |
| | fn=bidirectional_translate, |
| | inputs=[input_text, direction], |
| | outputs=output_text |
| | ) |
| | |
| | |
| | gr.Examples( |
| | examples=[ |
| | ["ุฃูุง ูุง ุฃุนุฑู ุฅุฐุง ูุงู ุณูุชู
ูู ู
ู ุงูุญุถูุฑ ุงูููู
ุฃู
ูุง.", "MSA โ Syrian"], |
| | ["ููู ุญุงููุ", "MSA โ Syrian"], |
| | ["ู
ุง ุจุนุฑู ุฅุฐุง ุฑุญ ููุฏุฑ ูุฌู ุงูููู
ููุง ูุฃ.", "Syrian โ MSA"], |
| | ["ุดููููุ", "Syrian โ MSA"] |
| | ], |
| | inputs=[input_text, direction], |
| | outputs=output_text, |
| | fn=bidirectional_translate |
| | ) |
| |
|
| | |
| | if __name__ == "__main__": |
| | demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |