Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import time | |
| import logging | |
| from tqdm import tqdm | |
| import torch | |
| from fairseq import utils, tasks, options | |
| from fairseq.checkpoint_utils import load_model_ensemble_and_task | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from torch import Tensor | |
| from typing import Dict, List, Optional | |
| logging.basicConfig( | |
| format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
| stream=sys.stdout, | |
| ) | |
| logger = logging.getLogger("inference") | |
| def write_result(results, output_file): | |
| with open(output_file, 'w') as f: | |
| for line in results: | |
| f.write(line + '\n') | |
| def fairseq_generate(data_lines, cfg, models, task, batch_size, device): | |
| # fairseq original decoding implementation | |
| src_dict = task.source_dictionary | |
| tgt_dict = task.target_dictionary | |
| generator = task.build_generator(models, cfg.generation) | |
| data_size = len(data_lines) | |
| all_results = [] | |
| logger.info(f'Fairseq generate batch {batch_size}') | |
| start = time.perf_counter() | |
| for start_idx in tqdm(range(0, data_size, batch_size)): | |
| batch_lines = [line for line in data_lines[start_idx: min(start_idx + batch_size, data_size)]] | |
| batch_ids = [src_dict.encode_line(sentence, add_if_not_exist=False).long() for sentence in batch_lines] | |
| lengths = torch.LongTensor([t.numel() for t in batch_ids]) | |
| batch_dataset = task.build_dataset_for_inference(batch_ids, lengths) | |
| batch = batch_dataset.collater(batch_dataset) | |
| batch = utils.apply_to_sample(lambda t: t.to(device), batch) | |
| translations = generator.generate(models, batch, prefix_tokens=None) | |
| results = [] | |
| for id, hypos in zip(batch["id"].tolist(), translations): | |
| results.append((id, hypos)) | |
| batched_hypos = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] | |
| all_results.extend([tgt_dict.string(hypos[0]['tokens']) for hypos in batched_hypos]) | |
| delta = time.perf_counter() - start | |
| remove_bpe_results = [line.replace('@@ ', '') for line in all_results] | |
| return remove_bpe_results, delta | |
| def baseline_forward_decoder(model, | |
| input_tokens, | |
| encoder_out: Dict[str, List[Tensor]], | |
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |
| parallel_forward_start_pos=None, | |
| temperature: float = 1.0): | |
| decoder_out = model.decoder.forward(input_tokens, | |
| encoder_out=encoder_out, | |
| incremental_state=incremental_state, | |
| parallel_forward_start_pos=parallel_forward_start_pos) | |
| decoder_out_tuple = (decoder_out[0].div_(temperature), decoder_out[1]) | |
| pred_tokens = torch.argmax(decoder_out_tuple[0], dim=-1).squeeze(0) | |
| return pred_tokens | |
| def baseline_generate(data_lines, model, task, batch_size, device, max_len=200): | |
| # simplified AR greedy decoding | |
| src_dict = task.source_dictionary | |
| tgt_dict = task.target_dictionary | |
| data_size = len(data_lines) | |
| all_results = [] | |
| start = time.perf_counter() | |
| logger.info(f'Baseline generate') | |
| for start_idx in tqdm(range(0, data_size, batch_size)): | |
| batch_size = min(data_size - start_idx, batch_size) | |
| batch_lines = [line for line in data_lines[start_idx: start_idx + batch_size]] | |
| batch_ids = [src_dict.encode_line(sentence, add_if_not_exist=False).long() for sentence in batch_lines] | |
| lengths = torch.LongTensor([t.numel() for t in batch_ids]) | |
| batch_dataset = task.build_dataset_for_inference(batch_ids, lengths) | |
| batch_dataset.left_pad_source = False | |
| batch = batch_dataset.collater(batch_dataset) | |
| batch = utils.apply_to_sample(lambda t: t.to(device), batch) | |
| net_input = batch['net_input'] | |
| encoder_out = model.encoder.forward(net_input['src_tokens'], net_input['src_lengths']) | |
| incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], | |
| torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})) | |
| batch_tokens = [[tgt_dict.eos()] for _ in range(batch_size)] | |
| finish_list = [] | |
| for step in range(0, max_len): | |
| cur_input_tokens = torch.tensor(batch_tokens).to(device).long() | |
| pred_tokens = baseline_forward_decoder(model, | |
| cur_input_tokens, | |
| encoder_out, | |
| incremental_state=incremental_state) | |
| for i, pred_tok in enumerate(pred_tokens): | |
| if len(batch_tokens[i]) == 1: | |
| batch_tokens[i].append(pred_tok.item()) | |
| else: | |
| if batch_tokens[i][-1] != tgt_dict.eos(): | |
| batch_tokens[i].append(pred_tok.item()) | |
| else: | |
| if i not in finish_list: | |
| finish_list.append(i) | |
| batch_tokens[i].append(tgt_dict.eos()) | |
| if len(finish_list) == batch_size: | |
| break | |
| batch_tokens = [y for x, y in sorted(zip(batch['id'].cpu().tolist(), batch_tokens))] | |
| for tokens in batch_tokens: | |
| all_results.append(tgt_dict.string(tokens[1:])) | |
| remove_bpe_results = [line.replace('@@ ', '') for line in all_results] | |
| delta = time.perf_counter() - start | |
| return remove_bpe_results, delta | |
| def forward_decoder(model, input_tokens, encoder_out, incremental_state=None, | |
| parallel_forward_start_pos=None, temperature=1.0, beta=1, tau=0.0): | |
| decoder_out = model.decoder.forward(input_tokens, | |
| encoder_out=encoder_out, | |
| incremental_state=incremental_state, | |
| parallel_forward_start_pos=parallel_forward_start_pos) | |
| decoder_out_tuple = (decoder_out[0].div_(temperature), decoder_out[1]) | |
| topk_scores, indexes = torch.topk(decoder_out_tuple[0], beta, dim=-1) | |
| topk_scores_list = topk_scores.tolist() | |
| indexes_list = indexes.tolist() | |
| for i in range(indexes.size(0)): | |
| for j in range(indexes.size(1)): | |
| for k, s in enumerate(topk_scores_list[i][j]): | |
| if topk_scores_list[i][j][0] - s > tau: | |
| indexes_list[i][j][k] = -1 | |
| return indexes_list | |
| def gad_generate(data_lines, model, AR_model, task, block_size, batch_size, device, beta=1, tau=0, max_len=200): | |
| # Generalized Aggressive Decoding | |
| src_dict = task.source_dictionary | |
| tgt_dict = task.target_dictionary | |
| data_size = len(data_lines) | |
| all_results = [] | |
| logger.info(f'GAD generate') | |
| start = time.perf_counter() | |
| for start_idx in tqdm(range(0, data_size, batch_size)): | |
| batch_size = min(data_size - start_idx, batch_size) | |
| batch_lines = [line for line in data_lines[start_idx: start_idx + batch_size]] | |
| batch_ids = [src_dict.encode_line(sentence, add_if_not_exist=False).long() for sentence in batch_lines] | |
| lengths = torch.LongTensor([t.numel() for t in batch_ids]) | |
| batch_dataset = task.build_dataset_for_inference(batch_ids, lengths) | |
| batch_dataset.left_pad_source = False | |
| batch = batch_dataset.collater(batch_dataset) | |
| batch = utils.apply_to_sample(lambda t: t.to(device), batch) | |
| net_input = batch['net_input'] | |
| AR_encoder_out = AR_model.encoder.forward(net_input['src_tokens'], net_input['src_lengths']) | |
| encoder_out = model.encoder.forward(net_input['src_tokens'], net_input['src_lengths']) | |
| sentences = [[tgt_dict.eos()] for _ in range(batch_size)] | |
| prev_output_tokens = [[tgt_dict.unk()] * block_size for _ in range(batch_size)] | |
| start_pos_list = [0] * batch_size | |
| finish_list = [] | |
| for step in range(0, max_len): | |
| prev_output_tokens, start_pos_list = gad_forward(start_pos_list, block_size, batch_size, | |
| tgt_dict, prev_output_tokens, | |
| encoder_out, AR_encoder_out, model, AR_model, beta, tau) | |
| for i, start_pos in enumerate(start_pos_list): | |
| if i not in finish_list: | |
| if start_pos == -1: | |
| finish_list.append(i) | |
| sentences[i] = prev_output_tokens[i] | |
| if len(finish_list) == batch_size: | |
| break | |
| batch_sents = [y for x, y in sorted(zip(batch['id'].cpu().tolist(), sentences))] | |
| for s in batch_sents: | |
| all_results.append(tgt_dict.string(s)) | |
| remove_bpe_results = [line.replace('@@ ', '') for line in all_results] | |
| delta = time.perf_counter() - start | |
| return remove_bpe_results, delta | |
| def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_tokens, | |
| encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200): | |
| pad_tokens = [[tgt_dict.pad()] * (max_len + block_size) for _ in range(batch_size)] | |
| for i in range(batch_size): | |
| pad_tokens[i][:len(prev_output_tokens[i])] = prev_output_tokens[i] | |
| output_tokens = torch.tensor(pad_tokens).to(device) | |
| output_tokens = output_tokens[:, : output_tokens.ne(tgt_dict.pad()).sum(1).max()] | |
| _, tensor_tokens = model.decoder( | |
| normalize=False, | |
| prev_output_tokens=output_tokens, | |
| encoder_out=encoder_out, | |
| ).max(-1) | |
| _tokens = tensor_tokens.tolist() | |
| for i, start_pos in enumerate(start_pos_list): | |
| if start_pos_list[i] != -1: | |
| output_tokens[i, start_pos:start_pos + block_size] = tensor_tokens[i, start_pos:start_pos + block_size] | |
| prev_output_tokens[i][start_pos:start_pos + block_size] = _tokens[i][start_pos:start_pos + block_size] | |
| append_eos = torch.tensor([[tgt_dict.eos()] for _ in range(batch_size)]).to(device) | |
| cur_span_input_tokens = torch.cat((append_eos, output_tokens), dim=-1) | |
| AR_verify_tokens = forward_decoder(AR_model, cur_span_input_tokens, AR_encoder_out, beta=beta, tau=tau) | |
| next_output_tokens = prev_output_tokens.copy() | |
| for i in range(batch_size): | |
| if start_pos_list[i] != -1: | |
| bifurcation = block_size | |
| for j, (token, AR_verify_token) in enumerate( | |
| zip(prev_output_tokens[i][start_pos_list[i]:], AR_verify_tokens[i][start_pos_list[i]:-1])): | |
| if token not in AR_verify_token: | |
| bifurcation = j | |
| break | |
| next_output_tokens[i] = prev_output_tokens[i][:start_pos_list[i] + bifurcation] + \ | |
| [AR_verify_tokens[i][start_pos_list[i] + bifurcation][0]] + \ | |
| [tgt_dict.unk()] * block_size | |
| find_eos = False | |
| for j, o in enumerate(next_output_tokens[i][start_pos_list[i]:start_pos_list[i] + bifurcation + 1]): | |
| if o == tgt_dict.eos() or start_pos_list[i] + j == max_len: | |
| next_output_tokens[i] = next_output_tokens[i][:start_pos_list[i] + j] | |
| start_pos_list[i] = -1 | |
| find_eos = True | |
| break | |
| if not find_eos: | |
| start_pos_list[i] = start_pos_list[i] + bifurcation + 1 | |
| return next_output_tokens, start_pos_list | |
| if __name__ == '__main__': | |
| parser = options.get_generation_parser() | |
| parser.add_argument('--input-path', type=str, required=True, | |
| help='path to eval file') | |
| parser.add_argument('--output-path', type=str, default=None, | |
| help='path to output file') | |
| parser.add_argument('--AR-path', type=str, default=None, | |
| help='path to AR model') | |
| parser.add_argument('--strategy', type=str, default='fairseq', | |
| help='decoding strategy, choose from: fairseq, AR, gad') | |
| parser.add_argument('--batch', type=int, default=None, | |
| help='batch size') | |
| parser.add_argument('--block-size', type=int, default=5, | |
| help='block size') | |
| parser.add_argument('--beta', type=int, default=1, | |
| help='top-beta hyperparameter') | |
| parser.add_argument('--tau', type=float, default=0, | |
| help='tolerance hyperparameter') | |
| cmd_args = options.parse_args_and_arch(parser) | |
| cmd_args.input_path = os.path.expanduser(cmd_args.input_path) | |
| cmd_args.output_path = os.path.expanduser(cmd_args.output_path) | |
| cfg = convert_namespace_to_omegaconf(cmd_args) | |
| task = tasks.setup_task(cfg.task) | |
| # NAR drafter | |
| logger.info("loading model(s) from {}".format(cfg.common_eval.path)) | |
| models, _model_args, _model_task = load_model_ensemble_and_task(filenames=[cfg.common_eval.path], task=task) | |
| if cmd_args.cpu: | |
| device = torch.device('cpu') | |
| else: | |
| device = torch.device('cuda') | |
| model = models[0].to(device).eval() | |
| if cfg.common.fp16: | |
| logging.info("NAR fp16 enabled!") | |
| model.half() | |
| # AR verifier | |
| AR_model = None | |
| AR_models = None | |
| _AR_model_task = None | |
| if cmd_args.AR_path is not None: | |
| AR_models, _AR_model_args, _AR_model_task = load_model_ensemble_and_task(filenames=[cmd_args.AR_path], | |
| arg_overrides={'data': cfg.task.data}) | |
| if cfg.common.fp16: | |
| logging.info("AR fp16 enabled!") | |
| for AR_model in AR_models: | |
| AR_model.half() | |
| AR_model = AR_models[0].to(device).eval() | |
| logging.info("AR model loaded!") | |
| with open(cmd_args.input_path, 'r') as f: | |
| bpe_sents = [l.strip() for l in f.readlines()] | |
| if cmd_args.strategy == 'AR': | |
| logger.info("Decoding Strategy: Simplified AR") | |
| remove_bpe_results, delta = baseline_generate(bpe_sents, AR_model, _AR_model_task, cmd_args.batch, device) | |
| logger.info(f'Simplified AR generate: {delta}') | |
| elif cmd_args.strategy == 'gad': | |
| logger.info("Decoding Strategy: GAD") | |
| remove_bpe_results, delta = gad_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, cmd_args.batch, | |
| device, beta=cmd_args.beta, tau=cmd_args.tau) | |
| logger.info(f'GAD generate: {delta}') | |
| else: | |
| logger.info("Decoding Strategy: fairseq") | |
| remove_bpe_results, delta = fairseq_generate(bpe_sents, cfg, AR_models, _AR_model_task, cmd_args.batch, device) | |
| logger.info(f'Fairseq generate batch {cmd_args.batch}, beam {cfg.generation.beam}: {delta}') | |
| if cmd_args.output_path is not None: | |
| write_result(remove_bpe_results, cmd_args.output_path) | |