File size: 12,982 Bytes
785f55b 095f90d 785f55b 095f90d 785f55b 095f90d 785f55b 9b02bb1 785f55b 9b02bb1 785f55b 9b02bb1 785f55b 9b02bb1 095f90d 9b02bb1 785f55b 9b02bb1 785f55b 9b02bb1 785f55b 095f90d 9b02bb1 785f55b 9b02bb1 785f55b 9b02bb1 785f55b 9b02bb1 785f55b 9b02bb1 785f55b 9b02bb1 785f55b 9b02bb1 785f55b 9b02bb1 785f55b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 | from transformers import PreTrainedModel, AutoConfig, AutoModel
try:
from .configuration_setu_translation import SetuTranslationConfig
except ImportError:
from configuration_setu_translation import SetuTranslationConfig
import torch
import os
import numpy as np
import json
import onnxruntime as ort
import sentencepiece as spm
from typing import List, Tuple
from huggingface_hub import snapshot_download
class SetuTranslationModel(PreTrainedModel):
"""SETU Translation Model for Hugging Face Hub
This model performs script-agnostic translation to unified English output.
It handles multiscript, multilingual, and informal text translation.
"""
config_class = SetuTranslationConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Initialize model components
self.encoder_session = None
self.decoder_session = None
self.sp = None
# Load model files if they exist
self._load_model_components()
def _load_model_components(self):
"""Load ONNX models and SentencePiece processor"""
model_dir = getattr(self.config, '_name_or_path', '.')
# Paths to model files in assets folder
assets_dir = os.path.join(model_dir, 'assets')
encoder_path = os.path.join(assets_dir, 'encoder.onnx')
decoder_path = os.path.join(assets_dir, 'decoder.onnx')
smp_path = os.path.join(assets_dir, 'spm.model')
# Load ONNX models
# Configure providers: use CUDA if available, fallback to CPU
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
if os.path.exists(encoder_path):
self.encoder_session = ort.InferenceSession(
encoder_path,
providers=providers
)
if os.path.exists(decoder_path):
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=providers
)
# Load SentencePiece model
if os.path.exists(smp_path):
self.sp = spm.SentencePieceProcessor()
self.sp.Load(smp_path)
def encode_text(self, text: str) -> np.ndarray:
"""Encode text to token IDs using SentencePiece"""
if self.sp is None:
raise ValueError("SentencePiece model not loaded")
# Encode using SentencePiece
tokens = self.sp.EncodeAsIds(text)
# Add EOS token
tokens = tokens + [self.config.eos_idx]
return np.array(tokens, dtype=np.int64)
def decode_tokens(self, tokens: List[int]) -> str:
"""Decode token IDs to text using SentencePiece"""
if self.sp is None:
raise ValueError("SentencePiece model not loaded")
# Remove special tokens
tokens = [t for t in tokens if t not in [self.config.bos_idx, self.config.eos_idx, self.config.pad_idx]]
# Decode using SentencePiece
text = self.sp.DecodeIds(tokens)
return text.strip()
def encode_source(self, src_tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Run encoder on source tokens"""
if self.encoder_session is None:
raise ValueError("Encoder model not loaded")
# Prepare inputs
src_tokens_batch = src_tokens.reshape(1, -1) # [1, src_len]
src_lengths = np.array([len(src_tokens)], dtype=np.int64)
# Check encoder input names
encoder_inputs = [inp.name for inp in self.encoder_session.get_inputs()]
# Build input dict based on what encoder expects
input_dict = {'src_tokens': src_tokens_batch}
if 'src_lengths' in encoder_inputs:
input_dict['src_lengths'] = src_lengths
# Run encoder
outputs = self.encoder_session.run(None, input_dict)
# Handle encoder outputs
encoder_out = outputs[0]
encoder_padding_mask = outputs[1] if len(outputs) > 1 else None
return encoder_out, encoder_padding_mask
def decode_step(self, prev_tokens, encoder_out, encoder_padding_mask):
"""Run decoder for one step"""
if self.decoder_session is None:
raise ValueError("Decoder model not loaded")
# Prepare inputs - check if already numpy array
if isinstance(prev_tokens, np.ndarray):
prev_tokens_np = prev_tokens # Already formatted correctly
else:
prev_tokens_np = np.array([prev_tokens], dtype=np.int64) # [1, seq_len]
try:
# Run decoder
outputs = self.decoder_session.run(
None, # Get all outputs
{
'prev_output_tokens': prev_tokens_np,
'encoder_out': encoder_out,
'encoder_padding_mask': encoder_padding_mask
}
)
# Return logits (first output)
return outputs[0]
except Exception as e:
raise RuntimeError(f"Decoder step failed: {e}")
def beam_search_translate(self, src_tokens: np.ndarray) -> List[int]:
"""Perform beam search translation - matches ONNX implementation"""
# Encode source
encoder_out, encoder_padding_mask = self.encode_source(src_tokens)
# Initialize beam search parameters
beam_size = self.config.beam_size
max_len = self.config.max_len
len_penalty = self.config.len_penalty
vocab_size = self.config.tgt_vocab_size
# Initialize beams: (score, tokens)
# NOTE: start with EOS token, not BOS!
beams = [(0.0, [self.config.eos_idx])]
completed = []
for step in range(max_len):
# Stop early if we have enough good completed hypotheses
if len(completed) >= beam_size * 2:
break
all_candidates = []
for score, tokens in beams:
# Check if beam is completed (don't mark as complete if it's just the starting EOS token)
if tokens[-1] == self.config.eos_idx and len(tokens) > 1:
completed.append((score, tokens))
continue
# Prevent EOS at first step (min_len=1)
should_skip_eos = (step == 0 and len(tokens) == 1)
# Prepare decoder input
prev_tokens = np.array([tokens], dtype=np.int64) # [1, tgt_len]
try:
# Get logits for next token
logits = self.decode_step(prev_tokens, encoder_out, encoder_padding_mask)
# Check logits validity
if logits is None or logits.size == 0:
completed.append((score, tokens + [self.config.eos_idx]))
continue
# Get log probabilities for last position
log_probs = logits[0, -1, :] # [vocab_size]
# Proper log softmax: log(exp(x) / sum(exp(x))) = x - log(sum(exp(x)))
max_logit = np.max(log_probs)
log_probs_shifted = log_probs - max_logit
log_sum_exp = np.log(np.sum(np.exp(log_probs_shifted))) + max_logit
log_probs = log_probs - log_sum_exp
# Get top-k candidates (expand more than beam_size for diversity)
top_k = min(beam_size * 2, vocab_size)
top_k_indices = np.argpartition(log_probs, -top_k)[-top_k:]
top_k_indices = top_k_indices[np.argsort(log_probs[top_k_indices])][::-1]
for idx in top_k_indices[:beam_size * 2]: # Check more candidates
# Skip EOS on first step (min_len=1 constraint)
if should_skip_eos and int(idx) == self.config.eos_idx:
continue
candidate_score = score + log_probs[idx]
candidate_tokens = tokens + [int(idx)]
all_candidates.append((candidate_score, candidate_tokens))
# Stop after we have enough candidates
if len(all_candidates) >= beam_size:
break
except Exception as e:
# Force completion if decoding fails
completed.append((score, tokens + [self.config.eos_idx]))
continue
if not all_candidates:
# All beams completed
break
# Select top beam_size candidates
# Sort by cumulative score (no length penalty during search, only at finalization)
ordered = sorted(
all_candidates,
key=lambda x: x[0],
reverse=True
)
beams = ordered[:beam_size]
# Add remaining beams to completed
completed.extend(beams)
# Ensure we have at least one hypothesis
if not completed:
completed = [(0.0, [self.config.eos_idx, self.config.eos_idx])]
# Sort by score with length penalty
# Length = number of generated tokens (excluding starting EOS, including final EOS)
# tokens = [EOS, tok1, tok2, ..., EOS], so length = len(tokens) - 1
# Use max(1, ...) to avoid division by zero for very short sequences
completed = sorted(
completed,
key=lambda x: x[0] / (max(1, len(x[1]) - 1) ** len_penalty),
reverse=True
)
# Return best translation tokens
best_score, best_tokens = completed[0]
return best_tokens
def translate(self, text: str) -> str:
"""Translate input text to English
Args:
text: Input text in any supported script/language
Returns:
Translated English text
"""
# Encode input text
src_tokens = self.encode_text(text)
# Perform beam search translation
output_tokens = self.beam_search_translate(src_tokens)
# Decode output tokens
translated_text = self.decode_tokens(output_tokens)
return translated_text
def forward(self, text: str) -> str:
"""Forward pass - alias for translate method for simple usage"""
return self.translate(text)
def __call__(self, text: str) -> str:
"""Make model callable - enables model("text") usage"""
return self.translate(text)
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path,
*,
force_download=False,
resume_download=None,
proxies=None,
token=None,
cache_dir=None,
local_files_only=False,
revision=None,
**kwargs):
"""Load model from Hugging Face Hub or local directory"""
# Download model if it's a hub model
if not os.path.isdir(pretrained_model_name_or_path):
model_dir = snapshot_download(
repo_id=pretrained_model_name_or_path,
token=token,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
revision=revision
)
else:
model_dir = pretrained_model_name_or_path
# Load config
config_path = os.path.join(model_dir, 'config.json')
if os.path.exists(config_path):
config = SetuTranslationConfig.from_json_file(config_path)
else:
# Load from model_config.json if config.json doesn't exist
model_config_path = os.path.join(model_dir, 'model_config.json')
if os.path.exists(model_config_path):
with open(model_config_path, 'r') as f:
model_config = json.load(f)
config = SetuTranslationConfig(**model_config, **kwargs)
else:
config = SetuTranslationConfig(**kwargs)
# Set the model directory path
config._name_or_path = model_dir
# Create model instance
model = cls(config)
return model |