| | import argparse |
| | import numpy as np |
| | from pathlib import Path |
| | import tqdm |
| | from pprint import pprint |
| | import torch |
| | from torch.nn.utils.rnn import pad_sequence |
| | from scrl.config import load_config |
| | from scrl.training import setup_and_train |
| | from scrl.model import labels_to_summary |
| | from scrl.eval_metrics import compute_token_f1 |
| | import scrl.utils as utils |
| | from nltk import word_tokenize |
| |
|
| |
|
| | def evaluate_validation_reward(args, manager, model, tokenizer, reward_generator, dataset): |
| | device = args.device |
| | idx_range = list(range(len(dataset))) |
| | dataset_indices = list(utils.batchify(idx_range, args.batch_size)) |
| | rewards = [] |
| | for i, indices in enumerate(dataset_indices): |
| | if args.max_val_steps != None and i >= args.max_val_steps: |
| | break |
| | batch = dataset[indices] |
| | input_ids = batch["input_ids"] |
| | input_ids = pad_sequence( |
| | [torch.tensor(ids) for ids in input_ids], batch_first=True |
| | ) |
| | logits = model(input_ids.to(device)) |
| | probs = torch.softmax(logits, dim=2) |
| | argmax_labels = torch.argmax(logits, dim=2).to(device) |
| | argmax_summaries = labels_to_summary(input_ids, argmax_labels, tokenizer) |
| | argmax_rewards, _ = reward_generator(batch["document"], argmax_summaries) |
| | rewards += argmax_rewards |
| | avg_reward = np.mean(rewards) |
| | return avg_reward |
| |
|
| |
|
| |
|
| | def evaluate_validation_dataset(args, manager, model, tokenizer, reward_generator, dataset_path): |
| | f1_scores = [] |
| | dataset = list(utils.read_jsonl(dataset_path)) |
| | dump_data = [] |
| |
|
| | for item in tqdm.tqdm(dataset): |
| | src = item["text"] |
| | tgts = item["summaries"] |
| |
|
| | input_ids = torch.tensor(tokenizer([src])["input_ids"]).to(args.device) |
| | logits = model.forward(input_ids) |
| | argmax_labels = torch.argmax(logits, dim=2) |
| | pred = labels_to_summary(input_ids, argmax_labels, tokenizer)[0] |
| |
|
| | pred_tokens = word_tokenize(pred) |
| | src_tokens = word_tokenize(src) |
| |
|
| |
|
| | item_scores = [] |
| | for tgt in tgts: |
| | tgt_tokens = word_tokenize(tgt) |
| | pred_tokens = [t.lower() for t in pred_tokens] |
| | tgt_tokens = [t.lower() for t in tgt_tokens] |
| | token_f1 = compute_token_f1( |
| | tgt_tokens, pred_tokens, use_counts=True |
| | ) |
| | item_scores.append(token_f1) |
| |
|
| | if args.dump: |
| | probs = torch.softmax(logits, dim=2)[0].detach().tolist() |
| | dump_item = { |
| | "probs": probs, |
| | "source": src, |
| | "target": tgts[0], |
| | "f1-score": item_scores[0], |
| | "pred_summary": pred, |
| | "pred_labels": argmax_labels[0].tolist(), |
| | } |
| | dump_data.append(dump_item) |
| |
|
| | item_score = np.mean(item_scores) |
| | f1_scores.append(item_score) |
| | score = np.mean(f1_scores) |
| |
|
| |
|
| | if args.dump: |
| | dataset_name = dataset_path.name.split(".jsonl")[0] |
| | dump_dir = manager.dir / f"dump-{dataset_name}" |
| | dump_dir.mkdir(exist_ok=True) |
| | utils.write_jsonl( |
| | dump_data, |
| | dump_dir / f"step-{manager.step}.jsonl", |
| | "w" |
| | ) |
| | return score |
| |
|
| |
|
| | def evaluate(args, manager, model, tokenizer, reward_generator, holdout_data): |
| | step = manager.step |
| | val_reward = evaluate_validation_reward(args, manager, model, tokenizer, reward_generator, holdout_data) |
| |
|
| | reward_path = manager.dir / "val_rewards.jsonl" |
| | if reward_path.exists(): |
| | reward_results = list(utils.read_jsonl(reward_path)) |
| | prev_max = max([x["score"] for x in reward_results]) |
| | else: |
| | reward_results = [] |
| | prev_max = 0 |
| | if val_reward > prev_max: |
| | manager.save_model(model, step, "best_val_reward") |
| | reward_results.append({"step": step, "score": val_reward}) |
| | utils.write_jsonl(reward_results, reward_path, "w") |
| | if args.verbose: |
| | print("Validation Rewards:") |
| | pprint(reward_results) |
| | print() |
| |
|
| | |
| | for val_data_path in args.validation_datasets: |
| | val_data_path = Path(val_data_path) |
| | dataset_name = val_data_path.name.split(".jsonl")[0] |
| | dataset_score = evaluate_validation_dataset( |
| | args, manager, model, tokenizer, reward_generator, val_data_path |
| | ) |
| | result_path = Path(manager.dir / f"val_data_results.{dataset_name}.jsonl") |
| | if result_path.exists(): |
| | dataset_results = list(utils.read_jsonl(result_path)) |
| | prev_max = max([x["score"] for x in dataset_results]) |
| | else: |
| | dataset_results = [] |
| | prev_max = 0 |
| | if dataset_score > prev_max: |
| | manager.save_model(model, step, f"best_on_{dataset_name}") |
| | dataset_results.append({"step": step, "score": dataset_score}) |
| | utils.write_jsonl(dataset_results, result_path, "w") |
| | if args.verbose: |
| | print(f"Validation Dataset Results for {dataset_name}:") |
| | pprint(dataset_results) |
| | print() |
| |
|
| |
|
| | def main(args): |
| | utils.set_random_seed(0) |
| | setup_and_train(args, eval_func=evaluate) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", help="path to JSON config file") |
| | parser.add_argument("--device", default="cuda") |
| | parser.add_argument("--dump", action="store_true") |
| | parser.add_argument("--verbose", action="store_true") |
| | parser.add_argument( |
| | "--fresh", |
| | action="store_true", |
| | help="delete model directory and start from scratch" |
| | ) |
| | main(load_config(parser.parse_args())) |
| |
|