|
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import torch |
|
|
|
|
|
class ImprovedTranslator: |
|
|
def __init__(self): |
|
|
"""Initialize translator with multiple model options""" |
|
|
self.device = 0 if torch.cuda.is_available() else -1 |
|
|
self.models = {} |
|
|
self.current_model = "nllb" |
|
|
|
|
|
def load_model(self, model_type="nllb"): |
|
|
"""Load translation model based on type""" |
|
|
if model_type == "nllb" and "nllb" not in self.models: |
|
|
|
|
|
|
|
|
self.models["nllb"] = pipeline( |
|
|
"translation", |
|
|
model="facebook/nllb-200-distilled-600M", |
|
|
device=self.device, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
|
|
|
elif model_type == "mbart" and "mbart" not in self.models: |
|
|
|
|
|
self.models["mbart"] = pipeline( |
|
|
"translation", |
|
|
model="facebook/mbart-large-50-many-to-many-mmt", |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
elif model_type == "opus" and "opus" not in self.models: |
|
|
|
|
|
self.models["opus"] = { |
|
|
"id_en": pipeline("translation", model="Helsinki-NLP/opus-mt-id-en", device=self.device), |
|
|
"en_id": pipeline("translation", model="Helsinki-NLP/opus-mt-en-id", device=self.device) |
|
|
} |
|
|
|
|
|
elif model_type == "t5" and "t5" not in self.models: |
|
|
|
|
|
self.models["t5"] = pipeline( |
|
|
"translation", |
|
|
model="google/flan-t5-base", |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
def translate_with_nllb(self, text, direction): |
|
|
"""Translate using NLLB model""" |
|
|
if "nllb" not in self.models: |
|
|
self.load_model("nllb") |
|
|
|
|
|
if direction == "ID β EN": |
|
|
src_lang = "ind_Latn" |
|
|
tgt_lang = "eng_Latn" |
|
|
else: |
|
|
src_lang = "eng_Latn" |
|
|
tgt_lang = "ind_Latn" |
|
|
|
|
|
result = self.models["nllb"]( |
|
|
text, |
|
|
src_lang=src_lang, |
|
|
tgt_lang=tgt_lang, |
|
|
max_length=512 |
|
|
) |
|
|
return result[0]['translation_text'] |
|
|
|
|
|
def translate_with_mbart(self, text, direction): |
|
|
"""Translate using mBART model""" |
|
|
if "mbart" not in self.models: |
|
|
self.load_model("mbart") |
|
|
|
|
|
if direction == "ID β EN": |
|
|
|
|
|
text = f">>en<< {text}" |
|
|
else: |
|
|
text = f">>id<< {text}" |
|
|
|
|
|
result = self.models["mbart"](text) |
|
|
return result[0]['translation_text'] |
|
|
|
|
|
def translate_with_opus(self, text, direction): |
|
|
"""Translate using original Helsinki-NLP model""" |
|
|
if "opus" not in self.models: |
|
|
self.load_model("opus") |
|
|
|
|
|
if direction == "ID β EN": |
|
|
return self.models["opus"]["id_en"](text)[0]['translation_text'] |
|
|
else: |
|
|
return self.models["opus"]["en_id"](text)[0]['translation_text'] |
|
|
|
|
|
|
|
|
translator = ImprovedTranslator() |
|
|
|
|
|
def translate(text, direction, model_type="nllb"): |
|
|
""" |
|
|
Main translation function |
|
|
|
|
|
Args: |
|
|
text (str): Text to translate |
|
|
direction (str): "ID β EN" or "EN β ID" |
|
|
model_type (str): "nllb", "mbart", "opus", or "t5" |
|
|
""" |
|
|
try: |
|
|
if model_type == "nllb": |
|
|
return translator.translate_with_nllb(text, direction) |
|
|
elif model_type == "mbart": |
|
|
return translator.translate_with_mbart(text, direction) |
|
|
elif model_type == "opus": |
|
|
return translator.translate_with_opus(text, direction) |
|
|
else: |
|
|
|
|
|
return translator.translate_with_nllb(text, direction) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Translation error with {model_type}: {e}") |
|
|
|
|
|
if model_type != "opus": |
|
|
return translator.translate_with_opus(text, direction) |
|
|
else: |
|
|
return f"Translation failed: {str(e)}" |
|
|
|
|
|
|
|
|
def translate_simple(text, direction): |
|
|
"""Simple wrapper for backward compatibility""" |
|
|
return translate(text, direction, "nllb") |