File size: 12,982 Bytes
785f55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
095f90d
 
 
785f55b
 
 
095f90d
785f55b
 
 
 
 
095f90d
785f55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b02bb1
785f55b
 
 
9b02bb1
785f55b
 
 
9b02bb1
785f55b
9b02bb1
095f90d
9b02bb1
 
785f55b
 
9b02bb1
 
 
 
 
785f55b
9b02bb1
 
 
 
785f55b
 
095f90d
9b02bb1
785f55b
9b02bb1
 
785f55b
9b02bb1
 
 
785f55b
9b02bb1
 
 
 
785f55b
9b02bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785f55b
9b02bb1
 
785f55b
9b02bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785f55b
9b02bb1
 
785f55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
from transformers import PreTrainedModel, AutoConfig, AutoModel
try:
    from .configuration_setu_translation import SetuTranslationConfig
except ImportError:
    from configuration_setu_translation import SetuTranslationConfig
import torch
import os
import numpy as np
import json
import onnxruntime as ort
import sentencepiece as spm
from typing import List, Tuple
from huggingface_hub import snapshot_download


class SetuTranslationModel(PreTrainedModel):
    """SETU Translation Model for Hugging Face Hub
    
    This model performs script-agnostic translation to unified English output.
    It handles multiscript, multilingual, and informal text translation.
    """
    
    config_class = SetuTranslationConfig

    def __init__(self, config):
        super().__init__(config)
        
        self.config = config
        
        # Initialize model components
        self.encoder_session = None
        self.decoder_session = None
        self.sp = None
        
        # Load model files if they exist
        self._load_model_components()
        
    def _load_model_components(self):
        """Load ONNX models and SentencePiece processor"""
        model_dir = getattr(self.config, '_name_or_path', '.')
        
        # Paths to model files in assets folder
        assets_dir = os.path.join(model_dir, 'assets')
        encoder_path = os.path.join(assets_dir, 'encoder.onnx')
        decoder_path = os.path.join(assets_dir, 'decoder.onnx')
        smp_path = os.path.join(assets_dir, 'spm.model')
        
        # Load ONNX models
        # Configure providers: use CUDA if available, fallback to CPU
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
        
        if os.path.exists(encoder_path):
            self.encoder_session = ort.InferenceSession(
                encoder_path,
                providers=providers
            )
            
        if os.path.exists(decoder_path):
            self.decoder_session = ort.InferenceSession(
                decoder_path,
                providers=providers
            )
            
        # Load SentencePiece model
        if os.path.exists(smp_path):
            self.sp = spm.SentencePieceProcessor()
            self.sp.Load(smp_path)
    
    def encode_text(self, text: str) -> np.ndarray:
        """Encode text to token IDs using SentencePiece"""
        if self.sp is None:
            raise ValueError("SentencePiece model not loaded")
            
        # Encode using SentencePiece
        tokens = self.sp.EncodeAsIds(text)
        
        # Add EOS token
        tokens = tokens + [self.config.eos_idx]
        
        return np.array(tokens, dtype=np.int64)
    
    def decode_tokens(self, tokens: List[int]) -> str:
        """Decode token IDs to text using SentencePiece"""
        if self.sp is None:
            raise ValueError("SentencePiece model not loaded")
            
        # Remove special tokens
        tokens = [t for t in tokens if t not in [self.config.bos_idx, self.config.eos_idx, self.config.pad_idx]]
        
        # Decode using SentencePiece
        text = self.sp.DecodeIds(tokens)
        
        return text.strip()
    
    def encode_source(self, src_tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Run encoder on source tokens"""
        if self.encoder_session is None:
            raise ValueError("Encoder model not loaded")
            
        # Prepare inputs
        src_tokens_batch = src_tokens.reshape(1, -1)  # [1, src_len]
        src_lengths = np.array([len(src_tokens)], dtype=np.int64)
        
        # Check encoder input names
        encoder_inputs = [inp.name for inp in self.encoder_session.get_inputs()]
        
        # Build input dict based on what encoder expects
        input_dict = {'src_tokens': src_tokens_batch}
        if 'src_lengths' in encoder_inputs:
            input_dict['src_lengths'] = src_lengths
        
        # Run encoder
        outputs = self.encoder_session.run(None, input_dict)
        
        # Handle encoder outputs
        encoder_out = outputs[0]
        encoder_padding_mask = outputs[1] if len(outputs) > 1 else None
        
        return encoder_out, encoder_padding_mask
    
    def decode_step(self, prev_tokens, encoder_out, encoder_padding_mask):
        """Run decoder for one step"""
        if self.decoder_session is None:
            raise ValueError("Decoder model not loaded")
            
        # Prepare inputs - check if already numpy array
        if isinstance(prev_tokens, np.ndarray):
            prev_tokens_np = prev_tokens  # Already formatted correctly
        else:
            prev_tokens_np = np.array([prev_tokens], dtype=np.int64)  # [1, seq_len]
        
        try:
            # Run decoder
            outputs = self.decoder_session.run(
                None,  # Get all outputs
                {
                    'prev_output_tokens': prev_tokens_np,
                    'encoder_out': encoder_out,
                    'encoder_padding_mask': encoder_padding_mask
                }
            )
            
            # Return logits (first output)
            return outputs[0]
            
        except Exception as e:
            raise RuntimeError(f"Decoder step failed: {e}")
    
    def beam_search_translate(self, src_tokens: np.ndarray) -> List[int]:
        """Perform beam search translation - matches ONNX implementation"""
        # Encode source
        encoder_out, encoder_padding_mask = self.encode_source(src_tokens)
        
        # Initialize beam search parameters
        beam_size = self.config.beam_size
        max_len = self.config.max_len
        len_penalty = self.config.len_penalty
        vocab_size = self.config.tgt_vocab_size
        
        # Initialize beams: (score, tokens)
        # NOTE: start with EOS token, not BOS!
        beams = [(0.0, [self.config.eos_idx])]
        completed = []
        
        for step in range(max_len):
            # Stop early if we have enough good completed hypotheses
            if len(completed) >= beam_size * 2:
                break
                
            all_candidates = []
            
            for score, tokens in beams:
                # Check if beam is completed (don't mark as complete if it's just the starting EOS token)
                if tokens[-1] == self.config.eos_idx and len(tokens) > 1:
                    completed.append((score, tokens))
                    continue
                
                # Prevent EOS at first step (min_len=1)
                should_skip_eos = (step == 0 and len(tokens) == 1)
                
                # Prepare decoder input
                prev_tokens = np.array([tokens], dtype=np.int64)  # [1, tgt_len]
                
                try:
                    # Get logits for next token
                    logits = self.decode_step(prev_tokens, encoder_out, encoder_padding_mask)
                    
                    # Check logits validity
                    if logits is None or logits.size == 0:
                        completed.append((score, tokens + [self.config.eos_idx]))
                        continue
                    
                    # Get log probabilities for last position
                    log_probs = logits[0, -1, :]  # [vocab_size]
                    
                    # Proper log softmax: log(exp(x) / sum(exp(x))) = x - log(sum(exp(x)))
                    max_logit = np.max(log_probs)
                    log_probs_shifted = log_probs - max_logit
                    log_sum_exp = np.log(np.sum(np.exp(log_probs_shifted))) + max_logit
                    log_probs = log_probs - log_sum_exp
                    
                    # Get top-k candidates (expand more than beam_size for diversity)
                    top_k = min(beam_size * 2, vocab_size)
                    top_k_indices = np.argpartition(log_probs, -top_k)[-top_k:]
                    top_k_indices = top_k_indices[np.argsort(log_probs[top_k_indices])][::-1]
                    
                    for idx in top_k_indices[:beam_size * 2]:  # Check more candidates
                        # Skip EOS on first step (min_len=1 constraint)
                        if should_skip_eos and int(idx) == self.config.eos_idx:
                            continue
                        
                        candidate_score = score + log_probs[idx]
                        candidate_tokens = tokens + [int(idx)]
                        all_candidates.append((candidate_score, candidate_tokens))
                        
                        # Stop after we have enough candidates
                        if len(all_candidates) >= beam_size:
                            break
                        
                except Exception as e:
                    # Force completion if decoding fails
                    completed.append((score, tokens + [self.config.eos_idx]))
                    continue
            
            if not all_candidates:
                # All beams completed
                break
            
            # Select top beam_size candidates
            # Sort by cumulative score (no length penalty during search, only at finalization)
            ordered = sorted(
                all_candidates,
                key=lambda x: x[0],
                reverse=True
            )
            beams = ordered[:beam_size]
        
        # Add remaining beams to completed
        completed.extend(beams)
        
        # Ensure we have at least one hypothesis
        if not completed:
            completed = [(0.0, [self.config.eos_idx, self.config.eos_idx])]
        
        # Sort by score with length penalty
        # Length = number of generated tokens (excluding starting EOS, including final EOS)
        # tokens = [EOS, tok1, tok2, ..., EOS], so length = len(tokens) - 1
        # Use max(1, ...) to avoid division by zero for very short sequences
        completed = sorted(
            completed,
            key=lambda x: x[0] / (max(1, len(x[1]) - 1) ** len_penalty),
            reverse=True
        )
        
        # Return best translation tokens
        best_score, best_tokens = completed[0]
        return best_tokens
    
    def translate(self, text: str) -> str:
        """Translate input text to English
        
        Args:
            text: Input text in any supported script/language
            
        Returns:
            Translated English text
        """
        # Encode input text
        src_tokens = self.encode_text(text)
        
        # Perform beam search translation
        output_tokens = self.beam_search_translate(src_tokens)
        
        # Decode output tokens
        translated_text = self.decode_tokens(output_tokens)
        
        return translated_text
    
    def forward(self, text: str) -> str:
        """Forward pass - alias for translate method for simple usage"""
        return self.translate(text)
    
    def __call__(self, text: str) -> str:
        """Make model callable - enables model("text") usage"""
        return self.translate(text)
    
    @classmethod
    def from_pretrained(cls,
        pretrained_model_name_or_path,
        *,
        force_download=False,
        resume_download=None,
        proxies=None,
        token=None,
        cache_dir=None,
        local_files_only=False,
        revision=None, 
        **kwargs):
        """Load model from Hugging Face Hub or local directory"""
        
        # Download model if it's a hub model
        if not os.path.isdir(pretrained_model_name_or_path):
            model_dir = snapshot_download(
                repo_id=pretrained_model_name_or_path, 
                token=token,
                cache_dir=cache_dir,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                revision=revision
            )
        else:
            model_dir = pretrained_model_name_or_path
            
        # Load config
        config_path = os.path.join(model_dir, 'config.json')
        if os.path.exists(config_path):
            config = SetuTranslationConfig.from_json_file(config_path)
        else:
            # Load from model_config.json if config.json doesn't exist
            model_config_path = os.path.join(model_dir, 'model_config.json')
            if os.path.exists(model_config_path):
                with open(model_config_path, 'r') as f:
                    model_config = json.load(f)
                config = SetuTranslationConfig(**model_config, **kwargs)
            else:
                config = SetuTranslationConfig(**kwargs)
        
        # Set the model directory path
        config._name_or_path = model_dir
        
        # Create model instance
        model = cls(config)
        
        return model