|
|
""" |
|
|
Custom handler for TranslateGemma on HuggingFace Inference Endpoints. |
|
|
Properly handles the special chat template format. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
from typing import Dict, Any |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
|
|
|
model_id = "google/translategemma-12b-it" |
|
|
self.processor = AutoProcessor.from_pretrained(model_id) |
|
|
self.model = AutoModelForImageTextToText.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process translation request. |
|
|
|
|
|
Expected input format: |
|
|
{ |
|
|
"inputs": { |
|
|
"text": "Text to translate", |
|
|
"source_lang_code": "en", |
|
|
"target_lang_code": "ja_JP" |
|
|
}, |
|
|
"parameters": { |
|
|
"max_new_tokens": 200 |
|
|
} |
|
|
} |
|
|
""" |
|
|
inputs_data = data.get("inputs", data) |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
if isinstance(inputs_data, dict): |
|
|
text = inputs_data.get("text", "") |
|
|
source_lang = inputs_data.get("source_lang_code", "en") |
|
|
target_lang = inputs_data.get("target_lang_code", "en") |
|
|
else: |
|
|
|
|
|
return {"error": "Expected dict with text, source_lang_code, target_lang_code"} |
|
|
|
|
|
|
|
|
is_custom_prompt = target_lang.startswith("Translate to") |
|
|
|
|
|
if is_custom_prompt: |
|
|
|
|
|
|
|
|
prompt = f"<start_of_turn>user\n{target_lang} Output only the translation, no explanations.\n\n{text}<end_of_turn>\n<start_of_turn>model\n" |
|
|
tokenized = self.processor.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
add_special_tokens=True |
|
|
) |
|
|
inputs = {k: v.to(self.model.device) for k, v in tokenized.items()} |
|
|
else: |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "text", |
|
|
"source_lang_code": source_lang, |
|
|
"target_lang_code": target_lang, |
|
|
"text": text, |
|
|
} |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
inputs = self.processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
add_generation_prompt=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt" |
|
|
).to(self.model.device, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 2000) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
generation = self.model.generate( |
|
|
**inputs, |
|
|
do_sample=False, |
|
|
max_new_tokens=max_new_tokens |
|
|
) |
|
|
|
|
|
|
|
|
input_len = inputs["input_ids"].shape[1] |
|
|
generated_tokens = generation[0][input_len:] |
|
|
translation = self.processor.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
|
|
return { |
|
|
"translation": translation.strip(), |
|
|
"source_lang": source_lang, |
|
|
"target_lang": target_lang |
|
|
} |
|
|
|