X1A commited on
Commit
4681f9e
·
1 Parent(s): d08239e

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -64
utils.py DELETED
@@ -1,64 +0,0 @@
1
- import os
2
- import json
3
- from pathlib import Path
4
- from loguru import logger
5
- from typing import Callable, Iterable, List
6
-
7
-
8
- def lmap(f: Callable, x: Iterable) -> List:
9
- """list(map(f, x))"""
10
- return list(map(f, x))
11
-
12
- def write_txt_file(ordered_tgt, path):
13
- f = Path(path).open("w")
14
- for ln in ordered_tgt:
15
- f.write(ln + "\n")
16
- f.flush()
17
-
18
-
19
- def save_json(content, path, indent=4, **json_dump_kwargs):
20
- with open(path, "w") as f:
21
- json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
22
-
23
-
24
- def handle_metrics(split, metrics, output_dir):
25
- """
26
- Log and save metrics
27
-
28
- Args:
29
- - split: one of train, val, test
30
- - metrics: metrics dict
31
- - output_dir: where to save the metrics
32
- """
33
-
34
- logger.info(f"***** {split} metrics *****")
35
- for key in sorted(metrics.keys()):
36
- logger.info(f" {key} = {metrics[key]}")
37
- save_json(metrics, os.path.join(output_dir, f"results_{split}.json"))
38
-
39
-
40
- import shutil
41
- def delete_checkpoints(model_dir):
42
- checkpoints = [folder for folder in os.listdir(model_dir) if folder.split("-")[0]=="checkpoint"]
43
- logger.info(f"Deleting checkpoints.\n{checkpoints}")
44
- for checkpoint in checkpoints:
45
- shutil.rmtree(os.path.join(model_dir, checkpoint))
46
-
47
-
48
- import jieba
49
- from functools import partial
50
- from transformers import BertTokenizer
51
-
52
- class T5PegasusTokenizer(BertTokenizer):
53
- def __init__(self, *args, **kwargs):
54
- super().__init__(*args, **kwargs)
55
- self.pre_tokenizer = partial(jieba.cut, HMM=False)
56
-
57
- def _tokenize(self, text, *arg, **kwargs):
58
- split_tokens = []
59
- for text in self.pre_tokenizer(text):
60
- if text in self.vocab:
61
- split_tokens.append(text)
62
- else:
63
- split_tokens.extend(super()._tokenize(text))
64
- return split_tokens