mt5-support-tickets / model_mt5.py
abidne's picture
Upload 3 files
4c40701 verified
import os
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
# Chemin vers ton modèle local (dossier contenant config.json, pytorch_model.bin, etc.)
model = MT5ForConditionalGeneration.from_pretrained("mt5_model", local_files_only=True)
import os
tokenizer = MT5Tokenizer.from_pretrained(os.path.join(os.path.dirname(__file__), "mt5_model"), local_files_only=True)
def call_mt5_model(user_input: str, target_language: str = "fr") -> str:
# Préparation de l'entrée (tu peux personnaliser le format selon ton entraînement)
input_text = f"translate {target_language}: {user_input}"
# Encodage et génération
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(inputs, max_length=128, num_beams=4, early_stopping=True)
# Décodage
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response