|
|
|
|
|
"""
|
|
|
Custom ViSoNorm model class for BartPho-based models.
|
|
|
This preserves the custom heads needed for text normalization and
|
|
|
is loadable via auto_map without custom model_type.
|
|
|
"""
|
|
|
|
|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from transformers import MBartModel, MBartConfig, MBartPreTrainedModel
|
|
|
from transformers.modeling_outputs import Seq2SeqLMOutput
|
|
|
|
|
|
NUM_LABELS_N_MASKS = 5
|
|
|
|
|
|
|
|
|
def gelu(x):
|
|
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
|
|
|
|
|
|
|
|
class MBartLMHead(nn.Module):
|
|
|
def __init__(self, config, bart_model_embedding_weights):
|
|
|
super().__init__()
|
|
|
|
|
|
actual_hidden_size = bart_model_embedding_weights.size(1)
|
|
|
self.dense = nn.Linear(actual_hidden_size, actual_hidden_size)
|
|
|
self.layer_norm = nn.LayerNorm(actual_hidden_size, eps=1e-12)
|
|
|
|
|
|
num_labels = bart_model_embedding_weights.size(0)
|
|
|
self.decoder = nn.Linear(actual_hidden_size, num_labels, bias=False)
|
|
|
self.decoder.weight = bart_model_embedding_weights
|
|
|
self.decoder.bias = nn.Parameter(torch.zeros(num_labels))
|
|
|
|
|
|
def forward(self, features):
|
|
|
x = self.dense(features)
|
|
|
x = gelu(x)
|
|
|
x = self.layer_norm(x)
|
|
|
x = self.decoder(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class BartMaskNPredictionHead(nn.Module):
|
|
|
def __init__(self, config, actual_hidden_size):
|
|
|
super(BartMaskNPredictionHead, self).__init__()
|
|
|
self.mask_predictor_dense = nn.Linear(actual_hidden_size, 50)
|
|
|
self.mask_predictor_proj = nn.Linear(50, NUM_LABELS_N_MASKS)
|
|
|
self.activation = gelu
|
|
|
|
|
|
def forward(self, sequence_output):
|
|
|
mask_predictor_state = self.activation(self.mask_predictor_dense(sequence_output))
|
|
|
prediction_scores = self.mask_predictor_proj(mask_predictor_state)
|
|
|
return prediction_scores
|
|
|
|
|
|
|
|
|
class BartBinaryPredictor(nn.Module):
|
|
|
def __init__(self, hidden_size, dense_dim=100):
|
|
|
super(BartBinaryPredictor, self).__init__()
|
|
|
self.dense = nn.Linear(hidden_size, dense_dim)
|
|
|
|
|
|
self.predictor = nn.Linear(dense_dim, 2)
|
|
|
self.activation = gelu
|
|
|
|
|
|
def forward(self, sequence_output):
|
|
|
state = self.activation(self.dense(sequence_output))
|
|
|
prediction_scores = self.predictor(state)
|
|
|
return prediction_scores
|
|
|
|
|
|
|
|
|
class ViSoNormBartPhoForMaskedLM(MBartPreTrainedModel):
|
|
|
config_class = MBartConfig
|
|
|
|
|
|
def __init__(self, config: MBartConfig):
|
|
|
super().__init__(config)
|
|
|
|
|
|
|
|
|
bart_config = MBartConfig(
|
|
|
vocab_size=self.config.vocab_size,
|
|
|
hidden_size=self.config.hidden_size,
|
|
|
num_hidden_layers=self.config.num_hidden_layers,
|
|
|
num_attention_heads=self.config.num_attention_heads,
|
|
|
intermediate_size=self.config.intermediate_size,
|
|
|
max_position_embeddings=self.config.max_position_embeddings,
|
|
|
type_vocab_size=self.config.type_vocab_size,
|
|
|
initializer_range=self.config.initializer_range,
|
|
|
layer_norm_eps=self.config.layer_norm_eps,
|
|
|
pad_token_id=self.config.pad_token_id,
|
|
|
bos_token_id=self.config.bos_token_id,
|
|
|
eos_token_id=self.config.eos_token_id,
|
|
|
mask_token_id=self.config.mask_token_id,
|
|
|
)
|
|
|
|
|
|
|
|
|
self.bart = MBartModel(self.config)
|
|
|
|
|
|
|
|
|
actual_hidden_size = self.bart.shared.weight.size(1)
|
|
|
|
|
|
|
|
|
self.cls = MBartLMHead(self.config, self.bart.shared.weight)
|
|
|
|
|
|
|
|
|
self.mask_n_predictor = BartMaskNPredictionHead(self.config, actual_hidden_size)
|
|
|
self.nsw_detector = BartBinaryPredictor(actual_hidden_size, dense_dim=100)
|
|
|
self.num_labels_n_mask = NUM_LABELS_N_MASKS
|
|
|
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
def _load_state_dict(self, state_dict, strict=True):
|
|
|
"""
|
|
|
Custom state dict loading that handles shape mismatches gracefully.
|
|
|
"""
|
|
|
|
|
|
if 'bart.encoder.embed_positions.weight' in state_dict:
|
|
|
checkpoint_pos_shape = state_dict['bart.encoder.embed_positions.weight'].shape
|
|
|
model_pos_shape = self.bart.encoder.embed_positions.weight.shape
|
|
|
|
|
|
if checkpoint_pos_shape != model_pos_shape:
|
|
|
|
|
|
self.bart.encoder.embed_positions.weight.data = torch.nn.Parameter(
|
|
|
torch.zeros(checkpoint_pos_shape[0], checkpoint_pos_shape[1])
|
|
|
)
|
|
|
self.bart.decoder.embed_positions.weight.data = torch.nn.Parameter(
|
|
|
torch.zeros(checkpoint_pos_shape[0], checkpoint_pos_shape[1])
|
|
|
)
|
|
|
|
|
|
|
|
|
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
return missing_keys, unexpected_keys
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
|
|
"""
|
|
|
Override from_pretrained to use our custom state dict loading.
|
|
|
"""
|
|
|
|
|
|
config = MBartConfig.from_pretrained(pretrained_model_name_or_path)
|
|
|
|
|
|
|
|
|
model = cls(config)
|
|
|
|
|
|
|
|
|
import os
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
model_file = None
|
|
|
|
|
|
|
|
|
try:
|
|
|
model_file = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin")
|
|
|
state_dict = torch.load(model_file, map_location='cpu')
|
|
|
except Exception:
|
|
|
|
|
|
try:
|
|
|
model_file = hf_hub_download(pretrained_model_name_or_path, "model.safetensors")
|
|
|
from safetensors.torch import load_file
|
|
|
state_dict = load_file(model_file)
|
|
|
except Exception:
|
|
|
|
|
|
if os.path.exists(pretrained_model_name_or_path):
|
|
|
pytorch_file = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
|
|
|
safetensors_file = os.path.join(pretrained_model_name_or_path, "model.safetensors")
|
|
|
|
|
|
if os.path.exists(pytorch_file):
|
|
|
state_dict = torch.load(pytorch_file, map_location='cpu')
|
|
|
elif os.path.exists(safetensors_file):
|
|
|
from safetensors.torch import load_file
|
|
|
state_dict = load_file(safetensors_file)
|
|
|
else:
|
|
|
raise FileNotFoundError(f"No model file found in {pretrained_model_name_or_path}")
|
|
|
else:
|
|
|
raise FileNotFoundError(f"Model file not found for {pretrained_model_name_or_path}")
|
|
|
|
|
|
|
|
|
model._load_state_dict(state_dict)
|
|
|
|
|
|
return model
|
|
|
|
|
|
def fix_classification_head_for_tokenizer(self, tokenizer):
|
|
|
"""
|
|
|
Fix the classification head to match the tokenizer's vocabulary size.
|
|
|
This is needed when there's a vocabulary mismatch between model and tokenizer.
|
|
|
"""
|
|
|
tokenizer_vocab_size = len(tokenizer)
|
|
|
model_vocab_size = self.config.vocab_size
|
|
|
|
|
|
if tokenizer_vocab_size != model_vocab_size:
|
|
|
|
|
|
if '<space>' not in tokenizer.get_vocab():
|
|
|
|
|
|
tokenizer.add_tokens(['<space>'])
|
|
|
new_vocab_size = len(tokenizer)
|
|
|
|
|
|
|
|
|
self.bart.resize_token_embeddings(new_vocab_size)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
new_token_id = new_vocab_size - 1
|
|
|
|
|
|
existing_embeddings = self.bart.shared.weight[:-1]
|
|
|
avg_embedding = existing_embeddings.mean(dim=0)
|
|
|
self.bart.shared.weight[new_token_id] = avg_embedding
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids=None,
|
|
|
attention_mask=None,
|
|
|
decoder_input_ids=None,
|
|
|
decoder_attention_mask=None,
|
|
|
head_mask=None,
|
|
|
decoder_head_mask=None,
|
|
|
cross_attn_head_mask=None,
|
|
|
encoder_outputs=None,
|
|
|
past_key_values=None,
|
|
|
inputs_embeds=None,
|
|
|
decoder_inputs_embeds=None,
|
|
|
use_cache=None,
|
|
|
output_attentions=None,
|
|
|
output_hidden_states=None,
|
|
|
return_dict=None,
|
|
|
):
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
outputs = self.bart(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
decoder_input_ids=decoder_input_ids,
|
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
|
head_mask=head_mask,
|
|
|
decoder_head_mask=decoder_head_mask,
|
|
|
cross_attn_head_mask=cross_attn_head_mask,
|
|
|
encoder_outputs=encoder_outputs,
|
|
|
past_key_values=past_key_values,
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
|
use_cache=use_cache,
|
|
|
output_attentions=output_attentions,
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
return_dict=return_dict,
|
|
|
)
|
|
|
|
|
|
|
|
|
if return_dict:
|
|
|
sequence_output = outputs.last_hidden_state
|
|
|
else:
|
|
|
sequence_output = outputs[0]
|
|
|
|
|
|
|
|
|
logits_norm = self.cls(sequence_output)
|
|
|
logits_n_masks_pred = self.mask_n_predictor(sequence_output)
|
|
|
logits_nsw_detection = self.nsw_detector(sequence_output)
|
|
|
|
|
|
if not return_dict:
|
|
|
return (logits_norm, logits_n_masks_pred, logits_nsw_detection) + outputs[1:]
|
|
|
|
|
|
|
|
|
|
|
|
class ViSoNormOutput:
|
|
|
def __init__(self, logits_norm, logits_n_masks_pred, logits_nsw_detection, hidden_states=None, attentions=None):
|
|
|
self.logits = logits_norm
|
|
|
self.logits_norm = logits_norm
|
|
|
self.logits_n_masks_pred = logits_n_masks_pred
|
|
|
self.logits_nsw_detection = logits_nsw_detection
|
|
|
self.hidden_states = hidden_states
|
|
|
self.attentions = attentions
|
|
|
|
|
|
|
|
|
hidden_states = getattr(outputs, 'encoder_hidden_states', None) or getattr(outputs, 'hidden_states', None)
|
|
|
attentions = getattr(outputs, 'encoder_attentions', None) or getattr(outputs, 'attentions', None)
|
|
|
|
|
|
return ViSoNormOutput(
|
|
|
logits_norm=logits_norm,
|
|
|
logits_n_masks_pred=logits_n_masks_pred,
|
|
|
logits_nsw_detection=logits_nsw_detection,
|
|
|
hidden_states=hidden_states,
|
|
|
attentions=attentions,
|
|
|
)
|
|
|
|
|
|
def normalize_text(self, tokenizer, text, device='cpu'):
|
|
|
"""
|
|
|
Normalize text using the ViSoNorm BartPho model with proper NSW detection and masking.
|
|
|
|
|
|
Args:
|
|
|
tokenizer: HuggingFace tokenizer (should be BartphoTokenizer)
|
|
|
text: Input text to normalize
|
|
|
device: Device to run inference on
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (normalized_text, source_tokens, prediction_tokens)
|
|
|
"""
|
|
|
|
|
|
self.to(device)
|
|
|
|
|
|
|
|
|
self.fix_classification_head_for_tokenizer(tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
|
|
|
input_tokens_tensor = encoded.to(device)
|
|
|
|
|
|
|
|
|
input_tokens = tokenizer.convert_ids_to_tokens(encoded[0])
|
|
|
|
|
|
|
|
|
input_tokens_tensor, _, token_type_ids, input_mask = self._truncate_and_build_masks(input_tokens_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
self.eval()
|
|
|
with torch.no_grad():
|
|
|
bart_outputs = self.bart(input_tokens_tensor, attention_mask=input_mask, output_hidden_states=True)
|
|
|
sequence_output = bart_outputs.encoder_last_hidden_state
|
|
|
|
|
|
|
|
|
logits_norm = self.cls(sequence_output)
|
|
|
logits_n_masks_pred = self.mask_n_predictor(sequence_output)
|
|
|
logits_nsw_detection = self.nsw_detector(sequence_output)
|
|
|
|
|
|
|
|
|
class ViSoNormOutput:
|
|
|
def __init__(self, logits_norm, logits_n_masks_pred, logits_nsw_detection):
|
|
|
self.logits = logits_norm
|
|
|
self.logits_norm = logits_norm
|
|
|
self.logits_n_masks_pred = logits_n_masks_pred
|
|
|
self.logits_nsw_detection = logits_nsw_detection
|
|
|
|
|
|
outputs = ViSoNormOutput(logits_norm, logits_n_masks_pred, logits_nsw_detection)
|
|
|
|
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(input_tokens_tensor[0])
|
|
|
|
|
|
if hasattr(outputs, 'logits_nsw_detection') and outputs.logits_nsw_detection is not None:
|
|
|
|
|
|
if outputs.logits_nsw_detection.dim() == 3:
|
|
|
nsw_predictions = torch.argmax(outputs.logits_nsw_detection[0], dim=-1) == 1
|
|
|
else:
|
|
|
nsw_predictions = torch.sigmoid(outputs.logits_nsw_detection[0]) > 0.5
|
|
|
|
|
|
tokens_need_norm = []
|
|
|
for i, token in enumerate(tokens):
|
|
|
|
|
|
if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
|
|
|
tokens_need_norm.append(False)
|
|
|
else:
|
|
|
if i < len(nsw_predictions):
|
|
|
tokens_need_norm.append(nsw_predictions[i].item())
|
|
|
else:
|
|
|
tokens_need_norm.append(False)
|
|
|
else:
|
|
|
|
|
|
tokens_need_norm = [token not in ['<s>', '</s>', '<pad>', '<unk>', '<mask>'] for token in tokens]
|
|
|
|
|
|
|
|
|
nsw_tokens = [tokens[i] for i, need in enumerate(tokens_need_norm) if need]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _score_sequence(input_ids_tensor: torch.Tensor) -> float:
|
|
|
with torch.no_grad():
|
|
|
|
|
|
bart_outputs = self.bart(input_ids_tensor, attention_mask=torch.ones_like(input_ids_tensor), output_hidden_states=True)
|
|
|
sequence_output = bart_outputs.encoder_last_hidden_state
|
|
|
logits = self.cls(sequence_output)
|
|
|
log_probs = torch.log_softmax(logits[0], dim=-1)
|
|
|
|
|
|
position_scores, _ = torch.max(log_probs, dim=-1)
|
|
|
return float(position_scores.mean().item())
|
|
|
|
|
|
mask_token_id = tokenizer.convert_tokens_to_ids('<mask>')
|
|
|
working_ids = input_tokens_tensor[0].detach().clone().cpu().tolist()
|
|
|
nsw_indices = [i for i, need in enumerate(tokens_need_norm) if need]
|
|
|
|
|
|
offset = 0
|
|
|
for i in nsw_indices:
|
|
|
pos = i + offset
|
|
|
|
|
|
cand_a = working_ids
|
|
|
score_a = _score_sequence(torch.tensor([cand_a], device=device))
|
|
|
|
|
|
cand_b = working_ids[:pos+1] + [mask_token_id] + working_ids[pos+1:]
|
|
|
score_b = _score_sequence(torch.tensor([cand_b], device=device))
|
|
|
if score_b > score_a:
|
|
|
working_ids = cand_b
|
|
|
offset += 1
|
|
|
|
|
|
|
|
|
masked_input_ids = torch.tensor([working_ids], device=device)
|
|
|
with torch.no_grad():
|
|
|
|
|
|
bart_outputs = self.bart(masked_input_ids, attention_mask=torch.ones_like(masked_input_ids), output_hidden_states=True)
|
|
|
sequence_output = bart_outputs.encoder_last_hidden_state
|
|
|
logits_final = self.cls(sequence_output)
|
|
|
pred_ids = torch.argmax(logits_final, dim=-1)[0].cpu().tolist()
|
|
|
|
|
|
|
|
|
final_tokens = []
|
|
|
for idx, src_id in enumerate(working_ids):
|
|
|
tok = tokenizer.convert_ids_to_tokens([src_id])[0]
|
|
|
if tok in ['<s>', '</s>', '<pad>', '<unk>']:
|
|
|
final_tokens.append(src_id)
|
|
|
else:
|
|
|
pred_id = pred_ids[idx] if idx < len(pred_ids) else src_id
|
|
|
|
|
|
if pred_id >= len(tokenizer):
|
|
|
pred_id = len(tokenizer) - 1
|
|
|
final_tokens.append(pred_id)
|
|
|
|
|
|
|
|
|
def remove_special_tokens(token_list):
|
|
|
special_tokens = ['<s>', '</s>', '<pad>', '<unk>', '<mask>', '<space>']
|
|
|
return [token for token in token_list if token not in special_tokens]
|
|
|
|
|
|
def _safe_ids_to_text(token_ids):
|
|
|
if not token_ids:
|
|
|
return ""
|
|
|
try:
|
|
|
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
|
|
cleaned = remove_special_tokens(tokens)
|
|
|
if not cleaned:
|
|
|
return ""
|
|
|
return tokenizer.convert_tokens_to_string(cleaned)
|
|
|
except Exception:
|
|
|
return ""
|
|
|
|
|
|
|
|
|
final_tokens = [tid for tid in final_tokens if tid != -1]
|
|
|
pred_str = _safe_ids_to_text(final_tokens)
|
|
|
|
|
|
if pred_str:
|
|
|
pred_str = ' '.join(pred_str.split())
|
|
|
|
|
|
|
|
|
decoded_source = tokenizer.convert_ids_to_tokens(working_ids)
|
|
|
decoded_pred = tokenizer.convert_ids_to_tokens(final_tokens)
|
|
|
|
|
|
return pred_str, decoded_source, decoded_pred
|
|
|
|
|
|
def detect_nsw(self, tokenizer, text, device='cpu'):
|
|
|
"""
|
|
|
Detect Non-Standard Words (NSW) in text and return detailed information.
|
|
|
This method aligns with normalize_text to ensure consistent NSW detection.
|
|
|
|
|
|
Args:
|
|
|
tokenizer: HuggingFace tokenizer
|
|
|
text: Input text to analyze
|
|
|
device: Device to run inference on
|
|
|
|
|
|
Returns:
|
|
|
List of dictionaries containing NSW information:
|
|
|
[{'index': int, 'start_index': int, 'end_index': int, 'nsw': str,
|
|
|
'prediction': str, 'confidence_score': float}, ...]
|
|
|
"""
|
|
|
|
|
|
self.to(device)
|
|
|
|
|
|
|
|
|
self.fix_classification_head_for_tokenizer(tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
|
|
|
input_tokens_tensor = encoded.to(device)
|
|
|
|
|
|
|
|
|
input_tokens = tokenizer.convert_ids_to_tokens(encoded[0])
|
|
|
|
|
|
|
|
|
input_tokens_tensor, _, token_type_ids, input_mask = self._truncate_and_build_masks(input_tokens_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
self.eval()
|
|
|
with torch.no_grad():
|
|
|
bart_outputs = self.bart(input_tokens_tensor, attention_mask=input_mask, output_hidden_states=True)
|
|
|
sequence_output = bart_outputs.encoder_last_hidden_state
|
|
|
|
|
|
|
|
|
logits_norm = self.cls(sequence_output)
|
|
|
logits_n_masks_pred = self.mask_n_predictor(sequence_output)
|
|
|
logits_nsw_detection = self.nsw_detector(sequence_output)
|
|
|
|
|
|
|
|
|
class ViSoNormOutput:
|
|
|
def __init__(self, logits_norm, logits_n_masks_pred, logits_nsw_detection):
|
|
|
self.logits = logits_norm
|
|
|
self.logits_norm = logits_norm
|
|
|
self.logits_n_masks_pred = logits_n_masks_pred
|
|
|
self.logits_nsw_detection = logits_nsw_detection
|
|
|
|
|
|
outputs = ViSoNormOutput(logits_norm, logits_n_masks_pred, logits_nsw_detection)
|
|
|
|
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(input_tokens_tensor[0])
|
|
|
|
|
|
if hasattr(outputs, 'logits_nsw_detection') and outputs.logits_nsw_detection is not None:
|
|
|
|
|
|
if outputs.logits_nsw_detection.dim() == 3:
|
|
|
nsw_predictions = torch.argmax(outputs.logits_nsw_detection[0], dim=-1) == 1
|
|
|
nsw_confidence = torch.softmax(outputs.logits_nsw_detection[0], dim=-1)[:, 1]
|
|
|
else:
|
|
|
nsw_predictions = torch.sigmoid(outputs.logits_nsw_detection[0]) > 0.5
|
|
|
nsw_confidence = torch.sigmoid(outputs.logits_nsw_detection[0])
|
|
|
|
|
|
tokens_need_norm = []
|
|
|
for i, token in enumerate(tokens):
|
|
|
|
|
|
if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
|
|
|
tokens_need_norm.append(False)
|
|
|
else:
|
|
|
if i < len(nsw_predictions):
|
|
|
tokens_need_norm.append(nsw_predictions[i].item())
|
|
|
else:
|
|
|
tokens_need_norm.append(False)
|
|
|
else:
|
|
|
|
|
|
tokens_need_norm = [token not in ['<s>', '</s>', '<pad>', '<unk>', '<mask>'] for token in tokens]
|
|
|
|
|
|
|
|
|
def _score_sequence(input_ids_tensor: torch.Tensor) -> float:
|
|
|
with torch.no_grad():
|
|
|
|
|
|
bart_outputs = self.bart(input_ids_tensor, attention_mask=torch.ones_like(input_ids_tensor), output_hidden_states=True)
|
|
|
sequence_output = bart_outputs.encoder_last_hidden_state
|
|
|
logits = self.cls(sequence_output)
|
|
|
log_probs = torch.log_softmax(logits[0], dim=-1)
|
|
|
position_scores, _ = torch.max(log_probs, dim=-1)
|
|
|
return float(position_scores.mean().item())
|
|
|
|
|
|
mask_token_id = tokenizer.convert_tokens_to_ids('<mask>')
|
|
|
working_ids = input_tokens_tensor[0].detach().clone().cpu().tolist()
|
|
|
nsw_indices = [i for i, need in enumerate(tokens_need_norm) if need]
|
|
|
|
|
|
offset = 0
|
|
|
for i in nsw_indices:
|
|
|
pos = i + offset
|
|
|
|
|
|
cand_a = working_ids
|
|
|
score_a = _score_sequence(torch.tensor([cand_a], device=device))
|
|
|
|
|
|
cand_b = working_ids[:pos+1] + [mask_token_id] + working_ids[pos+1:]
|
|
|
score_b = _score_sequence(torch.tensor([cand_b], device=device))
|
|
|
if score_b > score_a:
|
|
|
working_ids = cand_b
|
|
|
offset += 1
|
|
|
|
|
|
|
|
|
masked_input_ids = torch.tensor([working_ids], device=device)
|
|
|
with torch.no_grad():
|
|
|
|
|
|
bart_outputs = self.bart(masked_input_ids, attention_mask=torch.ones_like(masked_input_ids), output_hidden_states=True)
|
|
|
sequence_output = bart_outputs.encoder_last_hidden_state
|
|
|
logits_final = self.cls(sequence_output)
|
|
|
pred_ids = torch.argmax(logits_final, dim=-1)[0].cpu().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nsw_results = []
|
|
|
|
|
|
|
|
|
final_tokens = []
|
|
|
for idx, src_id in enumerate(working_ids):
|
|
|
tok = tokenizer.convert_ids_to_tokens([src_id])[0]
|
|
|
if tok in ['<s>', '</s>', '<pad>', '<unk>']:
|
|
|
final_tokens.append(src_id)
|
|
|
else:
|
|
|
final_tokens.append(pred_ids[idx] if idx < len(pred_ids) else src_id)
|
|
|
|
|
|
|
|
|
def remove_special_tokens(token_list):
|
|
|
special_tokens = ['<s>', '</s>', '<pad>', '<unk>', '<mask>', '<space>']
|
|
|
return [token for token in token_list if token not in special_tokens]
|
|
|
|
|
|
def _safe_ids_to_text(token_ids):
|
|
|
if not token_ids:
|
|
|
return ""
|
|
|
try:
|
|
|
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
|
|
cleaned = remove_special_tokens(tokens)
|
|
|
if not cleaned:
|
|
|
return ""
|
|
|
return tokenizer.convert_tokens_to_string(cleaned)
|
|
|
except Exception:
|
|
|
return ""
|
|
|
|
|
|
|
|
|
final_tokens_cleaned = [tid for tid in final_tokens if tid != -1]
|
|
|
normalized_text = _safe_ids_to_text(final_tokens_cleaned)
|
|
|
|
|
|
if normalized_text:
|
|
|
normalized_text = ' '.join(normalized_text.split())
|
|
|
|
|
|
|
|
|
original_tokens = tokenizer.tokenize(text)
|
|
|
normalized_tokens = tokenizer.tokenize(normalized_text)
|
|
|
|
|
|
|
|
|
|
|
|
decoded_source = tokenizer.convert_ids_to_tokens(working_ids)
|
|
|
decoded_pred = tokenizer.convert_ids_to_tokens(final_tokens)
|
|
|
|
|
|
|
|
|
def clean_token(token):
|
|
|
if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
|
|
|
return None
|
|
|
return token.strip().lstrip('▁')
|
|
|
|
|
|
|
|
|
i = 0
|
|
|
while i < len(decoded_source):
|
|
|
src_token = decoded_source[i]
|
|
|
clean_src = clean_token(src_token)
|
|
|
|
|
|
if clean_src is None:
|
|
|
i += 1
|
|
|
continue
|
|
|
|
|
|
|
|
|
pred_token = decoded_pred[i]
|
|
|
clean_pred = clean_token(pred_token)
|
|
|
|
|
|
if clean_pred is None:
|
|
|
i += 1
|
|
|
continue
|
|
|
|
|
|
if clean_src != clean_pred:
|
|
|
|
|
|
expansion_tokens = [clean_pred]
|
|
|
j = i + 1
|
|
|
|
|
|
|
|
|
while j < len(decoded_source) and j < len(decoded_pred):
|
|
|
next_src = decoded_source[j]
|
|
|
next_pred = decoded_pred[j]
|
|
|
|
|
|
|
|
|
if next_src == '<mask>':
|
|
|
clean_next_pred = clean_token(next_pred)
|
|
|
if clean_next_pred is not None:
|
|
|
expansion_tokens.append(clean_next_pred)
|
|
|
j += 1
|
|
|
else:
|
|
|
|
|
|
clean_next_src = clean_token(next_src)
|
|
|
clean_next_pred = clean_token(next_pred)
|
|
|
|
|
|
if clean_next_src is not None and clean_next_pred is not None and clean_next_src != clean_next_pred:
|
|
|
|
|
|
|
|
|
|
|
|
break
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
|
|
|
expansion_text = ' '.join(expansion_tokens)
|
|
|
|
|
|
|
|
|
start_idx = text.find(clean_src)
|
|
|
end_idx = start_idx + len(clean_src) if start_idx != -1 else len(clean_src)
|
|
|
|
|
|
|
|
|
if hasattr(outputs, 'logits_nsw_detection') and outputs.logits_nsw_detection is not None:
|
|
|
|
|
|
orig_pos = None
|
|
|
for k, tok in enumerate(tokens):
|
|
|
if tok.strip().lstrip('▁') == clean_src:
|
|
|
orig_pos = k
|
|
|
break
|
|
|
|
|
|
if orig_pos is not None and orig_pos < len(nsw_confidence):
|
|
|
if outputs.logits_nsw_detection.dim() == 3:
|
|
|
nsw_conf = nsw_confidence[orig_pos].item()
|
|
|
else:
|
|
|
nsw_conf = nsw_confidence[orig_pos].item()
|
|
|
else:
|
|
|
nsw_conf = 0.5
|
|
|
|
|
|
|
|
|
norm_logits = logits_final[0]
|
|
|
norm_confidence = torch.softmax(norm_logits, dim=-1)
|
|
|
norm_conf = norm_confidence[i][final_tokens[i]].item()
|
|
|
combined_confidence = (nsw_conf + norm_conf) / 2
|
|
|
else:
|
|
|
combined_confidence = 0.5
|
|
|
|
|
|
nsw_results.append({
|
|
|
'index': i,
|
|
|
'start_index': start_idx,
|
|
|
'end_index': end_idx,
|
|
|
'nsw': clean_src,
|
|
|
'prediction': expansion_text,
|
|
|
'confidence_score': round(combined_confidence, 4)
|
|
|
})
|
|
|
|
|
|
|
|
|
i = j
|
|
|
else:
|
|
|
i += 1
|
|
|
|
|
|
return nsw_results
|
|
|
|
|
|
def _truncate_and_build_masks(self, input_tokens_tensor, output_tokens_tensor=None):
|
|
|
"""Apply the same truncation and masking logic as training."""
|
|
|
|
|
|
pad_id_model = 1
|
|
|
input_mask = torch.ones_like(input_tokens_tensor)
|
|
|
token_type_ids = None
|
|
|
return input_tokens_tensor, output_tokens_tensor, token_type_ids, input_mask
|
|
|
|
|
|
|
|
|
__all__ = ["ViSoNormBartPhoForMaskedLM"]
|
|
|
|