from transformers import BertConfig, BertModel import torch import re from torch.utils.data import DataLoader, Dataset from sklearn.model_selection import train_test_split, cross_validate import pytorch_lightning as pl import pandas as pd from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from torch.optim import AdamW from sklearn.metrics import f1_score MAX_LEN = 96 PAD_ID = 0 config = BertConfig( vocab_size=40, hidden_size=64, num_hidden_layers=4, num_attention_heads=4, intermediate_size=256, max_position_embeddings=MAX_LEN, type_vocab_size=4 ) class MyDataset(Dataset): def __init__(self, df, char2idx, label2idx, is_train=True): super().__init__() print(char2idx) print(label2idx) self.is_train = is_train self.dataset = get_dataset3(df, char2idx, label2idx, is_train=is_train) def __len__(self): return len(self.dataset) def __getitem__(self, idx): return self.dataset[idx] def collate_fn(self, batch): collated = { "input_ids": torch.IntTensor([(x[0] if self.is_train else x)["input_ids"] for x in batch]), "attention_mask": torch.Tensor([(x[0] if self.is_train else x)["attention_mask"] for x in batch]), "token_type_ids": torch.IntTensor([(x[0] if self.is_train else x)["token_type_ids"] for x in batch]) } if self.is_train: collated = collated, torch.IntTensor([x[1] for x in batch]) return collated def get_preprocessed_dfs(folder): df = pd.read_csv(f"{folder}/train_data.csv").drop_duplicates() df.loc[:, "Tag"] = df.Tag.apply(lambda x: "CAUS_2" if x.startswith("CAUS_") and x != "CAUS_1" else x) cats = ['FUT_INDF_3PLF', 'FUT_INDF_NEG', 'PST_INDF_PS', 'PCP_FUT_NEG', 'PCP_FUT_DEF', 'PRES_CONT', 'PRES_2SGF', 'POSS_2SGF', 'POSS_2PLF', 'NUM_APPR3', 'NUM_APPR2', 'NUM_APPR1', 'ADVV_CONT', 'ADJECTIVE', 'PST_ITER', 'PST_INDF', 'PST_EVID', 'PRES_PST', 'POSS_3SG', 'POSS_3PL', 'POSS_2SG', 'POSS_2PL', 'POSS_1SG', 'POSS_1PL', 'NUM_COLL', 'FUT_INDF', 'ADVV_SUC', 'ADVV_NEG', 'ADVV_INT', 'ADVV_ACC', 'PST_DEF', 'NUM_ORD', 'NUMERAL', 'IMP_SGF', 'IMP_PLF', 'FUT_DEF', 'PREC_1', 'PCP_PS', 'PCP_PR', 'JUS_SG', 'JUS_PL', 'IMP_SG', 'IMP_PL', 'HOR_SG', 'HOR_PL', 'DESIDE', 'CAUS_2', 'CAUS_1', 'INF_5', 'INF_4', 'INF_3', 'INF_2', 'INF_1', 'VERB', 'REFL', 'RECP', 'PRES', 'PREM', 'PERS', 'PASS', 'COND', 'COMP', '2SGF', '2PLF', 'SUC', 'OPT', 'NOM', 'NEG', 'NEG', 'LOC', 'INT', 'GEN', 'DAT', 'ACT', 'ACC', 'ABL', '3SG', '3PL', '2SG', '2PL', '1SG', '1PL', 'SG', 'PL'] cats = sorted([x.lower() for x in cats], key=lambda x: (len(x), x), reverse=True) for col in df.columns: df.loc[:, col] = df[col].apply(lambda x: x.strip().lower()) def tag2list(t): res = [] for c in cats: if c in t: res.append(c) t = t.replace(c, "") return res df.loc[:, "Tag"] = df.Tag.apply(tag2list) tdf = pd.read_csv(f"{folder}/test_data.csv") tdf.pop("Tag") for col in tdf.columns: tdf.loc[:, col] = tdf[col].apply(lambda x: x.strip().lower()) return {"train": df.rename(columns={x: x.lower() for x in df.columns}), "test": tdf.rename(columns={x: x.lower() for x in tdf.columns})} def get_preprocessed_dfs2(folder): df = pd.read_csv(f"{folder}/train_data.csv").drop_duplicates() df.loc[:, "Tag"] = df.Tag.apply(lambda x: "CAUS_2" if x.startswith("CAUS_") and x != "CAUS_1" else x) for col in df.columns: df.loc[:, col] = df[col].apply(lambda x: x.strip().lower()) tdf = pd.read_csv(f"{folder}/test_data.csv") tdf.pop("Tag") for col in tdf.columns: tdf.loc[:, col] = tdf[col].apply(lambda x: x.strip().lower()) return {"train": df.rename(columns={x: x.lower() for x in df.columns}), "test": tdf.rename(columns={x: x.lower() for x in tdf.columns})} def get_splits(df, test_size=0.2): unique_roots = df.root.unique() print("unique roots", len(unique_roots)) train, validation = train_test_split(unique_roots, test_size=test_size, random_state=2023) print("unique train roots", len(train)) print("unique validation roots", len(validation)) train_df = df[df.root.isin(train)] validation_df = df[df.root.isin(validation)] return train_df, validation_df def get_char2idx(all_splits, special_chars=("", "", "")): charset = set() for split, df in all_splits.items(): charset = charset.union("".join(df.apply(lambda r: r.root + r.affix, axis=1))) return {x: i for i, x in enumerate(list(special_chars) + sorted(charset))} def get_dataset(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True): pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])} result = [] for r in split.itertuples(): input_ids = [char2idx[""], pos2idx[r.pos_word], pos2idx[r.pos_root]] attention_mask = [1, 1, 1] token_type_ids = [0, 0, 0] # print(r.word, r.root, r.affix) for c in r.word: input_ids.append(char2idx[c]) attention_mask.append(1) token_type_ids.append(1) for c in r.root: input_ids.append(char2idx[c]) attention_mask.append(1) token_type_ids.append(2) for c in r.affix: input_ids.append(char2idx[c]) attention_mask.append(1) token_type_ids.append(3) input_ids.append(char2idx[""]) attention_mask.append(1) token_type_ids.append(3) input_ids = input_ids[:MAX_LEN] attention_mask = attention_mask[:MAX_LEN] token_type_ids = token_type_ids[:MAX_LEN] for _ in range(MAX_LEN - len(input_ids)): input_ids.append(char2idx[""]) attention_mask.append(0) token_type_ids.append(3) result.append( { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, } ) if is_train: result[-1] = (result[-1], [0 for _ in range(len(label2idx))]) for tag in r.tag: result[-1][-1][label2idx[tag]] = 1 return result def get_dataset3(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True): pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])} result = [] for xs, r in enumerate(split.itertuples()): input_ids = [char2idx[""], pos2idx[r.pos_root]] attention_mask = [1, 1] token_type_ids = [0, 0] for c in r.root: input_ids.append(char2idx[c]) attention_mask.append(1) token_type_ids.append(1) for c in r.affix: input_ids.append(char2idx[c]) attention_mask.append(1) token_type_ids.append(2) input_ids.append(char2idx[""]) attention_mask.append(1) token_type_ids.append(2) input_ids = input_ids[:MAX_LEN] attention_mask = attention_mask[:MAX_LEN] token_type_ids = token_type_ids[:MAX_LEN] for _ in range(MAX_LEN - len(input_ids)): input_ids.append(char2idx[""]) attention_mask.append(0) token_type_ids.append(2) result.append( { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, } ) if is_train: result[-1] = (result[-1], label2idx[r.tag]) if xs + 1 % 1000 == 0: print(input_ids) print(attention_mask) print(token_type_ids) return result def get_dataset2(split, char2idx, label2idx, max_len=MAX_LEN, is_train=True): pos2idx = {x: i for i, x in enumerate(["noun", "verb", "num", "adjective"])} result = [] for xs, r in enumerate(split.itertuples()): input_ids = [char2idx[""], pos2idx[r.pos_word], pos2idx[r.pos_root]] attention_mask = [1, 1, 1] token_type_ids = [0, 0, 0] # print(r.word, r.root, r.affix) for c in r.word: input_ids.append(char2idx[c]) attention_mask.append(1) token_type_ids.append(1) for c in r.root: input_ids.append(char2idx[c]) attention_mask.append(1) token_type_ids.append(2) for c in r.affix: input_ids.append(char2idx[c]) attention_mask.append(1) token_type_ids.append(3) input_ids.append(char2idx[""]) attention_mask.append(1) token_type_ids.append(3) input_ids = input_ids[:MAX_LEN] attention_mask = attention_mask[:MAX_LEN] token_type_ids = token_type_ids[:MAX_LEN] for _ in range(MAX_LEN - len(input_ids)): input_ids.append(char2idx[""]) attention_mask.append(0) token_type_ids.append(3) result.append( { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, } ) if is_train: result[-1] = (result[-1], label2idx[r.tag]) if xs + 1 % 10000 == 0: print(input_ids) print(attention_mask) print(token_type_ids) return result def train_model(epochs=100, batch_size=400, data_folder="../Downloads/"): dfs = get_preprocessed_dfs2(data_folder) train, val = get_splits(dfs["train"]) char2idx = get_char2idx(dfs) # label2idx = {j: i for i, j in enumerate(sorted(set([x for y in dfs["train"].tag for x in y])))} label2idx = {l: i for i, l in enumerate(dfs["train"].tag.unique())} model = MyModel2(config, label2idx, char2idx, 0.5) checkpoint_callback = ModelCheckpoint( dirpath="fmicro_weights", save_top_k=3, monitor="fmicro", mode="max", filename="{epoch}-{step}", ) trainer = pl.Trainer( deterministic=True, max_epochs=epochs, callbacks=[EarlyStopping(monitor="fmicro", mode="max"), checkpoint_callback], log_every_n_steps=30, ) train_dataset = MyDataset(train, char2idx, label2idx) validation_dataset = MyDataset(val, char2idx, label2idx) trainer.fit(model, DataLoader(train_dataset, batch_size=400, collate_fn=train_dataset.collate_fn), DataLoader(validation_dataset, batch_size=400, collate_fn=validation_dataset.collate_fn)) best_model_path = [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)][0].best_model_path model.load_state_dict(torch.load(best_model_path)["state_dict"]) return model, train, val, dfs["test"] class MyModel(pl.LightningModule): def __init__(self, config, label2idx, threshold, *args, **kwargs): super().__init__(*args, **kwargs) self.threshold = threshold self.char2idx = char2idx self.label2idx = label2idx self.idx2label = {i: l for l, i in label2idx.items()} self.bert = BertModel(config) self.dropout = torch.nn.Dropout(0.3) self.proj = torch.nn.Linear(config.hidden_size, len(label2idx)) def common_step(self, batch): X, _ = batch hidden = self.bert(**X)[1] return self.proj(self.dropout(hidden)) def training_step(self, batch, batch_idx): # print(batch) logits = self.common_step(batch) loss = torch.nn.BCEWithLogitsLoss()(logits, batch[1].float()) self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): # print(batch[0]["input_ids"]) # print(batch[0]["token_type_ids"]) logits = self.common_step(batch) # print(logits) # print(batch[1]) loss = torch.nn.BCEWithLogitsLoss()(logits, batch[1].float()) self.log("loss", loss.mean(), prog_bar=True, on_epoch=True) return logits, loss def test_step(self, batch, batch_idx): return self.common_step((batch, [])) def forward(self, batch, batch_idx): return self.common_step((batch, [])) def configure_optimizers(self): return AdamW(params=self.parameters()) class MyModel2(pl.LightningModule): def __init__(self, config, label2idx, char2idx, threshold, *args, **kwargs): super().__init__(*args, **kwargs) self.threshold = threshold self.char2idx = char2idx self.fscore = 0.0 self.label2idx = label2idx self.idx2label = {i: l for l, i in label2idx.items()} self.bert = BertModel(config) self.dropout = torch.nn.Dropout(0.3) self.proj = torch.nn.Linear(config.hidden_size, len(label2idx)) def common_step(self, batch): X, _ = batch hidden = self.bert(**X)[1] return self.proj(self.dropout(hidden)) def training_step(self, batch, batch_idx): # print(batch) logits = self.common_step(batch) loss = torch.nn.CrossEntropyLoss()(logits.view(-1, len(self.label2idx)), batch[1].view(-1).long()) self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): # print(batch[0]["input_ids"]) # print(batch[0]["token_type_ids"]) logits = self.common_step(batch) # print(logits) # print(batch[1]) loss = torch.nn.CrossEntropyLoss()(logits.view(-1, len(self.label2idx)), batch[1].view(-1).long()) for p in logits: self.predos.append(self.idx2label[p.argmax().cpu().item()]) for t in batch[1]: self.trues.append(self.idx2label[t.cpu().item()]) self.log("loss", loss.mean(), prog_bar=True, on_epoch=True) self.log("fmicro", self.fscore, prog_bar=True, on_epoch=True) return logits, loss def on_validation_start(self): self.predos = [] self.trues = [] def on_validation_end(self): self.fscore = f1_score(self.trues, self.predos, average="micro") def test_step(self, batch, batch_idx): return self.common_step((batch, [])) def forward(self, batch, batch_idx): return self.common_step((batch, [])) def configure_optimizers(self): return AdamW(params=self.parameters()) def predict(self, dataloader): pass