setu / modeling_setu_translation.py
santoshdahal's picture
Upload folder using huggingface_hub
095f90d verified
from transformers import PreTrainedModel, AutoConfig, AutoModel
try:
from .configuration_setu_translation import SetuTranslationConfig
except ImportError:
from configuration_setu_translation import SetuTranslationConfig
import torch
import os
import numpy as np
import json
import onnxruntime as ort
import sentencepiece as spm
from typing import List, Tuple
from huggingface_hub import snapshot_download
class SetuTranslationModel(PreTrainedModel):
"""SETU Translation Model for Hugging Face Hub
This model performs script-agnostic translation to unified English output.
It handles multiscript, multilingual, and informal text translation.
"""
config_class = SetuTranslationConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Initialize model components
self.encoder_session = None
self.decoder_session = None
self.sp = None
# Load model files if they exist
self._load_model_components()
def _load_model_components(self):
"""Load ONNX models and SentencePiece processor"""
model_dir = getattr(self.config, '_name_or_path', '.')
# Paths to model files in assets folder
assets_dir = os.path.join(model_dir, 'assets')
encoder_path = os.path.join(assets_dir, 'encoder.onnx')
decoder_path = os.path.join(assets_dir, 'decoder.onnx')
smp_path = os.path.join(assets_dir, 'spm.model')
# Load ONNX models
# Configure providers: use CUDA if available, fallback to CPU
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
if os.path.exists(encoder_path):
self.encoder_session = ort.InferenceSession(
encoder_path,
providers=providers
)
if os.path.exists(decoder_path):
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=providers
)
# Load SentencePiece model
if os.path.exists(smp_path):
self.sp = spm.SentencePieceProcessor()
self.sp.Load(smp_path)
def encode_text(self, text: str) -> np.ndarray:
"""Encode text to token IDs using SentencePiece"""
if self.sp is None:
raise ValueError("SentencePiece model not loaded")
# Encode using SentencePiece
tokens = self.sp.EncodeAsIds(text)
# Add EOS token
tokens = tokens + [self.config.eos_idx]
return np.array(tokens, dtype=np.int64)
def decode_tokens(self, tokens: List[int]) -> str:
"""Decode token IDs to text using SentencePiece"""
if self.sp is None:
raise ValueError("SentencePiece model not loaded")
# Remove special tokens
tokens = [t for t in tokens if t not in [self.config.bos_idx, self.config.eos_idx, self.config.pad_idx]]
# Decode using SentencePiece
text = self.sp.DecodeIds(tokens)
return text.strip()
def encode_source(self, src_tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Run encoder on source tokens"""
if self.encoder_session is None:
raise ValueError("Encoder model not loaded")
# Prepare inputs
src_tokens_batch = src_tokens.reshape(1, -1) # [1, src_len]
src_lengths = np.array([len(src_tokens)], dtype=np.int64)
# Check encoder input names
encoder_inputs = [inp.name for inp in self.encoder_session.get_inputs()]
# Build input dict based on what encoder expects
input_dict = {'src_tokens': src_tokens_batch}
if 'src_lengths' in encoder_inputs:
input_dict['src_lengths'] = src_lengths
# Run encoder
outputs = self.encoder_session.run(None, input_dict)
# Handle encoder outputs
encoder_out = outputs[0]
encoder_padding_mask = outputs[1] if len(outputs) > 1 else None
return encoder_out, encoder_padding_mask
def decode_step(self, prev_tokens, encoder_out, encoder_padding_mask):
"""Run decoder for one step"""
if self.decoder_session is None:
raise ValueError("Decoder model not loaded")
# Prepare inputs - check if already numpy array
if isinstance(prev_tokens, np.ndarray):
prev_tokens_np = prev_tokens # Already formatted correctly
else:
prev_tokens_np = np.array([prev_tokens], dtype=np.int64) # [1, seq_len]
try:
# Run decoder
outputs = self.decoder_session.run(
None, # Get all outputs
{
'prev_output_tokens': prev_tokens_np,
'encoder_out': encoder_out,
'encoder_padding_mask': encoder_padding_mask
}
)
# Return logits (first output)
return outputs[0]
except Exception as e:
raise RuntimeError(f"Decoder step failed: {e}")
def beam_search_translate(self, src_tokens: np.ndarray) -> List[int]:
"""Perform beam search translation - matches ONNX implementation"""
# Encode source
encoder_out, encoder_padding_mask = self.encode_source(src_tokens)
# Initialize beam search parameters
beam_size = self.config.beam_size
max_len = self.config.max_len
len_penalty = self.config.len_penalty
vocab_size = self.config.tgt_vocab_size
# Initialize beams: (score, tokens)
# NOTE: start with EOS token, not BOS!
beams = [(0.0, [self.config.eos_idx])]
completed = []
for step in range(max_len):
# Stop early if we have enough good completed hypotheses
if len(completed) >= beam_size * 2:
break
all_candidates = []
for score, tokens in beams:
# Check if beam is completed (don't mark as complete if it's just the starting EOS token)
if tokens[-1] == self.config.eos_idx and len(tokens) > 1:
completed.append((score, tokens))
continue
# Prevent EOS at first step (min_len=1)
should_skip_eos = (step == 0 and len(tokens) == 1)
# Prepare decoder input
prev_tokens = np.array([tokens], dtype=np.int64) # [1, tgt_len]
try:
# Get logits for next token
logits = self.decode_step(prev_tokens, encoder_out, encoder_padding_mask)
# Check logits validity
if logits is None or logits.size == 0:
completed.append((score, tokens + [self.config.eos_idx]))
continue
# Get log probabilities for last position
log_probs = logits[0, -1, :] # [vocab_size]
# Proper log softmax: log(exp(x) / sum(exp(x))) = x - log(sum(exp(x)))
max_logit = np.max(log_probs)
log_probs_shifted = log_probs - max_logit
log_sum_exp = np.log(np.sum(np.exp(log_probs_shifted))) + max_logit
log_probs = log_probs - log_sum_exp
# Get top-k candidates (expand more than beam_size for diversity)
top_k = min(beam_size * 2, vocab_size)
top_k_indices = np.argpartition(log_probs, -top_k)[-top_k:]
top_k_indices = top_k_indices[np.argsort(log_probs[top_k_indices])][::-1]
for idx in top_k_indices[:beam_size * 2]: # Check more candidates
# Skip EOS on first step (min_len=1 constraint)
if should_skip_eos and int(idx) == self.config.eos_idx:
continue
candidate_score = score + log_probs[idx]
candidate_tokens = tokens + [int(idx)]
all_candidates.append((candidate_score, candidate_tokens))
# Stop after we have enough candidates
if len(all_candidates) >= beam_size:
break
except Exception as e:
# Force completion if decoding fails
completed.append((score, tokens + [self.config.eos_idx]))
continue
if not all_candidates:
# All beams completed
break
# Select top beam_size candidates
# Sort by cumulative score (no length penalty during search, only at finalization)
ordered = sorted(
all_candidates,
key=lambda x: x[0],
reverse=True
)
beams = ordered[:beam_size]
# Add remaining beams to completed
completed.extend(beams)
# Ensure we have at least one hypothesis
if not completed:
completed = [(0.0, [self.config.eos_idx, self.config.eos_idx])]
# Sort by score with length penalty
# Length = number of generated tokens (excluding starting EOS, including final EOS)
# tokens = [EOS, tok1, tok2, ..., EOS], so length = len(tokens) - 1
# Use max(1, ...) to avoid division by zero for very short sequences
completed = sorted(
completed,
key=lambda x: x[0] / (max(1, len(x[1]) - 1) ** len_penalty),
reverse=True
)
# Return best translation tokens
best_score, best_tokens = completed[0]
return best_tokens
def translate(self, text: str) -> str:
"""Translate input text to English
Args:
text: Input text in any supported script/language
Returns:
Translated English text
"""
# Encode input text
src_tokens = self.encode_text(text)
# Perform beam search translation
output_tokens = self.beam_search_translate(src_tokens)
# Decode output tokens
translated_text = self.decode_tokens(output_tokens)
return translated_text
def forward(self, text: str) -> str:
"""Forward pass - alias for translate method for simple usage"""
return self.translate(text)
def __call__(self, text: str) -> str:
"""Make model callable - enables model("text") usage"""
return self.translate(text)
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path,
*,
force_download=False,
resume_download=None,
proxies=None,
token=None,
cache_dir=None,
local_files_only=False,
revision=None,
**kwargs):
"""Load model from Hugging Face Hub or local directory"""
# Download model if it's a hub model
if not os.path.isdir(pretrained_model_name_or_path):
model_dir = snapshot_download(
repo_id=pretrained_model_name_or_path,
token=token,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
revision=revision
)
else:
model_dir = pretrained_model_name_or_path
# Load config
config_path = os.path.join(model_dir, 'config.json')
if os.path.exists(config_path):
config = SetuTranslationConfig.from_json_file(config_path)
else:
# Load from model_config.json if config.json doesn't exist
model_config_path = os.path.join(model_dir, 'model_config.json')
if os.path.exists(model_config_path):
with open(model_config_path, 'r') as f:
model_config = json.load(f)
config = SetuTranslationConfig(**model_config, **kwargs)
else:
config = SetuTranslationConfig(**kwargs)
# Set the model directory path
config._name_or_path = model_dir
# Create model instance
model = cls(config)
return model