|
|
from transformers import PreTrainedModel, AutoConfig, AutoModel |
|
|
try: |
|
|
from .configuration_setu_translation import SetuTranslationConfig |
|
|
except ImportError: |
|
|
from configuration_setu_translation import SetuTranslationConfig |
|
|
import torch |
|
|
import os |
|
|
import numpy as np |
|
|
import json |
|
|
import onnxruntime as ort |
|
|
import sentencepiece as spm |
|
|
from typing import List, Tuple |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
class SetuTranslationModel(PreTrainedModel): |
|
|
"""SETU Translation Model for Hugging Face Hub |
|
|
|
|
|
This model performs script-agnostic translation to unified English output. |
|
|
It handles multiscript, multilingual, and informal text translation. |
|
|
""" |
|
|
|
|
|
config_class = SetuTranslationConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.encoder_session = None |
|
|
self.decoder_session = None |
|
|
self.sp = None |
|
|
|
|
|
|
|
|
self._load_model_components() |
|
|
|
|
|
def _load_model_components(self): |
|
|
"""Load ONNX models and SentencePiece processor""" |
|
|
model_dir = getattr(self.config, '_name_or_path', '.') |
|
|
|
|
|
|
|
|
assets_dir = os.path.join(model_dir, 'assets') |
|
|
encoder_path = os.path.join(assets_dir, 'encoder.onnx') |
|
|
decoder_path = os.path.join(assets_dir, 'decoder.onnx') |
|
|
smp_path = os.path.join(assets_dir, 'spm.model') |
|
|
|
|
|
|
|
|
|
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider'] |
|
|
|
|
|
if os.path.exists(encoder_path): |
|
|
self.encoder_session = ort.InferenceSession( |
|
|
encoder_path, |
|
|
providers=providers |
|
|
) |
|
|
|
|
|
if os.path.exists(decoder_path): |
|
|
self.decoder_session = ort.InferenceSession( |
|
|
decoder_path, |
|
|
providers=providers |
|
|
) |
|
|
|
|
|
|
|
|
if os.path.exists(smp_path): |
|
|
self.sp = spm.SentencePieceProcessor() |
|
|
self.sp.Load(smp_path) |
|
|
|
|
|
def encode_text(self, text: str) -> np.ndarray: |
|
|
"""Encode text to token IDs using SentencePiece""" |
|
|
if self.sp is None: |
|
|
raise ValueError("SentencePiece model not loaded") |
|
|
|
|
|
|
|
|
tokens = self.sp.EncodeAsIds(text) |
|
|
|
|
|
|
|
|
tokens = tokens + [self.config.eos_idx] |
|
|
|
|
|
return np.array(tokens, dtype=np.int64) |
|
|
|
|
|
def decode_tokens(self, tokens: List[int]) -> str: |
|
|
"""Decode token IDs to text using SentencePiece""" |
|
|
if self.sp is None: |
|
|
raise ValueError("SentencePiece model not loaded") |
|
|
|
|
|
|
|
|
tokens = [t for t in tokens if t not in [self.config.bos_idx, self.config.eos_idx, self.config.pad_idx]] |
|
|
|
|
|
|
|
|
text = self.sp.DecodeIds(tokens) |
|
|
|
|
|
return text.strip() |
|
|
|
|
|
def encode_source(self, src_tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
|
|
"""Run encoder on source tokens""" |
|
|
if self.encoder_session is None: |
|
|
raise ValueError("Encoder model not loaded") |
|
|
|
|
|
|
|
|
src_tokens_batch = src_tokens.reshape(1, -1) |
|
|
src_lengths = np.array([len(src_tokens)], dtype=np.int64) |
|
|
|
|
|
|
|
|
encoder_inputs = [inp.name for inp in self.encoder_session.get_inputs()] |
|
|
|
|
|
|
|
|
input_dict = {'src_tokens': src_tokens_batch} |
|
|
if 'src_lengths' in encoder_inputs: |
|
|
input_dict['src_lengths'] = src_lengths |
|
|
|
|
|
|
|
|
outputs = self.encoder_session.run(None, input_dict) |
|
|
|
|
|
|
|
|
encoder_out = outputs[0] |
|
|
encoder_padding_mask = outputs[1] if len(outputs) > 1 else None |
|
|
|
|
|
return encoder_out, encoder_padding_mask |
|
|
|
|
|
def decode_step(self, prev_tokens, encoder_out, encoder_padding_mask): |
|
|
"""Run decoder for one step""" |
|
|
if self.decoder_session is None: |
|
|
raise ValueError("Decoder model not loaded") |
|
|
|
|
|
|
|
|
if isinstance(prev_tokens, np.ndarray): |
|
|
prev_tokens_np = prev_tokens |
|
|
else: |
|
|
prev_tokens_np = np.array([prev_tokens], dtype=np.int64) |
|
|
|
|
|
try: |
|
|
|
|
|
outputs = self.decoder_session.run( |
|
|
None, |
|
|
{ |
|
|
'prev_output_tokens': prev_tokens_np, |
|
|
'encoder_out': encoder_out, |
|
|
'encoder_padding_mask': encoder_padding_mask |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
return outputs[0] |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Decoder step failed: {e}") |
|
|
|
|
|
def beam_search_translate(self, src_tokens: np.ndarray) -> List[int]: |
|
|
"""Perform beam search translation - matches ONNX implementation""" |
|
|
|
|
|
encoder_out, encoder_padding_mask = self.encode_source(src_tokens) |
|
|
|
|
|
|
|
|
beam_size = self.config.beam_size |
|
|
max_len = self.config.max_len |
|
|
len_penalty = self.config.len_penalty |
|
|
vocab_size = self.config.tgt_vocab_size |
|
|
|
|
|
|
|
|
|
|
|
beams = [(0.0, [self.config.eos_idx])] |
|
|
completed = [] |
|
|
|
|
|
for step in range(max_len): |
|
|
|
|
|
if len(completed) >= beam_size * 2: |
|
|
break |
|
|
|
|
|
all_candidates = [] |
|
|
|
|
|
for score, tokens in beams: |
|
|
|
|
|
if tokens[-1] == self.config.eos_idx and len(tokens) > 1: |
|
|
completed.append((score, tokens)) |
|
|
continue |
|
|
|
|
|
|
|
|
should_skip_eos = (step == 0 and len(tokens) == 1) |
|
|
|
|
|
|
|
|
prev_tokens = np.array([tokens], dtype=np.int64) |
|
|
|
|
|
try: |
|
|
|
|
|
logits = self.decode_step(prev_tokens, encoder_out, encoder_padding_mask) |
|
|
|
|
|
|
|
|
if logits is None or logits.size == 0: |
|
|
completed.append((score, tokens + [self.config.eos_idx])) |
|
|
continue |
|
|
|
|
|
|
|
|
log_probs = logits[0, -1, :] |
|
|
|
|
|
|
|
|
max_logit = np.max(log_probs) |
|
|
log_probs_shifted = log_probs - max_logit |
|
|
log_sum_exp = np.log(np.sum(np.exp(log_probs_shifted))) + max_logit |
|
|
log_probs = log_probs - log_sum_exp |
|
|
|
|
|
|
|
|
top_k = min(beam_size * 2, vocab_size) |
|
|
top_k_indices = np.argpartition(log_probs, -top_k)[-top_k:] |
|
|
top_k_indices = top_k_indices[np.argsort(log_probs[top_k_indices])][::-1] |
|
|
|
|
|
for idx in top_k_indices[:beam_size * 2]: |
|
|
|
|
|
if should_skip_eos and int(idx) == self.config.eos_idx: |
|
|
continue |
|
|
|
|
|
candidate_score = score + log_probs[idx] |
|
|
candidate_tokens = tokens + [int(idx)] |
|
|
all_candidates.append((candidate_score, candidate_tokens)) |
|
|
|
|
|
|
|
|
if len(all_candidates) >= beam_size: |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
completed.append((score, tokens + [self.config.eos_idx])) |
|
|
continue |
|
|
|
|
|
if not all_candidates: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
ordered = sorted( |
|
|
all_candidates, |
|
|
key=lambda x: x[0], |
|
|
reverse=True |
|
|
) |
|
|
beams = ordered[:beam_size] |
|
|
|
|
|
|
|
|
completed.extend(beams) |
|
|
|
|
|
|
|
|
if not completed: |
|
|
completed = [(0.0, [self.config.eos_idx, self.config.eos_idx])] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
completed = sorted( |
|
|
completed, |
|
|
key=lambda x: x[0] / (max(1, len(x[1]) - 1) ** len_penalty), |
|
|
reverse=True |
|
|
) |
|
|
|
|
|
|
|
|
best_score, best_tokens = completed[0] |
|
|
return best_tokens |
|
|
|
|
|
def translate(self, text: str) -> str: |
|
|
"""Translate input text to English |
|
|
|
|
|
Args: |
|
|
text: Input text in any supported script/language |
|
|
|
|
|
Returns: |
|
|
Translated English text |
|
|
""" |
|
|
|
|
|
src_tokens = self.encode_text(text) |
|
|
|
|
|
|
|
|
output_tokens = self.beam_search_translate(src_tokens) |
|
|
|
|
|
|
|
|
translated_text = self.decode_tokens(output_tokens) |
|
|
|
|
|
return translated_text |
|
|
|
|
|
def forward(self, text: str) -> str: |
|
|
"""Forward pass - alias for translate method for simple usage""" |
|
|
return self.translate(text) |
|
|
|
|
|
def __call__(self, text: str) -> str: |
|
|
"""Make model callable - enables model("text") usage""" |
|
|
return self.translate(text) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, |
|
|
pretrained_model_name_or_path, |
|
|
*, |
|
|
force_download=False, |
|
|
resume_download=None, |
|
|
proxies=None, |
|
|
token=None, |
|
|
cache_dir=None, |
|
|
local_files_only=False, |
|
|
revision=None, |
|
|
**kwargs): |
|
|
"""Load model from Hugging Face Hub or local directory""" |
|
|
|
|
|
|
|
|
if not os.path.isdir(pretrained_model_name_or_path): |
|
|
model_dir = snapshot_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
token=token, |
|
|
cache_dir=cache_dir, |
|
|
force_download=force_download, |
|
|
resume_download=resume_download, |
|
|
proxies=proxies, |
|
|
local_files_only=local_files_only, |
|
|
revision=revision |
|
|
) |
|
|
else: |
|
|
model_dir = pretrained_model_name_or_path |
|
|
|
|
|
|
|
|
config_path = os.path.join(model_dir, 'config.json') |
|
|
if os.path.exists(config_path): |
|
|
config = SetuTranslationConfig.from_json_file(config_path) |
|
|
else: |
|
|
|
|
|
model_config_path = os.path.join(model_dir, 'model_config.json') |
|
|
if os.path.exists(model_config_path): |
|
|
with open(model_config_path, 'r') as f: |
|
|
model_config = json.load(f) |
|
|
config = SetuTranslationConfig(**model_config, **kwargs) |
|
|
else: |
|
|
config = SetuTranslationConfig(**kwargs) |
|
|
|
|
|
|
|
|
config._name_or_path = model_dir |
|
|
|
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
return model |