| from transformers import DebertaV2TokenizerFast |
| import torch |
|
|
| from multi_head_model import MultiHeadModel |
| from utils import get_torch_device |
|
|
|
|
| class MultiHeadPredictor: |
| def __init__(self, model_name_or_path: str): |
| self.tokenizer = DebertaV2TokenizerFast.from_pretrained(model_name_or_path, add_prefix_space=True) |
| self.model = MultiHeadModel.from_pretrained(model_name_or_path) |
| self.id2label = self.model.config.label_maps |
|
|
| self.device = get_torch_device() |
| self.model.to(self.device) |
| self.model.eval() |
|
|
|
|
| def predict(self, text: str): |
| """ |
| Perform multi-headed token classification on a single piece of text. |
| |
| :param text: The raw text string. |
| |
| :return: A dict with {head_name: [predicted_label_for_each_token]} for the tokens in `text`. |
| """ |
| raw_tokens = text.split() |
|
|
| |
| |
| |
| encoded = self.tokenizer( |
| raw_tokens, |
| is_split_into_words=True, |
| max_length=512, |
| stride=128, |
| truncation=True, |
| return_overflowing_tokens=True, |
| return_offsets_mapping=False, |
| padding="max_length" |
| ) |
|
|
| |
| |
| sample_map = encoded.get("overflow_to_sample_mapping", [0] * len(encoded["input_ids"])) |
|
|
| |
| chunk_preds = [] |
| chunk_word_ids = [] |
|
|
| |
| |
| for i in range(len(encoded["input_ids"])): |
| |
| input_ids_tensor = torch.tensor([encoded["input_ids"][i]], dtype=torch.long).to(self.device) |
| attention_mask_tensor = torch.tensor([encoded["attention_mask"][i]], dtype=torch.long).to(self.device) |
|
|
| |
| with torch.no_grad(): |
| logits_dict = self.model( |
| input_ids=input_ids_tensor, |
| attention_mask=attention_mask_tensor |
| ) |
|
|
| |
| |
| pred_ids_dict = {} |
| for head_name, logits in logits_dict.items(): |
| |
| preds = torch.argmax(logits, dim=-1) |
| |
| pred_ids_dict[head_name] = preds[0].cpu().numpy().tolist() |
|
|
| |
| chunk_preds.append(pred_ids_dict) |
|
|
| |
| |
| |
| word_ids_chunk = encoded.word_ids(batch_index=i) |
| chunk_word_ids.append(word_ids_chunk) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| final_pred_labels = {**{ |
| "text": text, |
| "tokens": raw_tokens, |
| }, **{ |
| head: ["O"] * len(raw_tokens) |
| for head in self.id2label.keys() |
| }} |
|
|
| |
| |
| assigned_tokens = set() |
|
|
| for i, pred_dict in enumerate(chunk_preds): |
| w_ids = chunk_word_ids[i] |
| for pos, w_id in enumerate(w_ids): |
| if w_id is None: |
| |
| continue |
| if w_id in assigned_tokens: |
| |
| continue |
|
|
| |
| |
| for head_name, pred_ids in pred_dict.items(): |
| label_id = pred_ids[pos] |
| label_str = self.id2label[head_name][label_id] |
| final_pred_labels[head_name][w_id] = label_str |
|
|
| assigned_tokens.add(w_id) |
|
|
| return final_pred_labels |
|
|
|
|
| if __name__ == "__main__": |
| predictor = MultiHeadPredictor("./o3-mini_20250218_final") |
|
|
| test_cases = [ |
| "How to convince my parents to let me get a Ball python?", |
| ] |
| for case in test_cases: |
| predictions = predictor.predict(case) |
| for head_name, labels in predictions.items(): |
| print(f"{head_name}: {labels}") |
|
|