|
|
from transformers import DebertaV2TokenizerFast |
|
|
import torch |
|
|
|
|
|
from multi_head_model import MultiHeadModel |
|
|
from utils import get_torch_device |
|
|
|
|
|
|
|
|
class MultiHeadPredictor: |
|
|
def __init__(self, model_name_or_path: str): |
|
|
self.tokenizer = DebertaV2TokenizerFast.from_pretrained(model_name_or_path, add_prefix_space=True) |
|
|
self.model = MultiHeadModel.from_pretrained(model_name_or_path) |
|
|
self.id2label = self.model.config.label_maps |
|
|
|
|
|
self.device = get_torch_device() |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
def predict(self, text: str): |
|
|
""" |
|
|
Perform multi-headed token classification on a single piece of text. |
|
|
|
|
|
:param text: The raw text string. |
|
|
|
|
|
:return: A dict with {head_name: [predicted_label_for_each_token]} for the tokens in `text`. |
|
|
""" |
|
|
raw_tokens = text.split() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoded = self.tokenizer( |
|
|
raw_tokens, |
|
|
is_split_into_words=True, |
|
|
max_length=512, |
|
|
stride=128, |
|
|
truncation=True, |
|
|
return_overflowing_tokens=True, |
|
|
return_offsets_mapping=False, |
|
|
padding="max_length" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
sample_map = encoded.get("overflow_to_sample_mapping", [0] * len(encoded["input_ids"])) |
|
|
|
|
|
|
|
|
chunk_preds = [] |
|
|
chunk_word_ids = [] |
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(encoded["input_ids"])): |
|
|
|
|
|
input_ids_tensor = torch.tensor([encoded["input_ids"][i]], dtype=torch.long).to(self.device) |
|
|
attention_mask_tensor = torch.tensor([encoded["attention_mask"][i]], dtype=torch.long).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits_dict = self.model( |
|
|
input_ids=input_ids_tensor, |
|
|
attention_mask=attention_mask_tensor |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
pred_ids_dict = {} |
|
|
for head_name, logits in logits_dict.items(): |
|
|
|
|
|
preds = torch.argmax(logits, dim=-1) |
|
|
|
|
|
pred_ids_dict[head_name] = preds[0].cpu().numpy().tolist() |
|
|
|
|
|
|
|
|
chunk_preds.append(pred_ids_dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
word_ids_chunk = encoded.word_ids(batch_index=i) |
|
|
chunk_word_ids.append(word_ids_chunk) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_pred_labels = {**{ |
|
|
"text": text, |
|
|
"tokens": raw_tokens, |
|
|
}, **{ |
|
|
head: ["O"] * len(raw_tokens) |
|
|
for head in self.id2label.keys() |
|
|
}} |
|
|
|
|
|
|
|
|
|
|
|
assigned_tokens = set() |
|
|
|
|
|
for i, pred_dict in enumerate(chunk_preds): |
|
|
w_ids = chunk_word_ids[i] |
|
|
for pos, w_id in enumerate(w_ids): |
|
|
if w_id is None: |
|
|
|
|
|
continue |
|
|
if w_id in assigned_tokens: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
for head_name, pred_ids in pred_dict.items(): |
|
|
label_id = pred_ids[pos] |
|
|
label_str = self.id2label[head_name][label_id] |
|
|
final_pred_labels[head_name][w_id] = label_str |
|
|
|
|
|
assigned_tokens.add(w_id) |
|
|
|
|
|
return final_pred_labels |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
predictor = MultiHeadPredictor("./o3-mini_20250218_final") |
|
|
|
|
|
test_cases = [ |
|
|
"How to convince my parents to let me get a Ball python?", |
|
|
] |
|
|
for case in test_cases: |
|
|
predictions = predictor.predict(case) |
|
|
for head_name, labels in predictions.items(): |
|
|
print(f"{head_name}: {labels}") |
|
|
|