morphological_analysis / bert_model_variant.py
Zarinaaa's picture
Special for morphological analysis
7486641
from transformers import BertConfig, BertModel
import torch
import re
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, cross_validate
import pytorch_lightning as pl
import pandas as pd
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim import AdamW
from sklearn.metrics import f1_score
MAX_LEN = 96
PAD_ID = 0
config = BertConfig(
vocab_size=40,
hidden_size=64,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=256,
max_position_embeddings=MAX_LEN,
type_vocab_size=4
)
class MyDataset(Dataset):
def __init__(self, df, char2idx, label2idx, is_train=True):
super().__init__()
print(char2idx)
print(label2idx)
self.is_train = is_train
self.dataset = get_dataset3(df, char2idx, label2idx, is_train=is_train)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx]
def collate_fn(self, batch):
collated = {
"input_ids": torch.IntTensor([(x[0] if self.is_train else x)["input_ids"] for x in batch]),
"attention_mask": torch.Tensor([(x[0] if self.is_train else x)["attention_mask"] for x in batch]),
"token_type_ids": torch.IntTensor([(x[0] if self.is_train else x)["token_type_ids"] for x in batch])
}
if self.is_train:
collated = collated, torch.IntTensor([x[1] for x in batch])
return collated
def get_preprocessed_dfs(folder):
df = pd.read_csv(f"{folder}/train_data.csv").drop_duplicates()
df.loc[:, "Tag"] = df.Tag.apply(lambda x: "CAUS_2" if x.startswith("CAUS_") and x != "CAUS_1" else x)
cats = ['FUT_INDF_3PLF', 'FUT_INDF_NEG', 'PST_INDF_PS', 'PCP_FUT_NEG', 'PCP_FUT_DEF', 'PRES_CONT', 'PRES_2SGF', 'POSS_2SGF', 'POSS_2PLF', 'NUM_APPR3', 'NUM_APPR2', 'NUM_APPR1', 'ADVV_CONT', 'ADJECTIVE', 'PST_ITER', 'PST_INDF', 'PST_EVID', 'PRES_PST', 'POSS_3SG', 'POSS_3PL', 'POSS_2SG', 'POSS_2PL', 'POSS_1SG', 'POSS_1PL', 'NUM_COLL', 'FUT_INDF', 'ADVV_SUC', 'ADVV_NEG', 'ADVV_INT', 'ADVV_ACC', 'PST_DEF', 'NUM_ORD', 'NUMERAL', 'IMP_SGF', 'IMP_PLF', 'FUT_DEF', 'PREC_1', 'PCP_PS', 'PCP_PR', 'JUS_SG', 'JUS_PL', 'IMP_SG', 'IMP_PL', 'HOR_SG', 'HOR_PL', 'DESIDE', 'CAUS_2', 'CAUS_1', 'INF_5', 'INF_4', 'INF_3', 'INF_2', 'INF_1', 'VERB', 'REFL', 'RECP', 'PRES', 'PREM', 'PERS', 'PASS', 'COND', 'COMP', '2SGF', '2PLF', 'SUC', 'OPT', 'NOM', 'NEG', 'NEG', 'LOC', 'INT', 'GEN', 'DAT', 'ACT', 'ACC', 'ABL', '3SG', '3PL', '2SG', '2PL', '1SG', '1PL', 'SG', 'PL']
cats = sorted([x.lower() for x in cats], key=lambda x: (len(x), x), reverse=True)
for col in df.columns:
df.loc[:, col] = df[col].apply(lambda x: x.strip().lower())
def tag2list(t):
res = []
for c in cats:
if c in t:
res.append(c)
t = t.replace(c, "")
return res
df.loc[:, "Tag"] = df.Tag.apply(tag2list)
tdf = pd.read_csv(f"{folder}/test_data.csv")
tdf.pop("Tag")
for col in tdf.columns:
tdf.loc[:, col] = tdf[col].apply(lambda x: x.strip().lower())
return {"train": df.rename(columns={x: x.lower() for x in df.columns}), "test": tdf.rename(columns={x: x.lower() for x in tdf.columns})}
def get_preprocessed_dfs2(folder):
df = pd.read_csv(f"{folder}/train_data.csv").drop_duplicates()
df.loc[:, "Tag"] = df.Tag.apply(lambda x: "CAUS_2" if x.startswith("CAUS_") and x != "CAUS_1" else x)
for col in df.columns:
df.loc[:, col] = df[col].apply(lambda x: x.strip().lower())
tdf = pd.read_csv(f"{folder}/test_data.csv")
tdf.pop("Tag")
for col in tdf.columns:
tdf.loc[:, col] = tdf[col].apply(lambda x: x.strip().lower())
return {"train": df.rename(columns={x: x.lower() for x in df.columns}), "test": tdf.rename(columns={x: x.lower() for x in tdf.columns})}
def get_splits(df, test_size=0.2):
unique_roots = df.root.unique()
print("unique roots", len(unique_roots))
train, validation = train_test_split(unique_roots, test_size=test_size, random_state=2023)
print("unique train roots", len(train))
print("unique validation roots", len(validation))
train_df = df[df.root.isin(train)]
validation_df = df[df.root.isin(validation)]
return train_df, validation_df
def get_char2idx(all_splits, special_chars=("<pad>", "<s>", "</s>")):
charset = set()
for split, df in all_splits.items():
charset = charset.union("".join(df.apply(lambda r: r.root + r.affix, axis=1)))
return {x: i for i, x in enumerate(list(special_chars) + sorted(charset))}
def get_dataset(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True):
pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])}
result = []
for r in split.itertuples():
input_ids = [char2idx["<s>"], pos2idx[r.pos_word], pos2idx[r.pos_root]]
attention_mask = [1, 1, 1]
token_type_ids = [0, 0, 0]
# print(r.word, r.root, r.affix)
for c in r.word:
input_ids.append(char2idx[c])
attention_mask.append(1)
token_type_ids.append(1)
for c in r.root:
input_ids.append(char2idx[c])
attention_mask.append(1)
token_type_ids.append(2)
for c in r.affix:
input_ids.append(char2idx[c])
attention_mask.append(1)
token_type_ids.append(3)
input_ids.append(char2idx["</s>"])
attention_mask.append(1)
token_type_ids.append(3)
input_ids = input_ids[:MAX_LEN]
attention_mask = attention_mask[:MAX_LEN]
token_type_ids = token_type_ids[:MAX_LEN]
for _ in range(MAX_LEN - len(input_ids)):
input_ids.append(char2idx["<pad>"])
attention_mask.append(0)
token_type_ids.append(3)
result.append(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
if is_train:
result[-1] = (result[-1], [0 for _ in range(len(label2idx))])
for tag in r.tag:
result[-1][-1][label2idx[tag]] = 1
return result
def get_dataset3(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True):
pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])}
result = []
for xs, r in enumerate(split.itertuples()):
input_ids = [char2idx["<s>"], pos2idx[r.pos_root]]
attention_mask = [1, 1]
token_type_ids = [0, 0]
for c in r.root:
input_ids.append(char2idx[c])
attention_mask.append(1)
token_type_ids.append(1)
for c in r.affix:
input_ids.append(char2idx[c])
attention_mask.append(1)
token_type_ids.append(2)
input_ids.append(char2idx["</s>"])
attention_mask.append(1)
token_type_ids.append(2)
input_ids = input_ids[:MAX_LEN]
attention_mask = attention_mask[:MAX_LEN]
token_type_ids = token_type_ids[:MAX_LEN]
for _ in range(MAX_LEN - len(input_ids)):
input_ids.append(char2idx["<pad>"])
attention_mask.append(0)
token_type_ids.append(2)
result.append(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
if is_train:
result[-1] = (result[-1], label2idx[r.tag])
if xs + 1 % 1000 == 0:
print(input_ids)
print(attention_mask)
print(token_type_ids)
return result
def get_dataset2(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True):
pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])}
result = []
for xs, r in enumerate(split.itertuples()):
input_ids = [char2idx["<s>"], pos2idx[r.pos_word], pos2idx[r.pos_root]]
attention_mask = [1, 1, 1]
token_type_ids = [0, 0, 0]
# print(r.word, r.root, r.affix)
for c in r.word:
input_ids.append(char2idx[c])
attention_mask.append(1)
token_type_ids.append(1)
for c in r.root:
input_ids.append(char2idx[c])
attention_mask.append(1)
token_type_ids.append(2)
for c in r.affix:
input_ids.append(char2idx[c])
attention_mask.append(1)
token_type_ids.append(3)
input_ids.append(char2idx["</s>"])
attention_mask.append(1)
token_type_ids.append(3)
input_ids = input_ids[:MAX_LEN]
attention_mask = attention_mask[:MAX_LEN]
token_type_ids = token_type_ids[:MAX_LEN]
for _ in range(MAX_LEN - len(input_ids)):
input_ids.append(char2idx["<pad>"])
attention_mask.append(0)
token_type_ids.append(3)
result.append(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
if is_train:
result[-1] = (result[-1], label2idx[r.tag])
if xs + 1 % 10000 == 0:
print(input_ids)
print(attention_mask)
print(token_type_ids)
return result
def train_model(epochs=100, batch_size=400, data_folder="../Downloads/"):
dfs = get_preprocessed_dfs2(data_folder)
train, val = get_splits(dfs["train"])
char2idx = get_char2idx(dfs)
# label2idx = {j: i for i, j in enumerate(sorted(set([x for y in dfs["train"].tag for x in y])))}
label2idx = {l: i for i, l in enumerate(dfs["train"].tag.unique())}
model = MyModel2(config, label2idx, char2idx, 0.5)
checkpoint_callback = ModelCheckpoint(
dirpath="fmicro_weights",
save_top_k=3,
monitor="fmicro",
mode="max",
filename="{epoch}-{step}",
)
trainer = pl.Trainer(
deterministic=True,
max_epochs=epochs,
callbacks=[EarlyStopping(monitor="fmicro", mode="max"), checkpoint_callback],
log_every_n_steps=30,
)
train_dataset = MyDataset(train, char2idx, label2idx)
validation_dataset = MyDataset(val, char2idx, label2idx)
trainer.fit(model, DataLoader(train_dataset, batch_size=400, collate_fn=train_dataset.collate_fn), DataLoader(validation_dataset, batch_size=400, collate_fn=validation_dataset.collate_fn))
best_model_path = [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)][0].best_model_path
model.load_state_dict(torch.load(best_model_path)["state_dict"])
return model, train, val, dfs["test"]
class MyModel(pl.LightningModule):
def __init__(self, config, label2idx, threshold, *args, **kwargs):
super().__init__(*args, **kwargs)
self.threshold = threshold
self.char2idx = char2idx
self.label2idx = label2idx
self.idx2label = {i: l for l, i in label2idx.items()}
self.bert = BertModel(config)
self.dropout = torch.nn.Dropout(0.3)
self.proj = torch.nn.Linear(config.hidden_size, len(label2idx))
def common_step(self, batch):
X, _ = batch
hidden = self.bert(**X)[1]
return self.proj(self.dropout(hidden))
def training_step(self, batch, batch_idx):
# print(batch)
logits = self.common_step(batch)
loss = torch.nn.BCEWithLogitsLoss()(logits, batch[1].float())
self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
# print(batch[0]["input_ids"])
# print(batch[0]["token_type_ids"])
logits = self.common_step(batch)
# print(logits)
# print(batch[1])
loss = torch.nn.BCEWithLogitsLoss()(logits, batch[1].float())
self.log("loss", loss.mean(), prog_bar=True, on_epoch=True)
return logits, loss
def test_step(self, batch, batch_idx):
return self.common_step((batch, []))
def forward(self, batch, batch_idx):
return self.common_step((batch, []))
def configure_optimizers(self):
return AdamW(params=self.parameters())
class MyModel2(pl.LightningModule):
def __init__(self, config, label2idx, char2idx, threshold, *args, **kwargs):
super().__init__(*args, **kwargs)
self.threshold = threshold
self.char2idx = char2idx
self.fscore = 0.0
self.label2idx = label2idx
self.idx2label = {i: l for l, i in label2idx.items()}
self.bert = BertModel(config)
self.dropout = torch.nn.Dropout(0.3)
self.proj = torch.nn.Linear(config.hidden_size, len(label2idx))
def common_step(self, batch):
X, _ = batch
hidden = self.bert(**X)[1]
return self.proj(self.dropout(hidden))
def training_step(self, batch, batch_idx):
# print(batch)
logits = self.common_step(batch)
loss = torch.nn.CrossEntropyLoss()(logits.view(-1, len(self.label2idx)), batch[1].view(-1).long())
self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
# print(batch[0]["input_ids"])
# print(batch[0]["token_type_ids"])
logits = self.common_step(batch)
# print(logits)
# print(batch[1])
loss = torch.nn.CrossEntropyLoss()(logits.view(-1, len(self.label2idx)), batch[1].view(-1).long())
for p in logits:
self.predos.append(self.idx2label[p.argmax().cpu().item()])
for t in batch[1]:
self.trues.append(self.idx2label[t.cpu().item()])
self.log("loss", loss.mean(), prog_bar=True, on_epoch=True)
self.log("fmicro", self.fscore, prog_bar=True, on_epoch=True)
return logits, loss
def on_validation_start(self):
self.predos = []
self.trues = []
def on_validation_end(self):
self.fscore = f1_score(self.trues, self.predos, average="micro")
def test_step(self, batch, batch_idx):
return self.common_step((batch, []))
def forward(self, batch, batch_idx):
return self.common_step((batch, []))
def configure_optimizers(self):
return AdamW(params=self.parameters())
def predict(self, dataloader):
pass