| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn.utils.rnn import pad_sequence |
| | from transformers import AutoModel |
| | from pathlib import Path |
| |
|
| |
|
| | class LinearTokenSelector(nn.Module): |
| | def __init__(self, encoder, embedding_size=768): |
| | super(LinearTokenSelector, self).__init__() |
| | self.encoder = encoder |
| | self.classifier = nn.Linear(embedding_size, 2, bias=False) |
| |
|
| | def forward(self, x): |
| | output = self.encoder(x, output_hidden_states=True) |
| | x = output["hidden_states"][-1] |
| | x = self.classifier(x) |
| | x = F.log_softmax(x, dim=2) |
| | return x |
| |
|
| | def save(self, classifier_path, encoder_path): |
| | state = self.state_dict() |
| | state = dict((k, v) for k, v in state.items() if k.startswith("classifier")) |
| | torch.save(state, classifier_path) |
| | self.encoder.save_pretrained(encoder_path) |
| |
|
| | def predict(self, texts, tokenizer, device): |
| | input_ids = tokenizer(texts)["input_ids"] |
| | input_ids = pad_sequence( |
| | [torch.tensor(ids) for ids in input_ids], batch_first=True |
| | ).to(device) |
| | logits = self.forward(input_ids) |
| | argmax_labels = torch.argmax(logits, dim=2) |
| | return labels_to_summary(input_ids, argmax_labels, tokenizer) |
| |
|
| |
|
| | def load_model(model_dir, device="cuda", prefix="best"): |
| | if isinstance(model_dir, str): |
| | model_dir = Path(model_dir) |
| | for p in (model_dir / "checkpoints").iterdir(): |
| | if p.name.startswith(f"{prefix}"): |
| | checkpoint_dir = p |
| | return load_checkpoint(checkpoint_dir, device=device) |
| |
|
| |
|
| | def load_checkpoint(checkpoint_dir, device="cuda"): |
| | if isinstance(checkpoint_dir, str): |
| | checkpoint_dir = Path(checkpoint_dir) |
| |
|
| | encoder_path = checkpoint_dir / "encoder.bin" |
| | classifier_path = checkpoint_dir / "classifier.bin" |
| |
|
| | encoder = AutoModel.from_pretrained(encoder_path).to(device) |
| | embedding_size = encoder.state_dict()["embeddings.word_embeddings.weight"].shape[1] |
| |
|
| | classifier = LinearTokenSelector(None, embedding_size).to(device) |
| | classifier_state = torch.load(classifier_path, map_location=device) |
| | classifier_state = dict( |
| | (k, v) for k, v in classifier_state.items() |
| | if k.startswith("classifier") |
| | ) |
| | classifier.load_state_dict(classifier_state) |
| | classifier.encoder = encoder |
| | return classifier.to(device) |
| |
|
| |
|
| | def labels_to_summary(input_batch, label_batch, tokenizer): |
| | summaries = [] |
| | for input_ids, labels in zip(input_batch, label_batch): |
| | selected = [int(input_ids[i]) for i in range(len(input_ids)) |
| | if labels[i] == 1] |
| | summary = tokenizer.decode(selected, skip_special_tokens=True) |
| | summaries.append(summary) |
| | return summaries |
| |
|