| from typing import Dict, List, Any | |
| from transformers import MarianMTModel, MarianTokenizer | |
| class EndpointHandler: | |
| def __init__(self, path: str = ""): | |
| self.tokenizer = MarianTokenizer.from_pretrained(path) | |
| self.model = MarianMTModel.from_pretrained(path) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: | |
| inputs = data.get("inputs", "") | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| encoded = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True) | |
| translated = self.model.generate(**encoded) | |
| decoded = self.tokenizer.batch_decode(translated, skip_special_tokens=True) | |
| return [{"translation_text": text} for text in decoded] | |