| import argparse
|
| import os
|
| import random
|
| import pickle as pkl
|
|
|
| import epitran
|
| import numpy as np
|
| from epitran.backoff import Backoff
|
| from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
|
| from torch.utils.data.distributed import DistributedSampler
|
| import json
|
| import logging
|
| import torch
|
| from tqdm import tqdm
|
| from sentence_transformers import SentenceTransformer
|
| from transformers import get_linear_schedule_with_warmup, AutoTokenizer
|
| from utils.acc_and_f1 import cal_acc_and_f1
|
| from model import SentencesPairExtract
|
|
|
| logger = logging.getLogger('parl_sent_pair_extract')
|
| logger.setLevel(logging.INFO)
|
|
|
|
|
| class TextDataset(Dataset):
|
| def __init__(self, plm, file_path=None, file_name=None):
|
| self.plm = plm
|
| self.examples = []
|
|
|
| pkl_path = os.path.join(file_path, file_name[:-5] + "_binarized.pkl")
|
| if os.path.exists(pkl_path):
|
| binarized_file = open(pkl_path, "rb")
|
| self.examples = pkl.load(binarized_file)
|
| del self.plm
|
| return
|
|
|
| with open(os.path.join(file_path, file_name), 'r', encoding='utf-8') as f:
|
| data = json.load(f)
|
|
|
| bar = tqdm(data)
|
| bar.set_description("GEN TextDataset", refresh=True)
|
| for js in bar:
|
| label = js['label'] if js.get('label', None) else 0
|
| src = js['src']
|
| tgt = js['tgt']
|
| self.examples.append(
|
| [torch.from_numpy(self.plm.encode(src)), torch.from_numpy(self.plm.encode(tgt)), torch.tensor(label)])
|
| del self.plm
|
|
|
| binarized_file = open(pkl_path, "wb")
|
| pkl.dump(self.examples, binarized_file)
|
| return
|
|
|
| def __len__(self):
|
| return len(self.examples)
|
|
|
| def __getitem__(self, i):
|
| """ return src_vec tgt_vec label"""
|
| return self.examples[i][0], \
|
| self.examples[i][1], \
|
| self.examples[i][2]
|
|
|
|
|
| def IPADataset(lang, file_path=None):
|
| lang2lang_code = {"lo": "lao-Laoo",
|
| "th": 'tha-Thai'}
|
|
|
| if lang == "mix":
|
| from epitran.backoff import Backoff
|
| epi = Backoff([lang2lang_code["lo"], lang2lang_code["th"]])
|
| else:
|
| epi = epitran.Epitran(lang2lang_code[lang])
|
|
|
| with open(file_path, 'r', encoding='utf-8') as f:
|
| data = json.load(f)
|
|
|
| bar = tqdm(data)
|
| bar.set_description("text", refresh=True)
|
| examples = []
|
| for js in bar:
|
| tgt = js['tgt']
|
| IPA_tgt = epi.transliterate(tgt)
|
| examples.append(IPA_tgt)
|
| return examples
|
|
|
|
|
| def set_seed(seed):
|
| random.seed(seed)
|
| os.environ['PYHTONHASHSEED'] = str(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
| torch.cuda.manual_seed(seed)
|
| torch.backends.cudnn.deterministic = True
|
|
|
|
|
| def train(args, train_dataset, model, IPA_dataset, MODE):
|
| train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
| train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=0,
|
| pin_memory=True)
|
| args.save_steps = len(train_dataloader) if args.save_steps <= 0 else args.save_steps
|
| args.warmup_steps = len(train_dataloader) if args.warmup_steps <= 0 else args.warmup_steps
|
| args.logging_steps = len(train_dataloader)
|
|
|
| if args.max_steps > 0:
|
| t_total = args.max_steps
|
| args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps)
|
| else:
|
| t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
|
|
|
|
| no_decay = ['bias', 'LayerNorm.weight']
|
| optimizer_grouped_parameters = [
|
| {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
| 'weight_decay': args.weight_decay},
|
| {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
| ]
|
| optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
| scheduler = get_linear_schedule_with_warmup(optimizer, args.warmup_steps, t_total)
|
|
|
| model.to(args.device)
|
|
|
| if args.fp16:
|
| try:
|
| from apex import amp
|
| except ImportError:
|
| raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
| model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
|
|
| checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last.' + args.mode + str(args.seed))
|
| scheduler_last = os.path.join(checkpoint_last, 'scheduler.pt')
|
| optimizer_last = os.path.join(checkpoint_last, 'optimizer.pt')
|
| if os.path.exists(scheduler_last):
|
| scheduler.load_state_dict(torch.load(scheduler_last))
|
| if os.path.exists(optimizer_last):
|
| optimizer.load_state_dict(torch.load(optimizer_last))
|
|
|
|
|
| logger.info("***** Running training *****")
|
| logger.info(" Num examples = %d", len(train_dataset))
|
| logger.info(" Num Epochs = %d", args.num_train_epochs)
|
| logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size)
|
| logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
| args.train_batch_size * args.gradient_accumulation_steps)
|
| logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
| logger.info(" Total optimization steps = %d", t_total)
|
| global_step = args.start_step
|
| tr_loss, logging_loss, avg_loss, tr_nb, tr_num, train_loss = 0.0, 0.0, 0.0, 0, 0, 0
|
|
|
| best_results = {"acc": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0, "acc_and_f1": 0.0}
|
| model.zero_grad()
|
| logger.info(model)
|
|
|
| for idx in range(args.start_epoch, int(args.num_train_epochs)):
|
| bar = tqdm(enumerate(train_dataloader))
|
| tr_num = 0
|
| train_loss = 0
|
| for step, batch in bar:
|
| src_inputs = batch[0].to(args.device)
|
| tgt_inputs = batch[1].to(args.device)
|
| labels = batch[2].to(args.device)
|
|
|
| model.train()
|
| if MODE == 3:
|
| IPA_tgt_inputs = 123
|
| loss, predictions = model(src_inputs, tgt_inputs, labels, IPA_tgt_inputs)
|
| else:
|
| loss, predictions = model(src_inputs, tgt_inputs, labels)
|
|
|
| if args.gradient_accumulation_steps > 1:
|
| loss = loss / args.gradient_accumulation_steps
|
|
|
| if args.fp16:
|
| with amp.scale_loss(loss, optimizer) as scaled_loss:
|
| scaled_loss.backward()
|
| torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
| else:
|
| loss.backward()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
|
|
| tr_loss += loss.item()
|
| tr_num += 1
|
| train_loss += loss.item()
|
| if avg_loss == 0:
|
| avg_loss = tr_loss
|
| avg_loss = round(train_loss / tr_num, 5)
|
| bar.set_description("epoch {} step {} loss {}".format(idx, step + 1, avg_loss))
|
|
|
| if (step + 1) % args.gradient_accumulation_steps == 0:
|
| optimizer.step()
|
| optimizer.zero_grad()
|
| scheduler.step()
|
| global_step += 1
|
| avg_loss = round(np.exp((tr_loss - logging_loss) / (global_step - tr_nb)), 4)
|
| if args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
| logging_loss = tr_loss
|
| tr_nb = global_step
|
|
|
| if args.save_steps > 0 and global_step % args.save_steps == 0:
|
| if args.evaluate_during_training:
|
| results = evaluate(args, model)
|
| for key, value in results.items():
|
| logger.info(" %s = %s", key, round(value, 4))
|
|
|
| if results['f1'] >= best_results['f1']:
|
| best_results = results
|
|
|
| logger.info(" Best %s = %s", key, str(round(best_results[key], 4)))
|
|
|
|
|
| checkpoint_prefix = 'checkpoint-best-aver.' + args.mode + str(
|
| args.seed)
|
| output_dir = os.path.join(args.output_dir, checkpoint_prefix)
|
| if not os.path.exists(output_dir):
|
| os.makedirs(output_dir)
|
| model_to_save = model.module if hasattr(model, 'module') else model
|
|
|
| torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
| torch.save(model_to_save.state_dict(),
|
| os.path.join(output_dir, 'training_{}.bin'.format(idx)))
|
| logger.info("Saving model checkpoint to %s", output_dir)
|
| torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
| torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
| logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
| torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
| logger.info("Saving model checkpoint to %s", output_dir)
|
|
|
| checkpoint_prefix = 'checkpoint-last.' + args.mode + str(args.seed)
|
| output_dir = os.path.join(args.output_dir, checkpoint_prefix)
|
| if not os.path.exists(output_dir):
|
| os.makedirs(output_dir)
|
| model_to_save = model.module if hasattr(model, 'module') else model
|
| torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
|
|
| idx_file = os.path.join(output_dir, 'idx_file.txt')
|
| with open(idx_file, 'w', encoding='utf-8') as idxf:
|
| idxf.write(str(args.start_epoch + idx) + '\n')
|
| logger.info("Saving model checkpoint to %s", output_dir)
|
| torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
| torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
| logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
| step_file = os.path.join(output_dir, 'step_file.txt')
|
| with open(step_file, 'w', encoding='utf-8') as stepf:
|
| stepf.write(str(global_step) + '\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
| def train_triple(args, triple_dataset, model, triple_IPA_dataset, train_dataset, IPA_dataset, MODE):
|
| '''concat text_dataset and IPA_dataset'''
|
| concated_dataset = []
|
| for x in zip(triple_dataset, triple_IPA_dataset):
|
| concated_dataset.append([item for sublist in x for item in sublist])
|
| for x in range(len(train_dataset)):
|
| if x < len(concated_dataset):
|
| pass
|
| else:
|
| concated_dataset.append('placeholder')
|
|
|
| if MODE == '73':
|
| t_d = []
|
| for x in range(len(train_dataset.examples)):
|
| t_d.append(train_dataset.examples[x] + IPA_dataset[x])
|
| train_dataset = t_d
|
| triple_sampler = SequentialSampler(concated_dataset)
|
| triple_dataloader = DataLoader(concated_dataset, sampler=triple_sampler, batch_size=args.train_batch_size,
|
| num_workers=0,
|
| pin_memory=True, collate_fn=lambda x: x)
|
| train_sampler = RandomSampler(train_dataset)
|
| if MODE == '73':
|
| train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,
|
| num_workers=0,
|
| pin_memory=True, collate_fn=lambda x: x)
|
| else:
|
| train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,
|
| num_workers=0,
|
| pin_memory=True)
|
| args.save_steps = len(train_dataloader) if args.save_steps <= 0 else args.save_steps
|
| args.warmup_steps = len(train_dataloader) if args.warmup_steps <= 0 else args.warmup_steps
|
| args.logging_steps = len(train_dataloader)
|
|
|
| if args.max_steps > 0:
|
| t_total = args.max_steps
|
| args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps)
|
| else:
|
| t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
|
|
|
|
| no_decay = ['bias', 'LayerNorm.weight']
|
| optimizer_grouped_parameters = [
|
| {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
| 'weight_decay': args.weight_decay},
|
| {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
| ]
|
| optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
| scheduler = get_linear_schedule_with_warmup(optimizer, args.warmup_steps, t_total)
|
|
|
| model.to(args.device)
|
|
|
| if args.fp16:
|
| try:
|
| from apex import amp
|
| except ImportError:
|
| raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
| model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
|
|
| checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last.' + args.mode + str(args.seed))
|
| scheduler_last = os.path.join(checkpoint_last, 'scheduler.pt')
|
| optimizer_last = os.path.join(checkpoint_last, 'optimizer.pt')
|
| if os.path.exists(scheduler_last):
|
| scheduler.load_state_dict(torch.load(scheduler_last))
|
| if os.path.exists(optimizer_last):
|
| optimizer.load_state_dict(torch.load(optimizer_last))
|
|
|
|
|
| logger.info("***** Running training *****")
|
| logger.info(" Num examples = %d", len(train_dataset))
|
| logger.info(" Num Epochs = %d", args.num_train_epochs)
|
| logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size)
|
| logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
| args.train_batch_size * args.gradient_accumulation_steps)
|
| logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
| logger.info(" Total optimization steps = %d", t_total)
|
| global_step = args.start_step
|
| tr_loss, logging_loss, avg_loss, tr_nb, tr_num, train_loss = 0.0, 0.0, 0.0, 0, 0, 0
|
|
|
| best_results = {"acc": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0, "acc_and_f1": 0.0}
|
| model.zero_grad()
|
| logger.info(model)
|
|
|
| max_length = args.max_seq_length
|
|
|
| for idx in range(args.start_epoch, int(args.num_train_epochs)):
|
| bar = tqdm(enumerate(zip(triple_dataloader, train_dataloader)))
|
| tr_num = 0
|
| train_loss = 0
|
| for step, [batch, batch2] in bar:
|
| '''bar'''
|
| if step < 618 and MODE == '73':
|
| src_inputs = torch.stack([x[0] for x in batch]).to(args.device)
|
| tgt_inputs = torch.stack([x[1] for x in batch]).to(args.device)
|
| labels = torch.ones(8).to(args.device)
|
| anchor_vec = torch.stack([x[2] for x in batch]).to(args.device)
|
| IPA_lo = torch.stack(
|
| [torch.nn.functional.pad(t.reshape(1, len(t)), (0, max_length - len(t), 0, 0)) for t in
|
| [torch.stack(x[3]) for x in batch]]).to(args.device)
|
| IPA_th = torch.stack(
|
| [torch.nn.functional.pad(t.reshape(1, len(t)), (0, max_length - len(t), 0, 0)) for t in
|
| [torch.stack(x[4]) for x in batch]]).to(args.device)
|
|
|
| model.train()
|
| loss, predictions = model(src_inputs, tgt_inputs, labels, MODE=MODE, anchor_vec=anchor_vec,
|
| src_IPA=IPA_lo, tgt_IPA=IPA_th)
|
|
|
| if args.gradient_accumulation_steps > 1:
|
| loss = loss / args.gradient_accumulation_steps
|
|
|
| if args.fp16:
|
| with amp.scale_loss(loss, optimizer) as scaled_loss:
|
| scaled_loss.backward()
|
| torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
| else:
|
| loss.backward()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
|
|
| '''bar2'''
|
| if MODE == '73':
|
| sc_inputs = torch.stack([x[0] for x in batch2]).to(args.device)
|
| tt_inputs = torch.stack([x[1] for x in batch2]).to(args.device)
|
| lbls = torch.stack([x[2] for x in batch2]).to(args.device)
|
| else:
|
| sc_inputs = batch2[0].to(args.device)
|
| tt_inputs = batch2[1].to(args.device)
|
| lbls = batch2[2].to(args.device)
|
|
|
| if MODE == '73' and IPA_dataset is not None:
|
| IPA_fragment = torch.stack(
|
| [torch.nn.functional.pad(t.reshape(1, len(t)), (0, max_length - len(t), 0, 0)) for t in
|
| [torch.stack(x[3]) if len(x[3]) > 0 else torch.zeros(1, dtype=torch.long) for x in batch2]]).to(
|
| args.device)
|
|
|
| loss2, predictions2 = model(sc_inputs, tt_inputs, lbls, IPA_fragment)
|
| else:
|
| loss2, predictions2 = model(sc_inputs, tt_inputs, lbls)
|
|
|
| if args.gradient_accumulation_steps > 1:
|
| loss2 = loss2 / args.gradient_accumulation_steps
|
|
|
| if args.fp16:
|
| with amp.scale_loss(loss2, optimizer) as scaled_loss2:
|
| scaled_loss2.backward()
|
| torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
| else:
|
| loss2.backward()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
|
|
| tr_loss += loss2.item()
|
| tr_num += 1
|
| train_loss += loss2.item()
|
| if avg_loss == 0:
|
| avg_loss = tr_loss
|
| avg_loss = round(train_loss / tr_num, 5)
|
| bar.set_description("epoch {} step {} loss {}".format(idx, step + 1, avg_loss))
|
|
|
| if (step + 1) % args.gradient_accumulation_steps == 0:
|
| optimizer.step()
|
| optimizer.zero_grad()
|
| scheduler.step()
|
| global_step += 1
|
| avg_loss = round(np.exp((tr_loss - logging_loss) / (global_step - tr_nb)), 4)
|
| if args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
| logging_loss = tr_loss
|
| tr_nb = global_step
|
|
|
| if args.save_steps > 0 and global_step % args.save_steps == 0:
|
| if args.evaluate_during_training:
|
| results = evaluate(args, model)
|
| for key, value in results.items():
|
| logger.info(" %s = %s", key, round(value, 4))
|
|
|
| if results['f1'] >= best_results['f1']:
|
| best_results = results
|
|
|
| logger.info(" Best %s = %s", key, str(round(best_results[key], 4)))
|
|
|
|
|
| checkpoint_prefix = 'checkpoint-best-aver.' + args.mode + str(
|
| args.seed)
|
| output_dir = os.path.join(args.output_dir, checkpoint_prefix)
|
| if not os.path.exists(output_dir):
|
| os.makedirs(output_dir)
|
| model_to_save = model.module if hasattr(model, 'module') else model
|
|
|
| torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
| torch.save(model_to_save.state_dict(),
|
| os.path.join(output_dir, 'training_{}.bin'.format(idx)))
|
| logger.info("Saving model checkpoint to %s", output_dir)
|
| torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
| torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
| logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
| torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
| logger.info("Saving model checkpoint to %s", output_dir)
|
|
|
| checkpoint_prefix = 'checkpoint-last.' + args.mode + str(args.seed)
|
| output_dir = os.path.join(args.output_dir, checkpoint_prefix)
|
| if not os.path.exists(output_dir):
|
| os.makedirs(output_dir)
|
| model_to_save = model.module if hasattr(model, 'module') else model
|
| torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
|
|
| idx_file = os.path.join(output_dir, 'idx_file.txt')
|
| with open(idx_file, 'w', encoding='utf-8') as idxf:
|
| idxf.write(str(args.start_epoch + idx) + '\n')
|
| logger.info("Saving model checkpoint to %s", output_dir)
|
| torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
| torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
| logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
| step_file = os.path.join(output_dir, 'step_file.txt')
|
| with open(step_file, 'w', encoding='utf-8') as stepf:
|
| stepf.write(str(global_step) + '\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
| def test(args, test_dataset, model, IPA_dataset, MODE):
|
| if not args.test_predictions_output:
|
| args.test_predictions_output = os.path.join(args.output_dir,
|
| 'predictions' + '_' + args.mode + str(
|
| args.seed) + '_' + args.test_data_file + '_' + '.txt')
|
| if args.mode == '73':
|
| t_d = []
|
| for x in range(len(test_dataset.examples)):
|
| t_d.append(test_dataset.examples[x] + IPA_dataset[x])
|
| test_dataset = t_d
|
|
|
| if not os.path.exists(os.path.dirname(args.test_predictions_output)):
|
| os.makedirs(os.path.dirname(args.test_predictions_output))
|
| eval_dataset = test_dataset
|
|
|
| eval_sampler = SequentialSampler(eval_dataset)
|
| if args.mode == '73':
|
| eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=0,
|
| pin_memory=True, collate_fn=lambda x: x)
|
| else:
|
| eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
|
|
| model.to(args.device)
|
|
|
|
|
| logger.info("***** Running Test *****")
|
| logger.info(" Num examples = %d", len(eval_dataset))
|
| logger.info(" Batch size = %d", args.eval_batch_size)
|
|
|
| eval_loss = 0.0
|
| nb_eval_steps = 0
|
| all_predictions = []
|
| all_labels = []
|
| if args.mode == '73':
|
| max_length = args.max_seq_length
|
| for batch in eval_dataloader:
|
| sc_inputs = torch.stack([x[0] for x in batch]).to(args.device)
|
| tt_inputs = torch.stack([x[1] for x in batch]).to(args.device)
|
| lbls = torch.stack([x[2] for x in batch]).to(args.device)
|
| IPA_fragment = torch.stack(
|
| [torch.nn.functional.pad(t.reshape(1, len(t)), (0, max_length - len(t), 0, 0)) for t in
|
| [torch.stack(x[3]) if len(x[3]) > 0 else torch.zeros(1, dtype=torch.long) for x in batch]]).to(
|
| args.device)
|
| with torch.no_grad():
|
| lm_loss, predictions = model(sc_inputs, tt_inputs, lbls, IPA_fragment)
|
| eval_loss += lm_loss.mean().item()
|
| all_predictions.append(predictions.cpu())
|
| all_labels.append(lbls.cpu())
|
| nb_eval_steps += 1
|
| else:
|
| for batch in eval_dataloader:
|
| src_inputs = batch[0].to(args.device)
|
| tgt_inputs = batch[1].to(args.device)
|
| labels = batch[2].to(args.device)
|
| with torch.no_grad():
|
| lm_loss, predictions = model(src_inputs, tgt_inputs, labels)
|
| eval_loss += lm_loss.mean().item()
|
| all_predictions.append(predictions.cpu())
|
| all_labels.append(labels.cpu())
|
| nb_eval_steps += 1
|
| all_predictions = torch.cat(all_predictions, 0).squeeze().numpy()
|
| all_labels = torch.cat(all_labels, 0).squeeze().numpy()
|
| eval_loss = torch.tensor(eval_loss / nb_eval_steps)
|
| results = cal_acc_and_f1(all_predictions, all_labels)
|
| results.update({"eval_loss": float(eval_loss)})
|
|
|
| with open(args.test_predictions_output, 'w') as f:
|
| for key in results.keys():
|
| logger.info(" Final test %s = %s", key, str(results[key]))
|
| f.write(f" Final test {key} = {str(results[key])}")
|
| logger.info(" " + "*" * 20)
|
|
|
|
|
| eval_dataset = None
|
|
|
|
|
| def evaluate(args, model):
|
| eval_output_dir = args.output_dir
|
| global eval_dataset
|
| if eval_dataset is None:
|
| eval_data_path = os.path.join(args.data_dir, args.lang)
|
| eval_dataset = TextDataset(plm=SentenceTransformer(r'foundation\E5').to(args.device),
|
| file_path=eval_data_path, file_name=args.eval_data_file)
|
| if args.mode == '73' and not os.path.exists(
|
| os.path.join(args.data_dir, "replace_word_level_IPA" + args.eval_data_file[:-5])):
|
| with open(os.path.join(eval_data_path, args.eval_data_file), 'r', encoding='utf-8') as f:
|
| data = json.load(f)
|
| bar = tqdm(data)
|
| bar.set_description("IPA", refresh=True)
|
|
|
| IPA_dataset = []
|
| same_list = "./utils/same_list"
|
|
|
| statistics = pkl.load(open(same_list, 'rb'))
|
| epi_lo = epitran.Epitran("lao-Laoo")
|
| epi_th = epitran.Epitran("tha-Thai")
|
| epi = Backoff(['deu-Latn', 'lao-Laoo', 'tha-Thai'])
|
| statistics_key = [statistics[i][0] for i in range(len(statistics))]
|
| statistics_dict = {x: idx for idx, x in enumerate(statistics_key)}
|
| for content in bar:
|
| IPA_, _ = text2IPA_by_statistics(content, statistics_dict, statistics_key, epi_lo, epi_th, epi)
|
| IPA_dataset.append([IPA_, _])
|
|
|
| pkl_path = os.path.join(args.data_dir, "replace_word_level_IPA")
|
| binarized_file = open(pkl_path + args.eval_data_file[:-5], "wb")
|
| pkl.dump(IPA_dataset, binarized_file)
|
| else:
|
| IPA_dataset = pkl.load(
|
| open(os.path.join(args.data_dir, "replace_word_level_IPA" + args.eval_data_file[:-5]), 'rb'))
|
|
|
| if not os.path.exists(eval_output_dir):
|
| os.makedirs(eval_output_dir)
|
|
|
| if args.mode == '73' and type(eval_dataset) is not list:
|
| t_d = []
|
| for x in range(len(eval_dataset.examples)):
|
| t_d.append(eval_dataset.examples[x] + IPA_dataset[x])
|
| eval_dataset = t_d
|
| eval_sampler = SequentialSampler(eval_dataset)
|
| eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=0,
|
| pin_memory=True, collate_fn=lambda x: x)
|
| else:
|
| eval_sampler = SequentialSampler(eval_dataset)
|
| eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=0,
|
| pin_memory=True)
|
|
|
|
|
| logger.info("***** Running evaluation *****")
|
| logger.info(" Num examples = %d", len(eval_dataset))
|
| logger.info(" Batch size = %d", args.eval_batch_size)
|
| eval_loss = 0.0
|
| nb_eval_steps = 0
|
| model.eval()
|
| all_predictions = []
|
| all_labels = []
|
| if args.mode == '73':
|
| max_length = args.max_seq_length
|
| for batch in eval_dataloader:
|
| sc_inputs = torch.stack([x[0] for x in batch]).to(args.device)
|
| tt_inputs = torch.stack([x[1] for x in batch]).to(args.device)
|
| lbls = torch.stack([x[2] for x in batch]).to(args.device)
|
| IPA_fragment = torch.stack(
|
| [torch.nn.functional.pad(t.reshape(1, len(t)), (0, max_length - len(t), 0, 0)) for t in
|
| [torch.stack(x[3]) if len(x[3]) > 0 else torch.zeros(1, dtype=torch.long) for x in batch]]).to(
|
| args.device)
|
|
|
| with torch.no_grad():
|
| lm_loss, predictions = model(sc_inputs, tt_inputs, lbls, IPA_fragment)
|
| eval_loss += lm_loss.mean().item()
|
| all_predictions.append(predictions.cpu())
|
| all_labels.append(lbls.cpu())
|
| nb_eval_steps += 1
|
| else:
|
| for batch in eval_dataloader:
|
| src_inputs = batch[0].to(args.device)
|
| tgt_inputs = batch[1].to(args.device)
|
| labels = batch[2].to(args.device)
|
| with torch.no_grad():
|
| lm_loss, predictions = model(src_inputs, tgt_inputs, labels)
|
| eval_loss += lm_loss.mean().item()
|
| all_predictions.append(predictions.cpu())
|
| all_labels.append(labels.cpu())
|
| nb_eval_steps += 1
|
| all_predictions = torch.cat(all_predictions, 0).squeeze().numpy()
|
| all_labels = torch.cat(all_labels, 0).squeeze().numpy()
|
| eval_loss = torch.tensor(eval_loss / nb_eval_steps)
|
|
|
| results = cal_acc_and_f1(all_predictions, all_labels)
|
| results.update({"eval_loss": float(eval_loss)})
|
| return results
|
|
|
|
|
| def init_parser(parser):
|
|
|
| parser.add_argument("--train_data_file", default=None, type=str,
|
| help="The input training data file (a text file).")
|
| parser.add_argument("--output_dir", default=None, type=str, required=True,
|
| help="The output directory where the model predictions and checkpoints will be written.")
|
| parser.add_argument("--data_dir", default=None, type=str, required=True)
|
|
|
| parser.add_argument("--eval_data_file", default=None, type=str,
|
| help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
| parser.add_argument("--test_data_file", default=None, type=str,
|
| help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
|
|
| parser.add_argument("--checkpoint_path", default=None, type=str,
|
| help="The checkpoint path of model to continue training.")
|
|
|
| parser.add_argument("--max_seq_length", default=-1, type=int,
|
| help="Optional input sequence length after tokenization."
|
| "The training dataset will be truncated in block of this size for training."
|
| "Default to the model max input length for single sentence inputs (take into account special tokens).")
|
| parser.add_argument("--do_train", action='store_true',
|
| help="Whether to run training.")
|
| parser.add_argument("--do_eval", action='store_true',
|
| help="Whether to run eval on the dev set.")
|
| parser.add_argument("--do_test", action='store_true',
|
| help="Whether to run eval on the dev set.")
|
| parser.add_argument("--evaluate_during_training", action='store_true',
|
| help="Run evaluation during training at each logging step.")
|
|
|
| parser.add_argument("--train_batch_size", default=8, type=int,
|
| help="Batch size per GPU/CPU for training.")
|
| parser.add_argument("--eval_batch_size", default=8, type=int,
|
| help="Batch size per GPU/CPU for evaluation.")
|
| parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
| help="Number of updates steps to accumulate before performing a backward/update pass.")
|
| parser.add_argument("--learning_rate", default=5e-5, type=float,
|
| help="The initial learning rate for Adam.")
|
| parser.add_argument("--weight_decay", default=1e-2, type=float,
|
| help="Weight deay if we apply some.")
|
| parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
| help="Epsilon for Adam optimizer.")
|
| parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
| help="Max gradient norm.")
|
| parser.add_argument("--num_train_epochs", default=3, type=int,
|
| help="Total number of training epochs to perform.")
|
| parser.add_argument("--max_steps", default=-1, type=int,
|
| help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
| parser.add_argument("--warmup_steps", default=5, type=int,
|
| help="Linear warmup over warmup_steps.")
|
| parser.add_argument("--code_type", default='code', type=str,
|
| help='use `code` or `code_tokens` in the json file to index.')
|
| parser.add_argument('--logging_steps', type=int, default=50,
|
| help="Log every X updates steps.")
|
| parser.add_argument('--save_steps', type=int, default=0,
|
| help="Save checkpoint every X updates steps.")
|
| parser.add_argument('--save_total_limit', type=int, default=None,
|
| help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default')
|
| parser.add_argument("--eval_all_checkpoints", action='store_true',
|
| help="Evaluate all checkpoints starting with the same prefix as encoder_name_or_path ending and ending with step number")
|
| parser.add_argument("--no_cuda", action='store_true',
|
| help="Avoid using CUDA when available")
|
| parser.add_argument('--overwrite_output_dir', action='store_true',
|
| help="Overwrite the content of the output directory")
|
| parser.add_argument('--overwrite_cache', action='store_true',
|
| help="Overwrite the cached training and evaluation sets")
|
| parser.add_argument('--seed', type=int, default=27,
|
| help="random seed for initialization")
|
| parser.add_argument('--fp16', action='store_true',
|
| help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
| parser.add_argument("--local_rank", type=int, default=-1,
|
| help="For distributed training: local_rank")
|
| parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
| parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
|
|
| parser.add_argument("--pred_model_dir", default=None, type=str,
|
| help='model for prediction')
|
| parser.add_argument("--test_result_dir", default='test_results.tsv', type=str,
|
| help='path to store test result')
|
| parser.add_argument("--test_predictions_output", default=None, type=str,
|
| help="The output directory where the model predictions")
|
|
|
| parser.add_argument("--mode", default="73", type=str)
|
| parser.add_argument("--lang", default=None, type=str)
|
|
|
| return parser
|
|
|
|
|
| def init_test_dataset_by_mode(args):
|
| if args.mode[1] == "0":
|
| args.lang = "lo"
|
| elif args.mode[1] == "1":
|
| args.lang = "th"
|
| elif args.mode[1] == "2":
|
| args.lang = "join"
|
| elif args.mode[1] == "3":
|
| args.lang = "replace_word_level"
|
|
|
| test_data_path = os.path.join(args.data_dir, args.lang)
|
|
|
| test_dataset = TextDataset(plm=SentenceTransformer(r'foundation\E5').to(args.device),
|
| file_path=test_data_path, file_name=args.test_data_file)
|
|
|
|
|
| if args.mode[1] == '3' and os.path.exists(
|
| os.path.join(args.data_dir, "replace_word_level_IPA" + args.test_data_file[:-5])):
|
| IPA_dataset = pkl.load(
|
| open(os.path.join(args.data_dir, "replace_word_level_IPA" + args.test_data_file[:-5]), 'rb'))
|
| elif args.mode[1] == '3':
|
| with open(os.path.join(test_data_path, args.test_data_file), 'r', encoding='utf-8') as f:
|
| data = json.load(f)
|
| bar = tqdm(data)
|
| bar.set_description("IPA", refresh=True)
|
|
|
| IPA_dataset = []
|
| same_list = "./utils/same_list"
|
|
|
| statistics = pkl.load(open(same_list, 'rb'))
|
| epi_lo = epitran.Epitran("lao-Laoo")
|
| epi_th = epitran.Epitran("tha-Thai")
|
| epi = Backoff(['deu-Latn', 'lao-Laoo', 'tha-Thai'])
|
| statistics_key = [statistics[i][0] for i in range(len(statistics))]
|
| statistics_dict = {x: idx for idx, x in enumerate(statistics_key)}
|
| for content in bar:
|
| IPA_, _ = text2IPA_by_statistics(content, statistics_dict, statistics_key, epi_lo, epi_th, epi)
|
| IPA_dataset.append([IPA_, _])
|
|
|
| pkl_path = os.path.join(args.data_dir, "replace_word_level_IPA")
|
| binarized_file = open(pkl_path + args.test_data_file[:-5], "wb")
|
| pkl.dump(IPA_dataset, binarized_file)
|
| else:
|
| IPA_dataset = None
|
| return test_dataset, IPA_dataset
|
|
|
|
|
| def init_train_dataset_by_mode(args):
|
| if args.mode[1] == "0":
|
| args.lang = "lo"
|
| elif args.mode[1] == "1":
|
| args.lang = "th"
|
| elif args.mode[1] == "2":
|
| args.lang = "join"
|
| elif args.mode[1] == "3":
|
| args.lang = "replace_word_level"
|
|
|
| train_data_path = os.path.join(args.data_dir, args.lang)
|
|
|
| train_dataset = TextDataset(plm=SentenceTransformer(r'foundation/E5').to(args.device),
|
| file_path=train_data_path, file_name=args.train_data_file)
|
|
|
|
|
| if args.mode[1] == '3' and os.path.exists(os.path.join(args.data_dir, "replace_word_levelIPA")):
|
| IPA_dataset = pkl.load(open(os.path.join(args.data_dir, "replace_word_levelIPA"), 'rb'))
|
| return train_dataset, IPA_dataset
|
|
|
| elif args.mode[1] == '3' and not os.path.exists(os.path.join(args.data_dir, "replace_word_levelIPA")):
|
| with open(os.path.join(train_data_path, args.train_data_file), 'r', encoding='utf-8') as f:
|
| data = json.load(f)
|
| bar = tqdm(data)
|
| bar.set_description("IPA", refresh=True)
|
|
|
| IPA_dataset = []
|
| same_list = "./utils/same_list"
|
|
|
| statistics = pkl.load(open(same_list, 'rb'))
|
| epi_lo = epitran.Epitran("lao-Laoo")
|
| epi_th = epitran.Epitran("tha-Thai")
|
| epi = Backoff(['deu-Latn', 'lao-Laoo', 'tha-Thai'])
|
| statistics_key = [statistics[i][0] for i in range(len(statistics))]
|
| statistics_dict = {x: idx for idx, x in enumerate(statistics_key)}
|
| for content in bar:
|
| IPA_, _ = text2IPA_by_statistics(content, statistics_dict, statistics_key, epi_lo, epi_th, epi)
|
| IPA_dataset.append([IPA_, _])
|
|
|
| pkl_path = os.path.join(args.data_dir, "replace_word_level")
|
| binarized_file = open(pkl_path + 'IPA', "wb")
|
| pkl.dump(IPA_dataset, binarized_file)
|
| return train_dataset, IPA_dataset
|
| else:
|
| IPA_dataset = None
|
| return train_dataset, IPA_dataset
|
|
|
|
|
| def init_triple_text_IPA_dataset(plm):
|
| triple_data_path = os.path.join(args.data_dir, "triple")
|
|
|
| pkl_path = triple_data_path
|
| if os.path.exists(os.path.join(pkl_path, 'triple_binarized.pklIPA')):
|
| binarized_file1 = open(os.path.join(pkl_path, 'triple_binarized.pklword'), "rb")
|
| binarized_file2 = open(os.path.join(pkl_path, 'triple_binarized.pklIPA'), "rb")
|
| dataset_ = pkl.load(binarized_file1)
|
| IPA_dataset = pkl.load(binarized_file2)
|
| return dataset_, IPA_dataset
|
| else:
|
| file_content = []
|
| for data_file_name in ['data_lo.txt', 'data_th.txt', 'data_zh.txt']:
|
| file_content.append(open(os.path.join(triple_data_path, data_file_name), 'r', encoding='utf-8').readlines())
|
| if data_file_name == "data_zh.txt":
|
| continue
|
|
|
| data = list(map(list, zip(*file_content)))
|
| bar = tqdm(data)
|
| bar.set_description("IPA", refresh=True)
|
| dataset_ = []
|
| IPA_dataset = []
|
| same_list = "./utils/same_list"
|
|
|
| statistics = pkl.load(open(same_list, 'rb'))
|
| epi_lo = epitran.Epitran("lao-Laoo")
|
| epi_th = epitran.Epitran("tha-Thai")
|
|
|
|
|
| statistics_key = [statistics[i][0] for i in range(len(statistics))]
|
| statistics_dict = {x: idx for idx, x in enumerate(statistics_key)}
|
| for content in bar:
|
| IPA_lo, IPA_th = text2IPA_by_statistics(content, statistics_dict, statistics_key, epi_lo, epi_th)
|
| IPA_dataset.append([IPA_lo, IPA_th])
|
| dataset_.append(
|
| [torch.from_numpy(plm.encode(content[0])),
|
| torch.from_numpy(plm.encode(content[1])),
|
| torch.from_numpy(plm.encode(content[2]))])
|
|
|
| binarized_file1 = open(pkl_path + 'word', "wb")
|
| binarized_file2 = open(pkl_path + 'IPA', "wb")
|
| pkl.dump(dataset_, binarized_file1)
|
| pkl.dump(IPA_dataset, binarized_file2)
|
| return dataset_, None
|
|
|
|
|
| def text2IPA_by_statistics(sent_, statistics_dict, statistics_key, epi_lo, epi_th, epi=None):
|
| '''根据统计的”发音碎片“结果,对语料进行处理'''
|
| IPA_fragment_list = []
|
| plm_tokenizer = AutoTokenizer.from_pretrained(os.getcwd() + '/foundation/E5')
|
|
|
| ipa_lo = []
|
| ipa_th = []
|
| ipa_ = []
|
|
|
| if epi is not None:
|
| word_tokenized = \
|
| plm_tokenizer(sent_['tgt'], max_length=512, padding=True, truncation=True, return_tensors='pt').encodings[
|
| 0].tokens[2:-1]
|
| for word in word_tokenized:
|
| '''PADDING_IDX == 0'''
|
| if epi is not None:
|
| word_ipa = epi.transliterate(word)
|
| if word_ipa in statistics_key:
|
| ipa_.append(torch.tensor(statistics_dict[word_ipa] + 1))
|
| continue
|
| else:
|
| for idx, content in enumerate(sent_):
|
| word_tokenized = \
|
| plm_tokenizer(content, max_length=512, padding=True, truncation=True, return_tensors='pt').encodings[
|
| 0].tokens[2:-1]
|
| for word in word_tokenized:
|
| '''PADDING_IDX == 0'''
|
| if idx == 0:
|
| word_ipa = epi_lo.transliterate(word)
|
| if word_ipa in statistics_key:
|
| ipa_lo.append(torch.tensor(statistics_dict[word_ipa] + 1))
|
| if idx == 1:
|
| word_ipa = epi_th.transliterate(word)
|
| if word_ipa in statistics_key:
|
| ipa_th.append(torch.tensor(statistics_dict[word_ipa] + 1))
|
| if epi is not None:
|
| return ipa_, '_'
|
| else:
|
| return ipa_lo, ipa_th
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser()
|
| init_parser(parser=parser)
|
|
|
| args = parser.parse_args()
|
|
|
| args.start_epoch = 0
|
| args.start_step = 0
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
| args.n_gpu = torch.cuda.device_count()
|
| args.device = device
|
|
|
|
|
| logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
| datefmt='%m/%d/%Y %H:%M:%S',
|
| level=logging.INFO)
|
| if args.do_train:
|
| file_handler = logging.FileHandler(f"do_train.{args.mode + str(args.seed)}.log", mode='a')
|
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
| file_handler.setFormatter(formatter)
|
| file_handler.setLevel(logging.INFO)
|
| logger.addHandler(file_handler)
|
|
|
|
|
| set_seed(args.seed)
|
|
|
|
|
| args.start_epoch = 0
|
| args.start_step = 0
|
| checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last.' + args.mode + str(args.seed))
|
| if os.path.exists(checkpoint_last) and os.listdir(checkpoint_last):
|
| idx_file = os.path.join(checkpoint_last, 'idx_file.txt')
|
| with open(idx_file, encoding='utf-8') as idxf:
|
| args.start_epoch = int(idxf.readlines()[0].strip()) + 1
|
|
|
| step_file = os.path.join(checkpoint_last, 'step_file.txt')
|
| if os.path.exists(step_file):
|
| with open(step_file, encoding='utf-8') as stepf:
|
| args.start_step = int(stepf.readlines()[0].strip())
|
|
|
| logger.info("reload model from {}, resume from {} epoch".format(checkpoint_last, args.start_epoch))
|
|
|
| model = SentencesPairExtract(IPA_embed_dim=32, max_seq_length=args.max_seq_length, batch_size=args.train_batch_size,
|
| IPA_vocab_size=len(pkl.load(open("./utils/same_list", 'rb'))) + 1)
|
|
|
|
|
| if args.do_train:
|
|
|
| train_dataset, IPA_dataset = init_train_dataset_by_mode(args=args)
|
| if args.mode[0] == '0':
|
| train(args, train_dataset, model, IPA_dataset, MODE=args.mode)
|
|
|
|
|
| elif args.mode[0] == '7':
|
| triple_lang_dataset, triple_IPA_dataset = init_triple_text_IPA_dataset(
|
| plm=SentenceTransformer('foundation/E5').to(args.device))
|
| train_triple(args, triple_lang_dataset, model, triple_IPA_dataset, train_dataset, IPA_dataset,
|
| MODE=args.mode)
|
|
|
| if args.do_test:
|
| logger.info("***** Testing results *****")
|
| output_dir = os.path.join(args.output_dir, "checkpoint-best-aver." + args.mode + str(args.seed))
|
| model_path = os.path.join(output_dir, 'pytorch_model.bin')
|
| model.load_state_dict(torch.load(model_path))
|
| logger.info("Test Model From: {}".format(model_path))
|
|
|
| test_dataset, IPA_dataset = init_test_dataset_by_mode(args=args)
|
| test(args, test_dataset, model, IPA_dataset, MODE=args.mode)
|
|
|