emotion_et_model / model.py
skboy's picture
Upload 13 files
832c573 verified
Raw
History Blame Contribute Delete
2.62 kB
from pathlib import Path
import numpy as np
import torch
from safetensors.torch import load_file
from transformers import AutoTokenizer, RobertaConfig, RobertaModel
FEATURE_NAMES = ["nFix", "FFD", "GPT", "TRT", "fixProp"]
TRT_INDEX = 3
DEFAULT_WEIGHT = "et_predictor2_iitb_scalezero_seed42.safetensors"
class RobertaRegressionModel(torch.nn.Module):
def __init__(self, config_path="."):
super().__init__()
config = RobertaConfig.from_pretrained(config_path)
self.roberta = RobertaModel(config)
self.decoder = torch.nn.Linear(config.hidden_size, len(FEATURE_NAMES))
def forward(self, input_ids, attention_mask):
hidden = self.roberta(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
return self.decoder(hidden)
def load_et_predictor(model_dir, weight_name=DEFAULT_WEIGHT, device=None):
model_dir = Path(model_dir)
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
tokenizer = AutoTokenizer.from_pretrained(model_dir, add_prefix_space=True)
model = RobertaRegressionModel(config_path=model_dir).to(device)
state = load_file(str(model_dir / weight_name), device=str(device))
model.load_state_dict(state)
model.eval()
return model, tokenizer
@torch.no_grad()
def predict_word_features(text, model, tokenizer, device=None, max_length=512):
device = torch.device(device or next(model.parameters()).device)
words = text.strip().split()
encoded = tokenizer(
words,
is_split_into_words=True,
return_tensors="pt",
truncation=True,
max_length=max_length,
padding=False,
)
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
predictions = model(input_ids=input_ids, attention_mask=attention_mask)
predictions = predictions.squeeze(0).clamp_min(0.0).cpu().numpy()
word_ids = encoded.word_ids(batch_index=0)
output = np.zeros((len(words), len(FEATURE_NAMES)), dtype=np.float32)
seen = set()
for token_index, word_index in enumerate(word_ids):
if word_index is None or word_index in seen or word_index >= len(words):
continue
output[word_index] = predictions[token_index]
seen.add(word_index)
return words, output
@torch.no_grad()
def predict_word_trt(text, model, tokenizer, device=None, max_length=512):
words, features = predict_word_features(
text,
model,
tokenizer,
device=device,
max_length=max_length,
)
return words, features[:, TRT_INDEX]