| from transformers.modeling_outputs import TokenClassifierOutput |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig |
| from torch.nn import CrossEntropyLoss |
| from typing import Optional, Tuple, Union |
| import logging, json, os |
| import floret |
|
|
|
|
| |
|
|
|
|
| def get_info(label_map): |
| num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} |
| return num_token_labels_dict |
|
|
|
|
| class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| print("11111111111111111111") |
| self.model = floret.load_model(self.config.filename) |
| print("22222222222222222222") |
| |
|
|
| def predict(self, text, k=1): |
| predictions = self.model.predict(text, k) |
| return predictions |