File size: 3,781 Bytes
76fa2d4
 
 
 
 
 
 
 
 
 
 
 
e2341f1
 
 
76fa2d4
e2341f1
76fa2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b34c972
 
 
 
 
9d0793e
 
 
b34c972
 
 
9d0793e
 
b34c972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76fa2d4
b34c972
 
 
 
 
 
 
 
76fa2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
Custom handler for TranslateGemma on HuggingFace Inference Endpoints.
Properly handles the special chat template format.
"""

import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from typing import Dict, Any


class EndpointHandler:
    def __init__(self, path: str = ""):
        # Load from HuggingFace Hub directly
        model_id = "google/translategemma-12b-it"
        self.processor = AutoProcessor.from_pretrained(model_id)
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=torch.bfloat16
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process translation request.

        Expected input format:
        {
            "inputs": {
                "text": "Text to translate",
                "source_lang_code": "en",
                "target_lang_code": "ja_JP"
            },
            "parameters": {
                "max_new_tokens": 200
            }
        }
        """
        inputs_data = data.get("inputs", data)
        parameters = data.get("parameters", {})

        # Extract translation parameters
        if isinstance(inputs_data, dict):
            text = inputs_data.get("text", "")
            source_lang = inputs_data.get("source_lang_code", "en")
            target_lang = inputs_data.get("target_lang_code", "en")
        else:
            # Fallback for simple string input
            return {"error": "Expected dict with text, source_lang_code, target_lang_code"}

        # Check if target_lang is a custom prompt (for unsupported languages)
        is_custom_prompt = target_lang.startswith("Translate to")

        if is_custom_prompt:
            # Custom prompt format for unsupported languages
            # Add explicit instruction to return ONLY the translation
            prompt = f"<start_of_turn>user\n{target_lang} Output only the translation, no explanations.\n\n{text}<end_of_turn>\n<start_of_turn>model\n"
            tokenized = self.processor.tokenizer(
                prompt,
                return_tensors="pt",
                add_special_tokens=True
            )
            inputs = {k: v.to(self.model.device) for k, v in tokenized.items()}
        else:
            # Standard language code: use structured message format
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "source_lang_code": source_lang,
                            "target_lang_code": target_lang,
                            "text": text,
                        }
                    ],
                }
            ]

            # Apply chat template
            inputs = self.processor.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt"
            ).to(self.model.device, dtype=torch.bfloat16)

        # Generate
        max_new_tokens = parameters.get("max_new_tokens", 2000)

        with torch.inference_mode():
            generation = self.model.generate(
                **inputs,
                do_sample=False,
                max_new_tokens=max_new_tokens
            )

        # Decode - only the new tokens
        input_len = inputs["input_ids"].shape[1]
        generated_tokens = generation[0][input_len:]
        translation = self.processor.decode(generated_tokens, skip_special_tokens=True)

        return {
            "translation": translation.strip(),
            "source_lang": source_lang,
            "target_lang": target_lang
        }