|
|
import os |
|
|
from transformers import MT5ForConditionalGeneration, MT5Tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = MT5ForConditionalGeneration.from_pretrained("mt5_model", local_files_only=True) |
|
|
import os |
|
|
tokenizer = MT5Tokenizer.from_pretrained(os.path.join(os.path.dirname(__file__), "mt5_model"), local_files_only=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_mt5_model(user_input: str, target_language: str = "fr") -> str: |
|
|
|
|
|
|
|
|
input_text = f"translate {target_language}: {user_input}" |
|
|
|
|
|
|
|
|
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True) |
|
|
outputs = model.generate(inputs, max_length=128, num_beams=4, early_stopping=True) |
|
|
|
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return response |
|
|
|