| import argparse |
| import json |
| from pathlib import Path |
| import re |
| from typing import Dict, Optional, Union |
| import torch |
| import torch.nn.functional as F |
| from .modules.layers import LstmSeq2SeqEncoder |
| from .modules.base import InstructBase |
| from .modules.evaluator import Evaluator, greedy_search |
| from .modules.span_rep import SpanRepLayer |
| from .modules.token_rep import TokenRepLayer |
| from torch import nn |
| from torch.nn.utils.rnn import pad_sequence |
| from huggingface_hub import PyTorchModelHubMixin, hf_hub_download |
| from huggingface_hub.utils import HfHubHTTPError |
|
|
|
|
|
|
| class GLiNER(InstructBase, PyTorchModelHubMixin): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.config = config |
|
|
| |
| self.entity_token = "<<ENT>>" |
| self.sep_token = "<<SEP>>" |
|
|
| |
| self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune, |
| subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size, |
| add_tokens=[self.entity_token, self.sep_token]) |
|
|
| |
| self.rnn = LstmSeq2SeqEncoder( |
| input_size=config.hidden_size, |
| hidden_size=config.hidden_size // 2, |
| num_layers=1, |
| bidirectional=True, |
| ) |
|
|
| |
| self.span_rep_layer = SpanRepLayer( |
| span_mode=config.span_mode, |
| hidden_size=config.hidden_size, |
| max_width=config.max_width, |
| dropout=config.dropout, |
| ) |
|
|
| |
| self.prompt_rep_layer = nn.Sequential( |
| nn.Linear(config.hidden_size, config.hidden_size * 4), |
| nn.Dropout(config.dropout), |
| nn.ReLU(), |
| nn.Linear(config.hidden_size * 4, config.hidden_size) |
| ) |
|
|
| def compute_score_train(self, x): |
| span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1) |
|
|
| new_length = x['seq_length'].clone() |
| new_tokens = [] |
| all_len_prompt = [] |
| num_classes_all = [] |
|
|
| |
| for i in range(len(x['tokens'])): |
| all_types_i = list(x['classes_to_id'][i].keys()) |
| |
| entity_prompt = [] |
| num_classes_all.append(len(all_types_i)) |
| |
| for entity_type in all_types_i: |
| entity_prompt.append(self.entity_token) |
| entity_prompt.append(entity_type) |
| entity_prompt.append(self.sep_token) |
|
|
| |
| |
|
|
| |
| tokens_p = entity_prompt + x['tokens'][i] |
|
|
| |
| |
|
|
| |
| new_length[i] = new_length[i] + len(entity_prompt) |
| |
| new_tokens.append(tokens_p) |
| |
| all_len_prompt.append(len(entity_prompt)) |
|
|
| |
| max_num_classes = max(num_classes_all) |
| entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to( |
| x['span_mask'].device) |
| entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to( |
| x['span_mask'].device) |
|
|
| |
| bert_output = self.token_rep_layer(new_tokens, new_length) |
| word_rep_w_prompt = bert_output["embeddings"] |
| mask_w_prompt = bert_output["mask"] |
|
|
| |
| word_rep = [] |
| mask = [] |
| entity_type_rep = [] |
| for i in range(len(x['tokens'])): |
| prompt_entity_length = all_len_prompt[i] |
| |
| word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) |
| |
| mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) |
|
|
| |
| entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] |
| entity_rep = entity_rep[0::2] |
| entity_type_rep.append(entity_rep) |
|
|
| |
| word_rep = pad_sequence(word_rep, batch_first=True) |
| mask = pad_sequence(mask, batch_first=True) |
| entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) |
|
|
| |
| word_rep = self.rnn(word_rep, mask) |
| span_rep = self.span_rep_layer(word_rep, span_idx) |
|
|
| |
| entity_type_rep = self.prompt_rep_layer(entity_type_rep) |
| num_classes = entity_type_rep.shape[1] |
|
|
| |
| scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) |
|
|
| return scores, num_classes, entity_type_mask |
|
|
| def forward(self, x): |
| |
| scores, num_classes, entity_type_mask = self.compute_score_train(x) |
| batch_size = scores.shape[0] |
|
|
| |
| logits_label = scores.view(-1, num_classes) |
| labels = x["span_label"].view(-1) |
| mask_label = labels != -1 |
| labels.masked_fill_(~mask_label, 0) |
|
|
| |
| labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device) |
| labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) |
| labels_one_hot = labels_one_hot[:, 1:] |
| |
|
|
| |
| all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot, |
| reduction='none') |
| |
| masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1) |
| all_losses = masked_loss.view(-1, num_classes) |
| |
| mask_label = mask_label.unsqueeze(-1).expand_as(all_losses) |
| |
| weight_c = labels_one_hot + 1 |
| |
| all_losses = all_losses * mask_label.float() * weight_c |
| return all_losses.sum() |
|
|
| def compute_score_eval(self, x, device): |
| |
| assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict" |
|
|
| span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device) |
|
|
| all_types = list(x['classes_to_id'].keys()) |
| |
| entity_prompt = [] |
|
|
| |
| for entity_type in all_types: |
| entity_prompt.append(self.entity_token) |
| entity_prompt.append(entity_type) |
|
|
| entity_prompt.append(self.sep_token) |
|
|
| prompt_entity_length = len(entity_prompt) |
|
|
| |
| tokens_p = [entity_prompt + tokens for tokens in x['tokens']] |
| seq_length_p = x['seq_length'] + prompt_entity_length |
|
|
| out = self.token_rep_layer(tokens_p, seq_length_p) |
|
|
| word_rep_w_prompt = out["embeddings"] |
| mask_w_prompt = out["mask"] |
|
|
| |
| word_rep = word_rep_w_prompt[:, prompt_entity_length:, :] |
| mask = mask_w_prompt[:, prompt_entity_length:] |
|
|
| |
| entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :] |
| |
| entity_type_rep = entity_type_rep[:, 0::2, :] |
|
|
| entity_type_rep = self.prompt_rep_layer(entity_type_rep) |
|
|
| word_rep = self.rnn(word_rep, mask) |
|
|
| span_rep = self.span_rep_layer(word_rep, span_idx) |
|
|
| local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) |
|
|
| return local_scores |
|
|
| @torch.no_grad() |
| def predict(self, x, flat_ner=False, threshold=0.5): |
| self.eval() |
| local_scores = self.compute_score_eval(x, device=next(self.parameters()).device) |
| spans = [] |
| for i, _ in enumerate(x["tokens"]): |
| local_i = local_scores[i] |
| wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)] |
| span_i = [] |
| for s, k, c in zip(*wh_i): |
| if s + k < len(x["tokens"][i]): |
| span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c])) |
| span_i = greedy_search(span_i, flat_ner) |
| spans.append(span_i) |
| return spans |
|
|
| def predict_entities(self, text, labels, flat_ner=True, threshold=0.5): |
| tokens = [] |
| start_token_idx_to_text_idx = [] |
| end_token_idx_to_text_idx = [] |
| for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text): |
| tokens.append(match.group()) |
| start_token_idx_to_text_idx.append(match.start()) |
| end_token_idx_to_text_idx.append(match.end()) |
|
|
| input_x = {"tokenized_text": tokens, "ner": None} |
| x = self.collate_fn([input_x], labels) |
| output = self.predict(x, flat_ner=flat_ner, threshold=threshold) |
|
|
| entities = [] |
| for start_token_idx, end_token_idx, ent_type in output[0]: |
| start_text_idx = start_token_idx_to_text_idx[start_token_idx] |
| end_text_idx = end_token_idx_to_text_idx[end_token_idx] |
| entities.append({ |
| "start": start_token_idx_to_text_idx[start_token_idx], |
| "end": end_token_idx_to_text_idx[end_token_idx], |
| "text": text[start_text_idx:end_text_idx], |
| "label": ent_type, |
| }) |
| return entities |
|
|
| def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None): |
| self.eval() |
| data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False) |
| device = next(self.parameters()).device |
| all_preds = [] |
| all_trues = [] |
| for x in data_loader: |
| for k, v in x.items(): |
| if isinstance(v, torch.Tensor): |
| x[k] = v.to(device) |
| batch_predictions = self.predict(x, flat_ner, threshold) |
| all_preds.extend(batch_predictions) |
| all_trues.extend(x["entities"]) |
| evaluator = Evaluator(all_trues, all_preds) |
| out, f1 = evaluator.evaluate() |
| return out, f1 |
|
|
| @classmethod |
| def _from_pretrained( |
| cls, |
| *, |
| model_id: str, |
| revision: Optional[str], |
| cache_dir: Optional[Union[str, Path]], |
| force_download: bool, |
| proxies: Optional[Dict], |
| resume_download: bool, |
| local_files_only: bool, |
| token: Union[str, bool, None], |
| map_location: str = "cpu", |
| strict: bool = False, |
| **model_kwargs, |
| ): |
| |
| filenames = ["gliner_base.pt", "gliner_multi.pt"] |
| for filename in filenames: |
| model_file = Path(model_id) / filename |
| if not model_file.exists(): |
| try: |
| model_file = hf_hub_download( |
| repo_id=model_id, |
| filename=filename, |
| revision=revision, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
| except HfHubHTTPError: |
| continue |
| dict_load = torch.load(model_file, map_location=torch.device(map_location)) |
| config = dict_load["config"] |
| state_dict = dict_load["model_weights"] |
| config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base" |
| model = cls(config) |
| model.load_state_dict(state_dict, strict=strict, assign=True) |
| |
| model.to(map_location) |
| return model |
|
|
| |
| from .train import load_config_as_namespace |
|
|
| model_file = Path(model_id) / "pytorch_model.bin" |
| if not model_file.exists(): |
| model_file = hf_hub_download( |
| repo_id=model_id, |
| filename="pytorch_model.bin", |
| revision=revision, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
| config_file = Path(model_id) / "gliner_config.json" |
| if not config_file.exists(): |
| config_file = hf_hub_download( |
| repo_id=model_id, |
| filename="gliner_config.json", |
| revision=revision, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
| config = load_config_as_namespace(config_file) |
| model = cls(config) |
| state_dict = torch.load(model_file, map_location=torch.device(map_location)) |
| model.load_state_dict(state_dict, strict=strict, assign=True) |
| model.to(map_location) |
| return model |
|
|
| def save_pretrained( |
| self, |
| save_directory: Union[str, Path], |
| *, |
| config: Optional[Union[dict, "DataclassInstance"]] = None, |
| repo_id: Optional[str] = None, |
| push_to_hub: bool = False, |
| **push_to_hub_kwargs, |
| ) -> Optional[str]: |
| """ |
| Save weights in local directory. |
| |
| Args: |
| save_directory (`str` or `Path`): |
| Path to directory in which the model weights and configuration will be saved. |
| config (`dict` or `DataclassInstance`, *optional*): |
| Model configuration specified as a key/value dictionary or a dataclass instance. |
| push_to_hub (`bool`, *optional*, defaults to `False`): |
| Whether or not to push your model to the Huggingface Hub after saving it. |
| repo_id (`str`, *optional*): |
| ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if |
| not provided. |
| kwargs: |
| Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. |
| """ |
| save_directory = Path(save_directory) |
| save_directory.mkdir(parents=True, exist_ok=True) |
|
|
| |
| torch.save(self.state_dict(), save_directory / "pytorch_model.bin") |
|
|
| |
| if config is None: |
| config = self.config |
| if config is not None: |
| if isinstance(config, argparse.Namespace): |
| config = vars(config) |
| (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2)) |
|
|
| |
| if push_to_hub: |
| kwargs = push_to_hub_kwargs.copy() |
| if config is not None: |
| kwargs["config"] = config |
| if repo_id is None: |
| repo_id = save_directory.name |
| return self.push_to_hub(repo_id=repo_id, **kwargs) |
| return None |
|
|
| def to(self, device): |
| super().to(device) |
| import flair |
|
|
| flair.device = device |
| return self |
|
|