Albin Thörn Cleland
Clean initial commit with LFS
19b8775
""" see __init__.py """
from datetime import datetime
import dataclasses
import json
import logging
import os
import random
import re
from typing import Any, Dict, List, Optional, Set, Tuple
import numpy as np # type: ignore
try:
import tomllib
except ImportError:
import tomli as tomllib
import torch
import transformers # type: ignore
from pickle import UnpicklingError
import warnings
from stanza.utils.get_tqdm import get_tqdm # type: ignore
tqdm = get_tqdm()
from stanza.models.coref import bert, conll, utils
from stanza.models.coref.anaphoricity_scorer import AnaphoricityScorer
from stanza.models.coref.cluster_checker import ClusterChecker
from stanza.models.coref.config import Config
from stanza.models.coref.const import CorefResult, Doc
from stanza.models.coref.loss import CorefLoss
from stanza.models.coref.pairwise_encoder import PairwiseEncoder
from stanza.models.coref.rough_scorer import RoughScorer
from stanza.models.coref.span_predictor import SpanPredictor
from stanza.models.coref.utils import GraphNode
from stanza.models.coref.utils import sigmoid_focal_loss
from stanza.models.coref.word_encoder import WordEncoder
from stanza.models.coref.dataset import CorefDataset
from stanza.models.coref.tokenizer_customization import *
from stanza.models.common.bert_embedding import load_tokenizer
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
import torch.nn as nn
logger = logging.getLogger('stanza')
class CorefModel: # pylint: disable=too-many-instance-attributes
"""Combines all coref modules together to find coreferent spans.
Attributes:
config (coref.config.Config): the model's configuration,
see config.toml for the details
epochs_trained (int): number of epochs the model has been trained for
trainable (Dict[str, torch.nn.Module]): trainable submodules with their
names used as keys
training (bool): used to toggle train/eval modes
Submodules (in the order of their usage in the pipeline):
tokenizer (transformers.AutoTokenizer)
bert (transformers.AutoModel)
we (WordEncoder)
rough_scorer (RoughScorer)
pw (PairwiseEncoder)
a_scorer (AnaphoricityScorer)
sp (SpanPredictor)
"""
def __init__(self,
epochs_trained: int = 0,
build_optimizers: bool = True,
config: Optional[dict] = None,
foundation_cache=None):
"""
A newly created model is set to evaluation mode.
Args:
config_path (str): the path to the toml file with the configuration
section (str): the selected section of the config file
epochs_trained (int): the number of epochs finished
(useful for warm start)
"""
if config is None:
raise ValueError("Cannot create a model without a config")
self.config = config
self.epochs_trained = epochs_trained
self._docs: Dict[str, List[Doc]] = {}
self._build_model(foundation_cache)
self.optimizers = {}
self.schedulers = {}
if build_optimizers:
self._build_optimizers()
self._set_training(False)
# final coreference resolution score
self._coref_criterion = CorefLoss(self.config.bce_loss_weight)
# score simply for the top-k choices out of the rough scorer
self._rough_criterion = CorefLoss(0)
# exact span matches
self._span_criterion = torch.nn.CrossEntropyLoss(reduction="sum")
@property
def training(self) -> bool:
""" Represents whether the model is in the training mode """
return self._training
@training.setter
def training(self, new_value: bool):
if self._training is new_value:
return
self._set_training(new_value)
# ========================================================== Public methods
@torch.no_grad()
def evaluate(self,
data_split: str = "dev",
word_level_conll: bool = False,
eval_lang: Optional[str] = None
) -> Tuple[float, Tuple[float, float, float]]:
""" Evaluates the modes on the data split provided.
Args:
data_split (str): one of 'dev'/'test'/'train'
word_level_conll (bool): if True, outputs conll files on word-level
eval_lang (str): which language to evaluate
Returns:
mean loss
span-level LEA: f1, precision, recal
"""
self.training = False
w_checker = ClusterChecker()
s_checker = ClusterChecker()
try:
data_split_data = f"{data_split}_data"
data_path = self.config.__dict__[data_split_data]
docs = self._get_docs(data_path)
except FileNotFoundError as e:
raise FileNotFoundError("Unable to find data split %s at file %s" % (data_split_data, data_path)) from e
running_loss = 0.0
s_correct = 0
s_total = 0
z_correct = 0
z_total = 0
with conll.open_(self.config, self.epochs_trained, data_split) \
as (gold_f, pred_f):
pbar = tqdm(docs, unit="docs", ncols=0)
for doc in pbar:
if eval_lang and doc.get("lang", "") != eval_lang:
# skip that document, only used for ablation where we only
# want to test evaluation on one language
continue
res = self.run(doc, True)
# measure zero prediction accuracy
zero_preds = (res.zero_scores > 0).view(-1).to(device=res.zero_scores.device)
is_zero = doc.get("is_zero")
if is_zero is None:
zero_targets = torch.zeros_like(zero_preds, device=zero_preds.device)
else:
zero_targets = torch.tensor(is_zero, device=zero_preds.device)
z_correct += (zero_preds == zero_targets).sum().item()
z_total += zero_targets.numel()
if (res.coref_y.argmax(dim=1) == 1).all():
logger.warning(f"EVAL: skipping document with no corefs...")
continue
running_loss += self._coref_criterion(res.coref_scores, res.coref_y).item()
if res.word_clusters is None or res.span_clusters is None:
logger.warning(f"EVAL: skipping document with no clusters...")
continue
if res.span_y:
pred_starts = res.span_scores[:, :, 0].argmax(dim=1)
pred_ends = res.span_scores[:, :, 1].argmax(dim=1)
s_correct += ((res.span_y[0] == pred_starts) * (res.span_y[1] == pred_ends)).sum().item()
s_total += len(pred_starts)
if word_level_conll:
raise NotImplementedError("We now write Conll-U conforming to UDCoref, which means that the span_clusters annotations will have headword info. word_level option is meaningless.")
else:
conll.write_conll(doc, doc["span_clusters"], doc["word_clusters"], gold_f)
conll.write_conll(doc, res.span_clusters, res.word_clusters, pred_f)
w_checker.add_predictions(doc["word_clusters"], res.word_clusters)
w_lea = w_checker.total_lea
s_checker.add_predictions(doc["span_clusters"], res.span_clusters)
s_lea = s_checker.total_lea
del res
pbar.set_description(
f"{data_split}:"
f" | WL: "
f" loss: {running_loss / (pbar.n + 1):<.5f},"
f" f1: {w_lea[0]:.5f},"
f" p: {w_lea[1]:.5f},"
f" r: {w_lea[2]:<.5f}"
f" | SL: "
f" sa: {s_correct / s_total:<.5f},"
f" f1: {s_lea[0]:.5f},"
f" p: {s_lea[1]:.5f},"
f" r: {s_lea[2]:<.5f}"
f" | ZA: {z_correct / z_total:<.5f}"
)
logger.info(f"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}")
logger.info(f"Zero prediction accuracy: {z_correct / z_total:.5f}")
return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc, w_checker.bakeoff, s_checker.bakeoff)
def load_weights(self,
path: Optional[str] = None,
ignore: Optional[Set[str]] = None,
map_location: Optional[str] = None,
noexception: bool = False) -> None:
"""
Loads pretrained weights of modules saved in a file located at path.
If path is None, the last saved model with current configuration
in save_dir is loaded.
Assumes files are named like {configuration}_(e{epoch}_{time})*.pt.
"""
if path is None:
# pattern = rf"{self.config.save_name}_\(e(\d+)_[^()]*\).*\.pt"
# tries to load the last checkpoint in the same dir
pattern = rf"{self.config.save_name}.*?\.checkpoint\.pt"
files = []
os.makedirs(self.config.save_dir, exist_ok=True)
for f in os.listdir(self.config.save_dir):
match_obj = re.match(pattern, f)
if match_obj:
files.append(f)
if not files:
if noexception:
logger.debug("No weights have been loaded", flush=True)
return
raise OSError(f"No weights found in {self.config.save_dir}!")
path = sorted(files)[-1]
path = os.path.join(self.config.save_dir, path)
if map_location is None:
map_location = self.config.device
logger.debug(f"Loading from {path}...")
try:
state_dicts = torch.load(path, map_location=map_location, weights_only=True)
except UnpicklingError:
state_dicts = torch.load(path, map_location=map_location, weights_only=False)
warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.")
self.epochs_trained = state_dicts.pop("epochs_trained", 0)
# just ignore a config in the model, since we should already have one
# TODO: some config elements may be fixed parameters of the model,
# such as the dimensions of the head,
# so we would want to use the ones from the config even if the
# user created a weird shaped model
config = state_dicts.pop("config", {})
self.load_state_dicts(state_dicts, ignore)
def load_state_dicts(self,
state_dicts: dict,
ignore: Optional[Set[str]] = None):
"""
Process the dictionaries from the save file
Loads the weights into the tensors of this model
May also have optimizer and/or schedule state
"""
for key, state_dict in state_dicts.items():
logger.debug("Loading state: %s", key)
if not ignore or key not in ignore:
if key.endswith("_optimizer"):
self.optimizers[key].load_state_dict(state_dict)
elif key.endswith("_scheduler"):
self.schedulers[key].load_state_dict(state_dict)
elif key == "bert_lora":
assert self.config.lora, "Unable to load state dict of LoRA model into model initialized without LoRA!"
self.bert = load_peft_wrapper(self.bert, state_dict, vars(self.config), logger, self.peft_name)
else:
self.trainable[key].load_state_dict(state_dict, strict=False)
logger.debug(f"Loaded {key}")
if self.config.log_norms:
self.log_norms()
def build_doc(self, doc: dict) -> dict:
filter_func = TOKENIZER_FILTERS.get(self.config.bert_model,
lambda _: True)
token_map = TOKENIZER_MAPS.get(self.config.bert_model, {})
word2subword = []
subwords = []
word_id = []
for i, word in enumerate(doc["cased_words"]):
tokenized_word = (token_map[word]
if word in token_map
else self.tokenizer.tokenize(word))
tokenized_word = list(filter(filter_func, tokenized_word))
word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
subwords.extend(tokenized_word)
word_id.extend([i] * len(tokenized_word))
doc["word2subword"] = word2subword
doc["subwords"] = subwords
doc["word_id"] = word_id
doc["head2span"] = []
if "speaker" not in doc:
doc["speaker"] = ["_" for _ in doc["cased_words"]]
doc["word_clusters"] = []
doc["span_clusters"] = []
return doc
@staticmethod
def load_model(path: str,
map_location: str = "cpu",
ignore: Optional[Set[str]] = None,
config_update: Optional[dict] = None,
foundation_cache = None):
if not path:
raise FileNotFoundError("coref model got an invalid path |%s|" % path)
if not os.path.exists(path):
raise FileNotFoundError("coref model file %s not found" % path)
try:
state_dicts = torch.load(path, map_location=map_location, weights_only=True)
except UnpicklingError:
state_dicts = torch.load(path, map_location=map_location, weights_only=False)
warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.")
epochs_trained = state_dicts.pop("epochs_trained", 0)
config = state_dicts.pop('config', None)
if config is None:
raise ValueError("Cannot load this format model without config in the dicts")
if 'max_train_len' not in config:
# TODO: this is to keep old models working.
# Can get rid of it if those models are rebuilt
config['max_train_len'] = 5000
if isinstance(config, dict):
config = Config(**config)
if config_update:
for key, value in config_update.items():
setattr(config, key, value)
model = CorefModel(config=config, build_optimizers=False,
epochs_trained=epochs_trained, foundation_cache=foundation_cache)
model.load_state_dicts(state_dicts, ignore)
return model
def run(self, # pylint: disable=too-many-locals
doc: Doc,
use_gold_spans_for_zeros = False
) -> CorefResult:
"""
This is a massive method, but it made sense to me to not split it into
several ones to let one see the data flow.
Args:
doc (Doc): a dictionary with the document data.
Returns:
CorefResult (see const.py)
"""
# Encode words with bert
# words [n_words, span_emb]
# cluster_ids [n_words]
words, cluster_ids = self.we(doc, self._bertify(doc))
# Obtain bilinear scores and leave only top-k antecedents for each word
# top_rough_scores [n_words, n_ants]
# top_indices [n_words, n_ants]
top_rough_scores, top_indices, rough_scores = self.rough_scorer(words)
# Get pairwise features [n_words, n_ants, n_pw_features]
pw = self.pw(top_indices, doc)
batch_size = self.config.a_scoring_batch_size
a_scores_lst: List[torch.Tensor] = []
for i in range(0, len(words), batch_size):
pw_batch = pw[i:i + batch_size]
words_batch = words[i:i + batch_size]
top_indices_batch = top_indices[i:i + batch_size]
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
# a_scores_batch [batch_size, n_ants]
a_scores_batch = self.a_scorer(
top_mentions=words[top_indices_batch], mentions_batch=words_batch,
pw_batch=pw_batch, top_rough_scores_batch=top_rough_scores_batch
)
a_scores_lst.append(a_scores_batch)
res = CorefResult()
# coref_scores [n_spans, n_ants]
res.coref_scores = torch.cat(a_scores_lst, dim=0)
res.coref_y = self._get_ground_truth(
cluster_ids, top_indices, (top_rough_scores > float("-inf")),
self.config.clusters_starts_are_singletons,
self.config.singletons
)
res.word_clusters = self._clusterize(
doc, res.coref_scores, top_indices,
self.config.singletons
)
res.span_scores, res.span_y = self.sp.get_training_data(doc, words)
if not self.training:
res.span_clusters = self.sp.predict(doc, words, res.word_clusters)
if not self.training and not use_gold_spans_for_zeros:
zero_words = words[[word_id
for cluster in res.word_clusters
for word_id in cluster]]
else:
zero_words = words[[i[0] for i in sorted(doc["head2span"])]]
res.zero_scores = self.zeros_predictor(zero_words)
return res
def save_weights(self, save_path=None, save_optimizers=True):
""" Saves trainable models as state dicts. """
to_save: List[Tuple[str, Any]] = \
[(key, value) for key, value in self.trainable.items()
if (self.config.bert_finetune and not self.config.lora) or key != "bert"]
if save_optimizers:
to_save.extend(self.optimizers.items())
to_save.extend(self.schedulers.items())
time = datetime.strftime(datetime.now(), "%Y.%m.%d_%H.%M")
if save_path is None:
save_path = os.path.join(self.config.save_dir,
f"{self.config.save_name}"
f"_e{self.epochs_trained}_{time}.pt")
savedict = {name: module.state_dict() for name, module in to_save}
if self.config.lora:
# so that this dependency remains optional
from peft import get_peft_model_state_dict
savedict["bert_lora"] = get_peft_model_state_dict(self.bert, adapter_name="coref")
savedict["epochs_trained"] = self.epochs_trained # type: ignore
# save as a dictionary because the weights_only=True load option
# doesn't allow for arbitrary @dataclass configs
savedict["config"] = dataclasses.asdict(self.config)
save_dir = os.path.split(save_path)[0]
if save_dir:
os.makedirs(save_dir, exist_ok=True)
torch.save(savedict, save_path)
def log_norms(self):
lines = ["NORMS FOR MODEL PARAMTERS"]
for t_name, trainable in self.trainable.items():
for name, param in trainable.named_parameters():
if param.requires_grad:
lines.append(" %s: %s %.6g (%d)" % (t_name, name, torch.norm(param).item(), param.numel()))
logger.info("\n".join(lines))
def train(self, log=False):
"""
Trains all the trainable blocks in the model using the config provided.
log: whether or not to log using wandb
skip_lang: str if we want to skip training this language (used for ablation)
"""
if log:
import wandb
wandb.watch((self.bert, self.pw,
self.a_scorer, self.we,
self.rough_scorer, self.sp))
docs = self._get_docs(self.config.train_data)
docs_ids = list(range(len(docs)))
avg_spans = docs.avg_span
# for a brand new model, we set the zeros prediction to all 0 if the dataset has no zeros
training_has_zeros = any('is_zero' in doc for doc in docs)
if not training_has_zeros:
logger.info("No zeros found in the dataset. The zeros predictor will set to 0")
if self.epochs_trained == 0:
# new model, set it to always predict not-zero
self.disable_zeros_predictor()
attenuated_languages = set()
if self.config.lang_lr_attenuation:
attenuated_languages = self.config.lang_lr_attenuation.split(",")
logger.info("Attenuating LR for the following languages: %s", attenuated_languages)
lr_scaled_languages = dict()
if self.config.lang_lr_weights:
scaled_languages = self.config.lang_lr_weights.split(",")
for piece in scaled_languages:
pieces = piece.split("=")
lr_scaled_languages[pieces[0]] = float(pieces[1])
logger.info("Scaling LR for the following languages: %s", lr_scaled_languages)
best_f1 = None
for epoch in range(self.epochs_trained, self.config.train_epochs):
self.training = True
if self.config.log_norms:
self.log_norms()
running_c_loss = 0.0
running_s_loss = 0.0
running_z_loss = 0.0
random.shuffle(docs_ids)
pbar = tqdm(docs_ids, unit="docs", ncols=0)
for doc_indx, doc_id in enumerate(pbar):
doc = docs[doc_id]
# skip very long documents during training time
if len(doc["subwords"]) > self.config.max_train_len:
continue
for optim in self.optimizers.values():
optim.zero_grad()
res = self.run(doc)
if res.zero_scores.size(0) == 0 or not training_has_zeros:
z_loss = 0.0 # since there are no corefs
else:
is_zero = doc.get("is_zero")
if is_zero is None:
is_zero = torch.zeros_like(res.zero_scores.squeeze(-1), device=res.zero_scores.device, dtype=torch.float)
else:
is_zero = torch.tensor(is_zero).to(res.zero_scores.device).float()
z_loss = sigmoid_focal_loss(res.zero_scores.squeeze(-1), is_zero, reduction="mean")
c_loss = self._coref_criterion(res.coref_scores, res.coref_y)
if res.span_y:
s_loss = (self._span_criterion(res.span_scores[:, :, 0], res.span_y[0])
+ self._span_criterion(res.span_scores[:, :, 1], res.span_y[1])) / avg_spans / 2
else:
s_loss = torch.zeros_like(c_loss)
lr_scale = lr_scaled_languages.get(doc.get("lang"), 1.0)
if doc.get("lang") in attenuated_languages:
lr_scale = lr_scale / max(epoch, 1.0)
c_loss = c_loss * lr_scale
s_loss = s_loss * lr_scale
z_loss = z_loss * lr_scale
(c_loss + s_loss + z_loss).backward()
running_c_loss += c_loss.item()
running_s_loss += s_loss.item()
if res.zero_scores.size(0) != 0 and training_has_zeros:
running_z_loss += z_loss.item()
# log every 100 docs
if log and doc_indx % 100 == 0:
logged = {
'train_c_loss': c_loss.item(),
'train_s_loss': s_loss.item(),
}
if res.zero_scores.size(0) != 0 and training_has_zeros:
logged['train_z_loss'] = z_loss.item()
wandb.log(logged)
del c_loss, s_loss, z_loss, res
for optim in self.optimizers.values():
optim.step()
for scheduler in self.schedulers.values():
scheduler.step()
pbar.set_description(
f"Epoch {epoch + 1}:"
f" {doc['document_id']:26}"
f" c_loss: {running_c_loss / (pbar.n + 1):<.5f}"
f" s_loss: {running_s_loss / (pbar.n + 1):<.5f}"
f" z_loss: {running_z_loss / (pbar.n + 1):<.5f}"
)
self.epochs_trained += 1
scores = self.evaluate()
prev_best_f1 = best_f1
if log:
wandb.log({'dev_score': scores[1]})
wandb.log({'dev_bakeoff': scores[-1]})
if best_f1 is None or scores[1] > best_f1:
if best_f1 is None:
logger.info("Saving new best model: F1 %.4f", scores[1])
else:
logger.info("Saving new best model: F1 %.4f > %.4f", scores[1], best_f1)
best_f1 = scores[1]
if self.config.save_name.endswith(".pt"):
save_path = os.path.join(self.config.save_dir,
f"{self.config.save_name}")
else:
save_path = os.path.join(self.config.save_dir,
f"{self.config.save_name}.pt")
self.save_weights(save_path, save_optimizers=False)
if self.config.save_each_checkpoint:
self.save_weights()
else:
if self.config.save_name.endswith(".pt"):
checkpoint_path = os.path.join(self.config.save_dir,
f"{self.config.save_name[:-3]}.checkpoint.pt")
else:
checkpoint_path = os.path.join(self.config.save_dir,
f"{self.config.save_name}.checkpoint.pt")
self.save_weights(checkpoint_path)
if prev_best_f1 is not None and prev_best_f1 != best_f1:
logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f\nPrevious best F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1, prev_best_f1)
else:
logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1)
# ========================================================= Private methods
def _bertify(self, doc: Doc) -> torch.Tensor:
all_batches = bert.get_subwords_batches(doc, self.config, self.tokenizer)
# we index the batches n at a time to prevent oom
result = []
for i in range(0, all_batches.shape[0], 1024):
subwords_batches = all_batches[i:i+1024]
special_tokens = np.array([self.tokenizer.cls_token_id,
self.tokenizer.sep_token_id,
self.tokenizer.pad_token_id,
self.tokenizer.eos_token_id])
subword_mask = ~(np.isin(subwords_batches, special_tokens))
subwords_batches_tensor = torch.tensor(subwords_batches,
device=self.config.device,
dtype=torch.long)
subword_mask_tensor = torch.tensor(subword_mask,
device=self.config.device)
# Obtain bert output for selected batches only
attention_mask = (subwords_batches != self.tokenizer.pad_token_id)
if "t5" in self.config.bert_model:
out = self.bert.encoder(
input_ids=subwords_batches_tensor,
attention_mask=torch.tensor(
attention_mask, device=self.config.device))
else:
out = self.bert(
subwords_batches_tensor,
attention_mask=torch.tensor(
attention_mask, device=self.config.device))
out = out['last_hidden_state']
# [n_subwords, bert_emb]
result.append(out[subword_mask_tensor])
# stack returns and return
return torch.cat(result)
def _build_model(self, foundation_cache):
if hasattr(self.config, 'lora') and self.config.lora:
self.bert, self.tokenizer, peft_name = load_bert_with_peft(self.config.bert_model, "coref", foundation_cache)
# vars() converts a dataclass to a dict, used for being able to index things like args["lora_*"]
self.bert = build_peft_wrapper(self.bert, vars(self.config), logger, adapter_name=peft_name)
self.peft_name = peft_name
else:
if self.config.bert_finetune:
logger.debug("Coref model requested a finetuned transformer; we are not using the foundation model cache to prevent we accidentally leak the finetuning weights elsewhere.")
foundation_cache = NoTransformerFoundationCache(foundation_cache)
self.bert, self.tokenizer = load_bert(self.config.bert_model, foundation_cache)
base_bert_name = self.config.bert_model.split("/")[-1]
tokenizer_kwargs = self.config.tokenizer_kwargs.get(base_bert_name, {})
if tokenizer_kwargs:
logger.debug(f"Using tokenizer kwargs: {tokenizer_kwargs}")
# we just downloaded the tokenizer, so for simplicity, we don't make another request to HF
self.tokenizer = load_tokenizer(self.config.bert_model, tokenizer_kwargs, local_files_only=True)
if self.config.bert_finetune or (hasattr(self.config, 'lora') and self.config.lora):
self.bert = self.bert.train()
self.bert = self.bert.to(self.config.device)
self.pw = PairwiseEncoder(self.config).to(self.config.device)
bert_emb = self.bert.config.hidden_size
pair_emb = bert_emb * 3 + self.pw.shape
# pylint: disable=line-too-long
self.a_scorer = AnaphoricityScorer(pair_emb, self.config).to(self.config.device)
self.we = WordEncoder(bert_emb, self.config).to(self.config.device)
self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device)
self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device)
self.zeros_predictor = nn.Sequential(
nn.Linear(bert_emb, bert_emb),
nn.ReLU(),
nn.Linear(bert_emb, 1)
).to(self.config.device)
if not hasattr(self.config, 'use_zeros') or not self.config.use_zeros:
self.disable_zeros_predictor()
self.trainable: Dict[str, torch.nn.Module] = {
"bert": self.bert, "we": self.we,
"rough_scorer": self.rough_scorer,
"pw": self.pw, "a_scorer": self.a_scorer,
"sp": self.sp, "zeros_predictor": self.zeros_predictor
}
def disable_zeros_predictor(self):
nn.init.zeros_(self.zeros_predictor[-1].weight)
nn.init.zeros_(self.zeros_predictor[-1].bias)
def _build_optimizers(self):
n_docs = len(self._get_docs(self.config.train_data))
self.optimizers: Dict[str, torch.optim.Optimizer] = {}
self.schedulers: Dict[str, torch.optim.lr_scheduler.LRScheduler] = {}
if not getattr(self.config, 'lora', False):
for param in self.bert.parameters():
param.requires_grad = self.config.bert_finetune
if self.config.bert_finetune:
logger.debug("Making bert optimizer with LR of %f", self.config.bert_learning_rate)
self.optimizers["bert_optimizer"] = torch.optim.Adam(
self.bert.parameters(), lr=self.config.bert_learning_rate
)
start_finetuning = int(n_docs * self.config.bert_finetune_begin_epoch)
if start_finetuning > 0:
logger.info("Will begin finetuning transformer at iteration %d", start_finetuning)
zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizers["bert_optimizer"], factor=0, total_iters=start_finetuning)
warmup_scheduler = transformers.get_linear_schedule_with_warmup(
self.optimizers["bert_optimizer"],
start_finetuning, n_docs * self.config.train_epochs - start_finetuning)
self.schedulers["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR(
self.optimizers["bert_optimizer"],
schedulers=[zero_scheduler, warmup_scheduler],
milestones=[start_finetuning])
# Must ensure the same ordering of parameters between launches
modules = sorted((key, value) for key, value in self.trainable.items()
if key != "bert")
params = []
for _, module in modules:
for param in module.parameters():
param.requires_grad = True
params.append(param)
self.optimizers["general_optimizer"] = torch.optim.Adam(
params, lr=self.config.learning_rate)
self.schedulers["general_scheduler"] = \
transformers.get_linear_schedule_with_warmup(
self.optimizers["general_optimizer"],
0, n_docs * self.config.train_epochs
)
def _clusterize(self, doc: Doc, scores: torch.Tensor, top_indices: torch.Tensor,
singletons: bool = True):
if singletons:
antecedents = scores[:,1:].argmax(dim=1) - 1
# set the dummy values to -1, so that they are not coref to themselves
is_start = (scores[:, :2].argmax(dim=1) == 0)
else:
antecedents = scores.argmax(dim=1) - 1
not_dummy = antecedents >= 0
coref_span_heads = torch.arange(0, len(scores), device=not_dummy.device)[not_dummy]
antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
nodes = [GraphNode(i) for i in range(len(doc["cased_words"]))]
for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
nodes[i].link(nodes[j])
assert nodes[i] is not nodes[j]
visited = {}
clusters = []
for node in nodes:
if len(node.links) > 0 and not node.visited:
cluster = []
stack = [node]
while stack:
current_node = stack.pop()
current_node.visited = True
cluster.append(current_node.id)
stack.extend(link for link in current_node.links if not link.visited)
assert len(cluster) > 1
for i in cluster:
visited[i] = True
clusters.append(sorted(cluster))
if singletons:
# go through the is_start nodes; if no clusters contain that node
# i.e. visited[i] == False, we add it as a singleton
for indx, i in enumerate(is_start):
if i and not visited.get(indx, False):
clusters.append([indx])
return sorted(clusters)
def _get_docs(self, path: str) -> List[Doc]:
if path not in self._docs:
self._docs[path] = CorefDataset(path, self.config, self.tokenizer)
return self._docs[path]
@staticmethod
def _get_ground_truth(cluster_ids: torch.Tensor,
top_indices: torch.Tensor,
valid_pair_map: torch.Tensor,
cluster_starts: bool,
singletons:bool = True) -> torch.Tensor:
"""
Args:
cluster_ids: tensor of shape [n_words], containing cluster indices
for each word. Non-gold words have cluster id of zero.
top_indices: tensor of shape [n_words, n_ants],
indices of antecedents of each word
valid_pair_map: boolean tensor of shape [n_words, n_ants],
whether for pair at [i, j] (i-th word and j-th word)
j < i is True
Returns:
tensor of shape [n_words, n_ants + 1] (dummy added),
containing 1 at position [i, j] if i-th and j-th words corefer.
"""
y = cluster_ids[top_indices] * valid_pair_map # [n_words, n_ants]
y[y == 0] = -1 # -1 for non-gold words
y = utils.add_dummy(y) # [n_words, n_cands + 1]
if singletons:
if not cluster_starts:
unique, counts = cluster_ids.unique(return_counts=True)
singleton_clusters = unique[(counts == 1) & (unique != 0)]
first_corefs = [(cluster_ids == i).nonzero().flatten()[0] for i in singleton_clusters]
if len(first_corefs) > 0:
first_coref = torch.stack(first_corefs)
else:
first_coref = torch.tensor([]).to(cluster_ids.device).long()
else:
# I apologize for this abuse of everything that's good about PyTorch.
# in essence, this line finds the INDEX of FIRST OCCURENCE of each NON-ZERO value
# from cluster_ids. We need this information because we use it to mark the
# special "is-start-of-ref" marker used to detect singletons.
first_coref = (cluster_ids ==
cluster_ids.unique().sort().values[1:].unsqueeze(1)
).float().topk(k=1, dim=1).indices.squeeze()
y = (y == cluster_ids.unsqueeze(1)) # True if coreferent
# For all rows with no gold antecedents setting dummy to True
y[y.sum(dim=1) == 0, 0] = True
if singletons:
# add another dummy for first coref
y = utils.add_dummy(y) # [n_words, n_cands + 2]
# for all rows that's a first coref, setting its dummy to True and unset the
# non-coref dummy to false
y[first_coref, 0] = True
y[first_coref, 1] = False
return y.to(torch.float)
@staticmethod
def _load_config(config_path: str,
section: str) -> Config:
with open(config_path, "rb") as fin:
config = tomllib.load(fin)
default_section = config["DEFAULT"]
current_section = config[section]
unknown_keys = (set(current_section.keys())
- set(default_section.keys()))
if unknown_keys:
raise ValueError(f"Unexpected config keys: {unknown_keys}")
return Config(section, **{**default_section, **current_section})
def _set_training(self, value: bool):
self._training = value
for module in self.trainable.values():
module.train(self._training)