|
|
from transformers import BertConfig, BertModel |
|
|
import torch |
|
|
import re |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from sklearn.model_selection import train_test_split, cross_validate |
|
|
import pytorch_lightning as pl |
|
|
import pandas as pd |
|
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint |
|
|
from torch.optim import AdamW |
|
|
from sklearn.metrics import f1_score |
|
|
|
|
|
MAX_LEN = 96 |
|
|
PAD_ID = 0 |
|
|
|
|
|
config = BertConfig( |
|
|
vocab_size=40, |
|
|
hidden_size=64, |
|
|
num_hidden_layers=4, |
|
|
num_attention_heads=4, |
|
|
intermediate_size=256, |
|
|
max_position_embeddings=MAX_LEN, |
|
|
type_vocab_size=4 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class MyDataset(Dataset): |
|
|
def __init__(self, df, char2idx, label2idx, is_train=True): |
|
|
super().__init__() |
|
|
print(char2idx) |
|
|
print(label2idx) |
|
|
self.is_train = is_train |
|
|
self.dataset = get_dataset3(df, char2idx, label2idx, is_train=is_train) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.dataset[idx] |
|
|
|
|
|
|
|
|
def collate_fn(self, batch): |
|
|
collated = { |
|
|
"input_ids": torch.IntTensor([(x[0] if self.is_train else x)["input_ids"] for x in batch]), |
|
|
"attention_mask": torch.Tensor([(x[0] if self.is_train else x)["attention_mask"] for x in batch]), |
|
|
"token_type_ids": torch.IntTensor([(x[0] if self.is_train else x)["token_type_ids"] for x in batch]) |
|
|
} |
|
|
if self.is_train: |
|
|
collated = collated, torch.IntTensor([x[1] for x in batch]) |
|
|
|
|
|
return collated |
|
|
|
|
|
|
|
|
def get_preprocessed_dfs(folder): |
|
|
df = pd.read_csv(f"{folder}/train_data.csv").drop_duplicates() |
|
|
df.loc[:, "Tag"] = df.Tag.apply(lambda x: "CAUS_2" if x.startswith("CAUS_") and x != "CAUS_1" else x) |
|
|
|
|
|
cats = ['FUT_INDF_3PLF', 'FUT_INDF_NEG', 'PST_INDF_PS', 'PCP_FUT_NEG', 'PCP_FUT_DEF', 'PRES_CONT', 'PRES_2SGF', 'POSS_2SGF', 'POSS_2PLF', 'NUM_APPR3', 'NUM_APPR2', 'NUM_APPR1', 'ADVV_CONT', 'ADJECTIVE', 'PST_ITER', 'PST_INDF', 'PST_EVID', 'PRES_PST', 'POSS_3SG', 'POSS_3PL', 'POSS_2SG', 'POSS_2PL', 'POSS_1SG', 'POSS_1PL', 'NUM_COLL', 'FUT_INDF', 'ADVV_SUC', 'ADVV_NEG', 'ADVV_INT', 'ADVV_ACC', 'PST_DEF', 'NUM_ORD', 'NUMERAL', 'IMP_SGF', 'IMP_PLF', 'FUT_DEF', 'PREC_1', 'PCP_PS', 'PCP_PR', 'JUS_SG', 'JUS_PL', 'IMP_SG', 'IMP_PL', 'HOR_SG', 'HOR_PL', 'DESIDE', 'CAUS_2', 'CAUS_1', 'INF_5', 'INF_4', 'INF_3', 'INF_2', 'INF_1', 'VERB', 'REFL', 'RECP', 'PRES', 'PREM', 'PERS', 'PASS', 'COND', 'COMP', '2SGF', '2PLF', 'SUC', 'OPT', 'NOM', 'NEG', 'NEG', 'LOC', 'INT', 'GEN', 'DAT', 'ACT', 'ACC', 'ABL', '3SG', '3PL', '2SG', '2PL', '1SG', '1PL', 'SG', 'PL'] |
|
|
cats = sorted([x.lower() for x in cats], key=lambda x: (len(x), x), reverse=True) |
|
|
|
|
|
for col in df.columns: |
|
|
df.loc[:, col] = df[col].apply(lambda x: x.strip().lower()) |
|
|
|
|
|
def tag2list(t): |
|
|
res = [] |
|
|
for c in cats: |
|
|
if c in t: |
|
|
res.append(c) |
|
|
t = t.replace(c, "") |
|
|
return res |
|
|
|
|
|
df.loc[:, "Tag"] = df.Tag.apply(tag2list) |
|
|
|
|
|
tdf = pd.read_csv(f"{folder}/test_data.csv") |
|
|
tdf.pop("Tag") |
|
|
for col in tdf.columns: |
|
|
tdf.loc[:, col] = tdf[col].apply(lambda x: x.strip().lower()) |
|
|
|
|
|
return {"train": df.rename(columns={x: x.lower() for x in df.columns}), "test": tdf.rename(columns={x: x.lower() for x in tdf.columns})} |
|
|
|
|
|
def get_preprocessed_dfs2(folder): |
|
|
df = pd.read_csv(f"{folder}/train_data.csv").drop_duplicates() |
|
|
df.loc[:, "Tag"] = df.Tag.apply(lambda x: "CAUS_2" if x.startswith("CAUS_") and x != "CAUS_1" else x) |
|
|
|
|
|
for col in df.columns: |
|
|
df.loc[:, col] = df[col].apply(lambda x: x.strip().lower()) |
|
|
|
|
|
tdf = pd.read_csv(f"{folder}/test_data.csv") |
|
|
tdf.pop("Tag") |
|
|
for col in tdf.columns: |
|
|
tdf.loc[:, col] = tdf[col].apply(lambda x: x.strip().lower()) |
|
|
|
|
|
return {"train": df.rename(columns={x: x.lower() for x in df.columns}), "test": tdf.rename(columns={x: x.lower() for x in tdf.columns})} |
|
|
|
|
|
def get_splits(df, test_size=0.2): |
|
|
unique_roots = df.root.unique() |
|
|
print("unique roots", len(unique_roots)) |
|
|
train, validation = train_test_split(unique_roots, test_size=test_size, random_state=2023) |
|
|
print("unique train roots", len(train)) |
|
|
print("unique validation roots", len(validation)) |
|
|
train_df = df[df.root.isin(train)] |
|
|
validation_df = df[df.root.isin(validation)] |
|
|
|
|
|
return train_df, validation_df |
|
|
|
|
|
def get_char2idx(all_splits, special_chars=("<pad>", "<s>", "</s>")): |
|
|
charset = set() |
|
|
for split, df in all_splits.items(): |
|
|
charset = charset.union("".join(df.apply(lambda r: r.root + r.affix, axis=1))) |
|
|
return {x: i for i, x in enumerate(list(special_chars) + sorted(charset))} |
|
|
|
|
|
def get_dataset(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True): |
|
|
pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])} |
|
|
|
|
|
result = [] |
|
|
|
|
|
for r in split.itertuples(): |
|
|
|
|
|
input_ids = [char2idx["<s>"], pos2idx[r.pos_word], pos2idx[r.pos_root]] |
|
|
attention_mask = [1, 1, 1] |
|
|
token_type_ids = [0, 0, 0] |
|
|
|
|
|
|
|
|
for c in r.word: |
|
|
input_ids.append(char2idx[c]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(1) |
|
|
|
|
|
for c in r.root: |
|
|
input_ids.append(char2idx[c]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(2) |
|
|
|
|
|
for c in r.affix: |
|
|
input_ids.append(char2idx[c]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(3) |
|
|
|
|
|
input_ids.append(char2idx["</s>"]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(3) |
|
|
|
|
|
input_ids = input_ids[:MAX_LEN] |
|
|
attention_mask = attention_mask[:MAX_LEN] |
|
|
token_type_ids = token_type_ids[:MAX_LEN] |
|
|
|
|
|
|
|
|
for _ in range(MAX_LEN - len(input_ids)): |
|
|
input_ids.append(char2idx["<pad>"]) |
|
|
attention_mask.append(0) |
|
|
token_type_ids.append(3) |
|
|
|
|
|
result.append( |
|
|
{ |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"token_type_ids": token_type_ids, |
|
|
} |
|
|
) |
|
|
|
|
|
if is_train: |
|
|
result[-1] = (result[-1], [0 for _ in range(len(label2idx))]) |
|
|
for tag in r.tag: |
|
|
result[-1][-1][label2idx[tag]] = 1 |
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
def get_dataset3(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True): |
|
|
pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])} |
|
|
|
|
|
result = [] |
|
|
|
|
|
for xs, r in enumerate(split.itertuples()): |
|
|
|
|
|
input_ids = [char2idx["<s>"], pos2idx[r.pos_root]] |
|
|
attention_mask = [1, 1] |
|
|
token_type_ids = [0, 0] |
|
|
|
|
|
for c in r.root: |
|
|
input_ids.append(char2idx[c]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(1) |
|
|
|
|
|
for c in r.affix: |
|
|
input_ids.append(char2idx[c]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(2) |
|
|
|
|
|
input_ids.append(char2idx["</s>"]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(2) |
|
|
|
|
|
input_ids = input_ids[:MAX_LEN] |
|
|
attention_mask = attention_mask[:MAX_LEN] |
|
|
token_type_ids = token_type_ids[:MAX_LEN] |
|
|
|
|
|
|
|
|
for _ in range(MAX_LEN - len(input_ids)): |
|
|
input_ids.append(char2idx["<pad>"]) |
|
|
attention_mask.append(0) |
|
|
token_type_ids.append(2) |
|
|
|
|
|
result.append( |
|
|
{ |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"token_type_ids": token_type_ids, |
|
|
} |
|
|
) |
|
|
|
|
|
if is_train: |
|
|
result[-1] = (result[-1], label2idx[r.tag]) |
|
|
|
|
|
if xs + 1 % 1000 == 0: |
|
|
print(input_ids) |
|
|
print(attention_mask) |
|
|
print(token_type_ids) |
|
|
|
|
|
return result |
|
|
|
|
|
def get_dataset2(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True): |
|
|
pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])} |
|
|
|
|
|
result = [] |
|
|
|
|
|
for xs, r in enumerate(split.itertuples()): |
|
|
|
|
|
input_ids = [char2idx["<s>"], pos2idx[r.pos_word], pos2idx[r.pos_root]] |
|
|
attention_mask = [1, 1, 1] |
|
|
token_type_ids = [0, 0, 0] |
|
|
|
|
|
|
|
|
for c in r.word: |
|
|
input_ids.append(char2idx[c]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(1) |
|
|
|
|
|
for c in r.root: |
|
|
input_ids.append(char2idx[c]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(2) |
|
|
|
|
|
for c in r.affix: |
|
|
input_ids.append(char2idx[c]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(3) |
|
|
|
|
|
input_ids.append(char2idx["</s>"]) |
|
|
attention_mask.append(1) |
|
|
token_type_ids.append(3) |
|
|
|
|
|
input_ids = input_ids[:MAX_LEN] |
|
|
attention_mask = attention_mask[:MAX_LEN] |
|
|
token_type_ids = token_type_ids[:MAX_LEN] |
|
|
|
|
|
|
|
|
for _ in range(MAX_LEN - len(input_ids)): |
|
|
input_ids.append(char2idx["<pad>"]) |
|
|
attention_mask.append(0) |
|
|
token_type_ids.append(3) |
|
|
|
|
|
result.append( |
|
|
{ |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"token_type_ids": token_type_ids, |
|
|
} |
|
|
) |
|
|
|
|
|
if is_train: |
|
|
result[-1] = (result[-1], label2idx[r.tag]) |
|
|
|
|
|
if xs + 1 % 10000 == 0: |
|
|
print(input_ids) |
|
|
print(attention_mask) |
|
|
print(token_type_ids) |
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
def train_model(epochs=100, batch_size=400, data_folder="../Downloads/"): |
|
|
dfs = get_preprocessed_dfs2(data_folder) |
|
|
train, val = get_splits(dfs["train"]) |
|
|
char2idx = get_char2idx(dfs) |
|
|
|
|
|
label2idx = {l: i for i, l in enumerate(dfs["train"].tag.unique())} |
|
|
|
|
|
model = MyModel2(config, label2idx, char2idx, 0.5) |
|
|
checkpoint_callback = ModelCheckpoint( |
|
|
dirpath="fmicro_weights", |
|
|
save_top_k=3, |
|
|
monitor="fmicro", |
|
|
mode="max", |
|
|
filename="{epoch}-{step}", |
|
|
) |
|
|
trainer = pl.Trainer( |
|
|
deterministic=True, |
|
|
max_epochs=epochs, |
|
|
callbacks=[EarlyStopping(monitor="fmicro", mode="max"), checkpoint_callback], |
|
|
log_every_n_steps=30, |
|
|
) |
|
|
|
|
|
train_dataset = MyDataset(train, char2idx, label2idx) |
|
|
validation_dataset = MyDataset(val, char2idx, label2idx) |
|
|
trainer.fit(model, DataLoader(train_dataset, batch_size=400, collate_fn=train_dataset.collate_fn), DataLoader(validation_dataset, batch_size=400, collate_fn=validation_dataset.collate_fn)) |
|
|
|
|
|
best_model_path = [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)][0].best_model_path |
|
|
|
|
|
model.load_state_dict(torch.load(best_model_path)["state_dict"]) |
|
|
|
|
|
return model, train, val, dfs["test"] |
|
|
|
|
|
|
|
|
class MyModel(pl.LightningModule): |
|
|
def __init__(self, config, label2idx, threshold, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.threshold = threshold |
|
|
self.char2idx = char2idx |
|
|
self.label2idx = label2idx |
|
|
self.idx2label = {i: l for l, i in label2idx.items()} |
|
|
self.bert = BertModel(config) |
|
|
self.dropout = torch.nn.Dropout(0.3) |
|
|
self.proj = torch.nn.Linear(config.hidden_size, len(label2idx)) |
|
|
|
|
|
|
|
|
def common_step(self, batch): |
|
|
X, _ = batch |
|
|
hidden = self.bert(**X)[1] |
|
|
return self.proj(self.dropout(hidden)) |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
|
|
logits = self.common_step(batch) |
|
|
loss = torch.nn.BCEWithLogitsLoss()(logits, batch[1].float()) |
|
|
self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) |
|
|
|
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
|
|
|
|
|
logits = self.common_step(batch) |
|
|
|
|
|
|
|
|
loss = torch.nn.BCEWithLogitsLoss()(logits, batch[1].float()) |
|
|
self.log("loss", loss.mean(), prog_bar=True, on_epoch=True) |
|
|
|
|
|
return logits, loss |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
return self.common_step((batch, [])) |
|
|
|
|
|
def forward(self, batch, batch_idx): |
|
|
return self.common_step((batch, [])) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
return AdamW(params=self.parameters()) |
|
|
|
|
|
class MyModel2(pl.LightningModule): |
|
|
def __init__(self, config, label2idx, char2idx, threshold, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.threshold = threshold |
|
|
self.char2idx = char2idx |
|
|
self.fscore = 0.0 |
|
|
self.label2idx = label2idx |
|
|
self.idx2label = {i: l for l, i in label2idx.items()} |
|
|
self.bert = BertModel(config) |
|
|
self.dropout = torch.nn.Dropout(0.3) |
|
|
self.proj = torch.nn.Linear(config.hidden_size, len(label2idx)) |
|
|
|
|
|
|
|
|
def common_step(self, batch): |
|
|
X, _ = batch |
|
|
hidden = self.bert(**X)[1] |
|
|
return self.proj(self.dropout(hidden)) |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
|
|
logits = self.common_step(batch) |
|
|
loss = torch.nn.CrossEntropyLoss()(logits.view(-1, len(self.label2idx)), batch[1].view(-1).long()) |
|
|
self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) |
|
|
|
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
|
|
|
|
|
logits = self.common_step(batch) |
|
|
|
|
|
|
|
|
loss = torch.nn.CrossEntropyLoss()(logits.view(-1, len(self.label2idx)), batch[1].view(-1).long()) |
|
|
for p in logits: |
|
|
self.predos.append(self.idx2label[p.argmax().cpu().item()]) |
|
|
for t in batch[1]: |
|
|
self.trues.append(self.idx2label[t.cpu().item()]) |
|
|
self.log("loss", loss.mean(), prog_bar=True, on_epoch=True) |
|
|
self.log("fmicro", self.fscore, prog_bar=True, on_epoch=True) |
|
|
|
|
|
return logits, loss |
|
|
|
|
|
def on_validation_start(self): |
|
|
self.predos = [] |
|
|
self.trues = [] |
|
|
|
|
|
def on_validation_end(self): |
|
|
self.fscore = f1_score(self.trues, self.predos, average="micro") |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
return self.common_step((batch, [])) |
|
|
|
|
|
def forward(self, batch, batch_idx): |
|
|
return self.common_step((batch, [])) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
return AdamW(params=self.parameters()) |
|
|
|
|
|
def predict(self, dataloader): |
|
|
pass |
|
|
|
|
|
|