Spaces:
Runtime error
Runtime error
File size: 6,210 Bytes
e489264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
from argparse import ArgumentParser
import os
import re
import time
from processors import FilesProcessor, get_text_distorter
from processes import (
CharsRemover,
LengthFilter,
LinesSplitter,
LoadFile,
NumbersFilter,
OOVFilter,
RepeatedCharsCollapsor,
# SoloCharFilter,
SpacesRemover,
ValidCharsKeeper,
WordsFilter,
WordsNumberFilter,
CharsNormalizer,
TokenizerLengthFilter,
)
from helpers import load_json, save_text_file
from typing import Union, List
from pathlib import Path
import constants
import pandas as pd
def get_paths(
main_dir: Union[Path, str]
) -> List[Union[Path, str]]:
paths = [
os.path.join(main_dir, file)
for file in os.listdir(main_dir)
]
return paths
def get_path(
file_path: Union[Path, str]
) -> List[Union[Path, str]]:
if os.path.isfile(file_path):
return [file_path]
else:
raise FileNotFoundError
def get_file_processor(args):
words = load_json(args.execlude_words_files)
processes = [
LoadFile(),
*[LinesSplitter(sep=sep) for sep in args.sep],
RepeatedCharsCollapsor(args.max_rep_chars),
NumbersFilter(),
# SoloCharFilter(),
WordsFilter(words),
ValidCharsKeeper(constants.VALID_CHARS),
SpacesRemover(),
WordsNumberFilter(args.min_words, args.max_words),
# TokenizerLengthFilter(),
LengthFilter(args.min_len, args.max_len)
]
return FilesProcessor(processes)
def post_process(data: List[str]) -> List[str]:
lines = []
for item in data:
lines.extend(item)
lines = list(set(lines))
# lines = OOVFilter(args.max_oov).execute(lines)
return lines
clean_punctuation = re.compile(r"(?<!\d)[!.:;?،؛؟«» ،؛۔٫٪؟](?!\d)")
def remove_punctuation(text):
"""Remove all punctuation from string, except if it's between digits"""
return clean_punctuation.sub("", text)
def get_argparser():
parser = ArgumentParser()
parser.add_argument(
'--sep', default=[
'\n',
# '\t', '.', '،', ',', '=', ':', '-', '\\', '/'
], nargs='+', type=str,
help='The seperator to be used to split the lines on'
)
parser.add_argument(
'--min_len', default=5, type=int,
help='The minimum line length to keep'
)
parser.add_argument(
'--max_len', default=1020, type=int,
help='The maximum line length to keep'
)
parser.add_argument(
'--dist_run', default=False, action='store_true'
)
parser.add_argument(
'--data_path', default='data/data.txt'
)
parser.add_argument(
'--save_path', default='data/clean_data.txt'
)
parser.add_argument(
'--max_rep_chars', default=2
)
parser.add_argument(
'--execlude_words_files', default='data/words.json'
)
parser.add_argument(
'--max_oov', default=100, type=int
)
parser.add_argument(
'--min_words', default=3, type=int
)
parser.add_argument(
'--max_words', default=100, type=int
)
parser.add_argument(
'--dist_ratios', default=[0.05, 0.1, 0.15]
)
parser.add_argument(
'--remove_punc', default=False, action='store_true', help='Remove punctuation of the distorted lines'
)
return parser
def main(args) -> None:
fp = get_file_processor(args)
files = get_path(args.data_path)
print('Started!')
start = time.time()
if args.dist_run is True:
print('dist run')
data = fp.dist_run(files)
else:
data = fp.run(files)
end = time.time()
print(f'Files Processing completed in {end - start}')
data = post_process(data)
sentences = data[: len(data) // 2]
print("Length of data after post processing", len(data))
df = None
for i, ratio in enumerate(args.dist_ratios):
distorter = get_text_distorter(ratio, sentences)
# TODO: Don't touch 2 percent of sentences to keep the model from having a high bias towards the noise
dist = list(map(distorter.run, data))
if df is None:
df = pd.DataFrame({
'clean': data,
f'distorted_{ratio}': dist
})
else:
df[f'distorted_{ratio}'] = dist
if args.remove_punc is True:
print("Removing punctuations for the distorted lines")
for ratio in args.dist_ratios:
df[f'distorted_{ratio}'] = df[f'distorted_{ratio}'].apply(
remove_punctuation
)
df.to_csv(f'data/data.csv', encoding='utf-8')
# save_text_file(args.save_path, '\n'.join(data))
if __name__ == '__main__':
parser = get_argparser()
args = parser.parse_args()
main(args)
num_lines = sum(1 for line in open(f"data/data.csv",'r'))
os.system(f"echo \"text,summary\" > train.csv")
# # Only change the first $ variable for different distortion ratios
# os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $4 \",\" $2}}' data/data.csv >> train.csv")
# os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $3 \",\" $2}}' data/data.csv >> train.csv")
os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $5 \",\" $2}}' data/data.csv | sed 's/\"//g' >> train.csv")
os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $4 \",\" $2}}' data/data.csv | sed 's/\"//g' >> train.csv")
os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $3 \",\" $2}}' data/data.csv | sed 's/\"//g' >> train.csv")
os.system(f"echo \"text,summary\" > test.csv")
# os.system(f"tail -n 50000 data/data.csv | awk -F',' '{{print $4 \",\" $2}}' >> test.csv")
# os.system(f"tail -n 50000 data/data.csv | awk -F',' '{{print $3 \",\" $2}}' >> test.csv")
os.system(f"awk -F',' 'NR>{num_lines-50000} {{print $5 \",\" $2}}' data/data.csv | sed 's/\"//g' >> test.csv")
os.system(f"awk -F',' 'NR>{num_lines-50000} {{print $4 \",\" $2}}' data/data.csv | sed 's/\"//g' >> test.csv")
os.system(f"awk -F',' 'NR>{num_lines-50000} {{print $3 \",\" $2}}' data/data.csv | sed 's/\"//g' >> test.csv")
|