mrmiku's picture
Upload handler.py with huggingface_hub
9d0793e verified
"""
Custom handler for TranslateGemma on HuggingFace Inference Endpoints.
Properly handles the special chat template format.
"""
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path: str = ""):
# Load from HuggingFace Hub directly
model_id = "google/translategemma-12b-it"
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process translation request.
Expected input format:
{
"inputs": {
"text": "Text to translate",
"source_lang_code": "en",
"target_lang_code": "ja_JP"
},
"parameters": {
"max_new_tokens": 200
}
}
"""
inputs_data = data.get("inputs", data)
parameters = data.get("parameters", {})
# Extract translation parameters
if isinstance(inputs_data, dict):
text = inputs_data.get("text", "")
source_lang = inputs_data.get("source_lang_code", "en")
target_lang = inputs_data.get("target_lang_code", "en")
else:
# Fallback for simple string input
return {"error": "Expected dict with text, source_lang_code, target_lang_code"}
# Check if target_lang is a custom prompt (for unsupported languages)
is_custom_prompt = target_lang.startswith("Translate to")
if is_custom_prompt:
# Custom prompt format for unsupported languages
# Add explicit instruction to return ONLY the translation
prompt = f"<start_of_turn>user\n{target_lang} Output only the translation, no explanations.\n\n{text}<end_of_turn>\n<start_of_turn>model\n"
tokenized = self.processor.tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=True
)
inputs = {k: v.to(self.model.device) for k, v in tokenized.items()}
else:
# Standard language code: use structured message format
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"source_lang_code": source_lang,
"target_lang_code": target_lang,
"text": text,
}
],
}
]
# Apply chat template
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
# Generate
max_new_tokens = parameters.get("max_new_tokens", 2000)
with torch.inference_mode():
generation = self.model.generate(
**inputs,
do_sample=False,
max_new_tokens=max_new_tokens
)
# Decode - only the new tokens
input_len = inputs["input_ids"].shape[1]
generated_tokens = generation[0][input_len:]
translation = self.processor.decode(generated_tokens, skip_special_tokens=True)
return {
"translation": translation.strip(),
"source_lang": source_lang,
"target_lang": target_lang
}