File size: 2,622 Bytes
832c573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from pathlib import Path

import numpy as np
import torch
from safetensors.torch import load_file
from transformers import AutoTokenizer, RobertaConfig, RobertaModel

FEATURE_NAMES = ["nFix", "FFD", "GPT", "TRT", "fixProp"]
TRT_INDEX = 3
DEFAULT_WEIGHT = "et_predictor2_iitb_scalezero_seed42.safetensors"


class RobertaRegressionModel(torch.nn.Module):
    def __init__(self, config_path="."):
        super().__init__()
        config = RobertaConfig.from_pretrained(config_path)
        self.roberta = RobertaModel(config)
        self.decoder = torch.nn.Linear(config.hidden_size, len(FEATURE_NAMES))

    def forward(self, input_ids, attention_mask):
        hidden = self.roberta(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        return self.decoder(hidden)


def load_et_predictor(model_dir, weight_name=DEFAULT_WEIGHT, device=None):
    model_dir = Path(model_dir)
    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
    tokenizer = AutoTokenizer.from_pretrained(model_dir, add_prefix_space=True)
    model = RobertaRegressionModel(config_path=model_dir).to(device)
    state = load_file(str(model_dir / weight_name), device=str(device))
    model.load_state_dict(state)
    model.eval()
    return model, tokenizer


@torch.no_grad()
def predict_word_features(text, model, tokenizer, device=None, max_length=512):
    device = torch.device(device or next(model.parameters()).device)
    words = text.strip().split()
    encoded = tokenizer(
        words,
        is_split_into_words=True,
        return_tensors="pt",
        truncation=True,
        max_length=max_length,
        padding=False,
    )
    input_ids = encoded["input_ids"].to(device)
    attention_mask = encoded["attention_mask"].to(device)
    predictions = model(input_ids=input_ids, attention_mask=attention_mask)
    predictions = predictions.squeeze(0).clamp_min(0.0).cpu().numpy()

    word_ids = encoded.word_ids(batch_index=0)
    output = np.zeros((len(words), len(FEATURE_NAMES)), dtype=np.float32)
    seen = set()

    for token_index, word_index in enumerate(word_ids):
        if word_index is None or word_index in seen or word_index >= len(words):
            continue
        output[word_index] = predictions[token_index]
        seen.add(word_index)

    return words, output


@torch.no_grad()
def predict_word_trt(text, model, tokenizer, device=None, max_length=512):
    words, features = predict_word_features(
        text,
        model,
        tokenizer,
        device=device,
        max_length=max_length,
    )
    return words, features[:, TRT_INDEX]