""" Custom handler for TranslateGemma on HuggingFace Inference Endpoints. Properly handles the special chat template format. """ import torch from transformers import AutoModelForImageTextToText, AutoProcessor from typing import Dict, Any class EndpointHandler: def __init__(self, path: str = ""): # Load from HuggingFace Hub directly model_id = "google/translategemma-12b-it" self.processor = AutoProcessor.from_pretrained(model_id) self.model = AutoModelForImageTextToText.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process translation request. Expected input format: { "inputs": { "text": "Text to translate", "source_lang_code": "en", "target_lang_code": "ja_JP" }, "parameters": { "max_new_tokens": 200 } } """ inputs_data = data.get("inputs", data) parameters = data.get("parameters", {}) # Extract translation parameters if isinstance(inputs_data, dict): text = inputs_data.get("text", "") source_lang = inputs_data.get("source_lang_code", "en") target_lang = inputs_data.get("target_lang_code", "en") else: # Fallback for simple string input return {"error": "Expected dict with text, source_lang_code, target_lang_code"} # Check if target_lang is a custom prompt (for unsupported languages) is_custom_prompt = target_lang.startswith("Translate to") if is_custom_prompt: # Custom prompt format for unsupported languages # Add explicit instruction to return ONLY the translation prompt = f"user\n{target_lang} Output only the translation, no explanations.\n\n{text}\nmodel\n" tokenized = self.processor.tokenizer( prompt, return_tensors="pt", add_special_tokens=True ) inputs = {k: v.to(self.model.device) for k, v in tokenized.items()} else: # Standard language code: use structured message format messages = [ { "role": "user", "content": [ { "type": "text", "source_lang_code": source_lang, "target_lang_code": target_lang, "text": text, } ], } ] # Apply chat template inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to(self.model.device, dtype=torch.bfloat16) # Generate max_new_tokens = parameters.get("max_new_tokens", 2000) with torch.inference_mode(): generation = self.model.generate( **inputs, do_sample=False, max_new_tokens=max_new_tokens ) # Decode - only the new tokens input_len = inputs["input_ids"].shape[1] generated_tokens = generation[0][input_len:] translation = self.processor.decode(generated_tokens, skip_special_tokens=True) return { "translation": translation.strip(), "source_lang": source_lang, "target_lang": target_lang }