Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
508e8b2
1
Parent(s):
2115d78
Fix classifier preprocessing
Browse files- src/preprocess.py +29 -39
src/preprocess.py
CHANGED
|
@@ -9,7 +9,7 @@ import segment
|
|
| 9 |
from tqdm import tqdm
|
| 10 |
from dataclasses import dataclass, field
|
| 11 |
from transformers import HfArgumentParser
|
| 12 |
-
from shared import
|
| 13 |
import csv
|
| 14 |
import re
|
| 15 |
import random
|
|
@@ -418,9 +418,9 @@ class PreprocessArguments:
|
|
| 418 |
num_jobs: int = field(
|
| 419 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
| 420 |
|
| 421 |
-
overwrite: bool = field(
|
| 422 |
-
|
| 423 |
-
)
|
| 424 |
|
| 425 |
do_generate: bool = field(
|
| 426 |
default=False, metadata={'help': 'Generate labelled data.'}
|
|
@@ -538,11 +538,11 @@ def main():
|
|
| 538 |
# TODO process all valid possible items and then do filtering only later
|
| 539 |
@lru_cache(maxsize=1)
|
| 540 |
def read_db():
|
| 541 |
-
if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
logger.info('Processing raw database')
|
| 547 |
db = {}
|
| 548 |
|
|
@@ -916,11 +916,8 @@ def main():
|
|
| 916 |
# Output training, testing and validation data
|
| 917 |
for name, items in splits.items():
|
| 918 |
outfile = os.path.join(dataset_args.data_dir, name)
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
fp.writelines(items)
|
| 922 |
-
else:
|
| 923 |
-
logger.info(f'Skipping {name}')
|
| 924 |
|
| 925 |
classifier_splits = {
|
| 926 |
dataset_args.c_train_file: train_data,
|
|
@@ -933,31 +930,24 @@ def main():
|
|
| 933 |
# Output training, testing and validation data
|
| 934 |
for name, items in classifier_splits.items():
|
| 935 |
outfile = os.path.join(dataset_args.data_dir, name)
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
'
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
})
|
| 955 |
-
|
| 956 |
-
for labelled_item in labelled_items:
|
| 957 |
-
print(json.dumps(labelled_item), file=fp)
|
| 958 |
-
|
| 959 |
-
else:
|
| 960 |
-
logger.info(f'Skipping {name}')
|
| 961 |
|
| 962 |
logger.info('Write')
|
| 963 |
# Save excess items
|
|
|
|
| 9 |
from tqdm import tqdm
|
| 10 |
from dataclasses import dataclass, field
|
| 11 |
from transformers import HfArgumentParser
|
| 12 |
+
from shared import extract_sponsor_matches_from_text, ACTION_OPTIONS, CATEGORIES, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
|
| 13 |
import csv
|
| 14 |
import re
|
| 15 |
import random
|
|
|
|
| 418 |
num_jobs: int = field(
|
| 419 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
| 420 |
|
| 421 |
+
# overwrite: bool = field(
|
| 422 |
+
# default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
|
| 423 |
+
# )
|
| 424 |
|
| 425 |
do_generate: bool = field(
|
| 426 |
default=False, metadata={'help': 'Generate labelled data.'}
|
|
|
|
| 538 |
# TODO process all valid possible items and then do filtering only later
|
| 539 |
@lru_cache(maxsize=1)
|
| 540 |
def read_db():
|
| 541 |
+
# if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
| 542 |
+
# logger.info(
|
| 543 |
+
# 'Using cached processed database (use `--overwrite` to avoid this behaviour).')
|
| 544 |
+
# with open(processed_db_path) as fp:
|
| 545 |
+
# return json.load(fp)
|
| 546 |
logger.info('Processing raw database')
|
| 547 |
db = {}
|
| 548 |
|
|
|
|
| 916 |
# Output training, testing and validation data
|
| 917 |
for name, items in splits.items():
|
| 918 |
outfile = os.path.join(dataset_args.data_dir, name)
|
| 919 |
+
with open(outfile, 'w', encoding='utf-8') as fp:
|
| 920 |
+
fp.writelines(items)
|
|
|
|
|
|
|
|
|
|
| 921 |
|
| 922 |
classifier_splits = {
|
| 923 |
dataset_args.c_train_file: train_data,
|
|
|
|
| 930 |
# Output training, testing and validation data
|
| 931 |
for name, items in classifier_splits.items():
|
| 932 |
outfile = os.path.join(dataset_args.data_dir, name)
|
| 933 |
+
with open(outfile, 'w', encoding='utf-8') as fp:
|
| 934 |
+
for item in items:
|
| 935 |
+
parsed_item = json.loads(item) # TODO add uuid
|
| 936 |
+
|
| 937 |
+
matches = extract_sponsor_matches_from_text(parsed_item['extracted'])
|
| 938 |
+
|
| 939 |
+
if matches:
|
| 940 |
+
for match in matches:
|
| 941 |
+
print(json.dumps({
|
| 942 |
+
'text': match['text'],
|
| 943 |
+
'label': CATEGORIES.index(match['category'])
|
| 944 |
+
}), file=fp)
|
| 945 |
+
else:
|
| 946 |
+
print(json.dumps({
|
| 947 |
+
'text': parsed_item['text'],
|
| 948 |
+
'label': none_category
|
| 949 |
+
}), file=fp)
|
| 950 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 951 |
|
| 952 |
logger.info('Write')
|
| 953 |
# Save excess items
|