File size: 3,781 Bytes
76fa2d4 e2341f1 76fa2d4 e2341f1 76fa2d4 b34c972 9d0793e b34c972 9d0793e b34c972 76fa2d4 b34c972 76fa2d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
"""
Custom handler for TranslateGemma on HuggingFace Inference Endpoints.
Properly handles the special chat template format.
"""
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path: str = ""):
# Load from HuggingFace Hub directly
model_id = "google/translategemma-12b-it"
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process translation request.
Expected input format:
{
"inputs": {
"text": "Text to translate",
"source_lang_code": "en",
"target_lang_code": "ja_JP"
},
"parameters": {
"max_new_tokens": 200
}
}
"""
inputs_data = data.get("inputs", data)
parameters = data.get("parameters", {})
# Extract translation parameters
if isinstance(inputs_data, dict):
text = inputs_data.get("text", "")
source_lang = inputs_data.get("source_lang_code", "en")
target_lang = inputs_data.get("target_lang_code", "en")
else:
# Fallback for simple string input
return {"error": "Expected dict with text, source_lang_code, target_lang_code"}
# Check if target_lang is a custom prompt (for unsupported languages)
is_custom_prompt = target_lang.startswith("Translate to")
if is_custom_prompt:
# Custom prompt format for unsupported languages
# Add explicit instruction to return ONLY the translation
prompt = f"<start_of_turn>user\n{target_lang} Output only the translation, no explanations.\n\n{text}<end_of_turn>\n<start_of_turn>model\n"
tokenized = self.processor.tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=True
)
inputs = {k: v.to(self.model.device) for k, v in tokenized.items()}
else:
# Standard language code: use structured message format
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"source_lang_code": source_lang,
"target_lang_code": target_lang,
"text": text,
}
],
}
]
# Apply chat template
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
# Generate
max_new_tokens = parameters.get("max_new_tokens", 2000)
with torch.inference_mode():
generation = self.model.generate(
**inputs,
do_sample=False,
max_new_tokens=max_new_tokens
)
# Decode - only the new tokens
input_len = inputs["input_ids"].shape[1]
generated_tokens = generation[0][input_len:]
translation = self.processor.decode(generated_tokens, skip_special_tokens=True)
return {
"translation": translation.strip(),
"source_lang": source_lang,
"target_lang": target_lang
}
|