| | from transformers import DebertaV2TokenizerFast |
| | import torch |
| |
|
| | from multi_head_model import MultiHeadModel |
| | from utils import get_torch_device |
| |
|
| |
|
| | class MultiHeadPredictor: |
| | def __init__(self, model_name_or_path: str): |
| | self.tokenizer = DebertaV2TokenizerFast.from_pretrained(model_name_or_path, add_prefix_space=True) |
| | self.model = MultiHeadModel.from_pretrained(model_name_or_path) |
| | self.id2label = self.model.config.label_maps |
| |
|
| | self.device = get_torch_device() |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| |
|
| | def predict(self, text: str): |
| | """ |
| | Perform multi-headed token classification on a single piece of text. |
| | |
| | :param text: The raw text string. |
| | |
| | :return: A dict with {head_name: [predicted_label_for_each_token]} for the tokens in `text`. |
| | """ |
| | raw_tokens = text.split() |
| |
|
| | |
| | |
| | |
| | encoded = self.tokenizer( |
| | raw_tokens, |
| | is_split_into_words=True, |
| | max_length=512, |
| | stride=128, |
| | truncation=True, |
| | return_overflowing_tokens=True, |
| | return_offsets_mapping=False, |
| | padding="max_length" |
| | ) |
| |
|
| | |
| | |
| | sample_map = encoded.get("overflow_to_sample_mapping", [0] * len(encoded["input_ids"])) |
| |
|
| | |
| | chunk_preds = [] |
| | chunk_word_ids = [] |
| |
|
| | |
| | |
| | for i in range(len(encoded["input_ids"])): |
| | |
| | input_ids_tensor = torch.tensor([encoded["input_ids"][i]], dtype=torch.long).to(self.device) |
| | attention_mask_tensor = torch.tensor([encoded["attention_mask"][i]], dtype=torch.long).to(self.device) |
| |
|
| | |
| | with torch.no_grad(): |
| | logits_dict = self.model( |
| | input_ids=input_ids_tensor, |
| | attention_mask=attention_mask_tensor |
| | ) |
| |
|
| | |
| | |
| | pred_ids_dict = {} |
| | for head_name, logits in logits_dict.items(): |
| | |
| | preds = torch.argmax(logits, dim=-1) |
| | |
| | pred_ids_dict[head_name] = preds[0].cpu().numpy().tolist() |
| |
|
| | |
| | chunk_preds.append(pred_ids_dict) |
| |
|
| | |
| | |
| | |
| | word_ids_chunk = encoded.word_ids(batch_index=i) |
| | chunk_word_ids.append(word_ids_chunk) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | final_pred_labels = {**{ |
| | "text": text, |
| | "tokens": raw_tokens, |
| | }, **{ |
| | head: ["O"] * len(raw_tokens) |
| | for head in self.id2label.keys() |
| | }} |
| |
|
| | |
| | |
| | assigned_tokens = set() |
| |
|
| | for i, pred_dict in enumerate(chunk_preds): |
| | w_ids = chunk_word_ids[i] |
| | for pos, w_id in enumerate(w_ids): |
| | if w_id is None: |
| | |
| | continue |
| | if w_id in assigned_tokens: |
| | |
| | continue |
| |
|
| | |
| | |
| | for head_name, pred_ids in pred_dict.items(): |
| | label_id = pred_ids[pos] |
| | label_str = self.id2label[head_name][label_id] |
| | final_pred_labels[head_name][w_id] = label_str |
| |
|
| | assigned_tokens.add(w_id) |
| |
|
| | return final_pred_labels |
| |
|
| |
|
| | if __name__ == "__main__": |
| | predictor = MultiHeadPredictor("./o3-mini_20250218_final") |
| |
|
| | test_cases = [ |
| | "How to convince my parents to let me get a Ball python?", |
| | ] |
| | for case in test_cases: |
| | predictions = predictor.predict(case) |
| | for head_name, labels in predictions.items(): |
| | print(f"{head_name}: {labels}") |
| |
|