clouds125's picture
Upload folder using huggingface_hub
2d8ff8e verified
from torch.utils.data import TensorDataset
import numpy as np
import logging
import os
import random
import torch
import time
from _utils import *
logger = logging.getLogger(__name__)
# -----------------------------
# GENERATION DATA
# -----------------------------
def load_and_cache_gen_data(args, filename, pool, tokenizer, split_tag, only_src=False, is_sample=False):
data_tag = '_all' if args.data_num == -1 else '_%d' % args.data_num
cache_fn = '{}/{}.pt'.format(
args.cache_path,
split_tag + ('_src' if only_src else '') + data_tag
)
examples = read_examples(filename, args.data_num, args.task)
logger.info("Processing %d examples", len(examples))
if args.data_num != -1:
logger.warning("⚠ Running in small-data mode: %d samples", args.data_num)
if is_sample:
examples = random.sample(examples, min(10000, len(examples)))
if split_tag == 'train':
calc_stats(examples, tokenizer, is_tokenize=True)
else:
calc_stats(examples)
# ---------------- CACHE ----------------
if os.path.exists(cache_fn) and os.path.getsize(cache_fn) > 0 and not is_sample:
logger.info("Load cache data from %s", cache_fn)
data = torch.load(cache_fn)
return examples, data
logger.info("Creating cache data into %s", cache_fn)
tuple_examples = [(example, idx, tokenizer, args, split_tag)
for idx, example in enumerate(examples)]
# ---------------- NO MULTIPROCESSING ----------------
features = [
convert_examples_to_features(x)
for x in tuple_examples
]
all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
if args.task == 'jit':
all_tag_ids = torch.tensor([f.tag_ids for f in features], dtype=torch.long)
all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)
labels = torch.tensor([f.url for f in features], dtype=torch.long)
data = TensorDataset(all_source_ids, all_tag_ids, all_target_ids, labels)
elif args.data_type in ['s1', 's2']:
all_tag_ids = torch.tensor([f.tag_ids for f in features], dtype=torch.long)
if split_tag == 'test' or only_src:
data = TensorDataset(all_source_ids, all_tag_ids)
else:
all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)
data = TensorDataset(all_source_ids, all_tag_ids, all_target_ids)
else:
if split_tag == 'test' or only_src:
data = TensorDataset(all_source_ids)
else:
all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)
data = TensorDataset(all_source_ids, all_target_ids)
if args.local_rank in [-1, 0] and not is_sample:
torch.save(data, cache_fn)
return examples, data
# -----------------------------
# CLONE DATA
# -----------------------------
def load_and_cache_clone_data(args, filename, pool, tokenizer, split_tag, is_sample=False):
cache_fn = '{}/{}.pt'.format(
args.cache_path,
split_tag + '_all' if args.data_num == -1 else '_%d' % args.data_num
)
examples = read_examples(filename, args.data_num, args.task)
if is_sample:
examples = random.sample(examples, int(len(examples) * 0.1))
calc_stats(examples, tokenizer, is_tokenize=True)
if os.path.exists(cache_fn):
logger.info("Load cache data from %s", cache_fn)
return examples, torch.load(cache_fn)
tuple_examples = [(example, idx, tokenizer, args)
for idx, example in enumerate(examples)]
features = [
convert_clone_examples_to_features(x)
for x in tuple_examples
]
all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
data = TensorDataset(all_source_ids, all_labels)
if args.local_rank in [-1, 0] and args.data_num == -1:
torch.save(data, cache_fn)
return examples, data
# -----------------------------
# DEFECT DATA
# -----------------------------
def load_and_cache_defect_data(args, filename, pool, tokenizer, split_tag, is_sample=False):
cache_fn = os.path.join(args.cache_path, split_tag)
examples = read_examples(filename, args.data_num, args.task)
if is_sample:
examples = random.sample(examples, int(len(examples) * 0.1))
calc_stats(examples, tokenizer, is_tokenize=True)
if os.path.exists(cache_fn):
logger.info("Load cache data from %s", cache_fn)
return examples, torch.load(cache_fn)
tuple_examples = [(example, idx, tokenizer, args)
for idx, example in enumerate(examples)]
features = [
convert_defect_examples_to_features(x)
for x in tuple_examples
]
all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
data = TensorDataset(all_source_ids, all_labels)
if args.local_rank in [-1, 0] and args.data_num == -1:
torch.save(data, cache_fn)
return examples, data
# -----------------------------
# MULTI-GEN DATA
# -----------------------------
def load_and_cache_multi_gen_data(args, pool, tokenizer, split_tag, only_src=False, is_sample=False):
cache_fn = os.path.join(args.cache_path, split_tag)
if os.path.exists(cache_fn) and not is_sample:
logger.info("Load cache data from %s", cache_fn)
return torch.load(cache_fn)
examples_data_dict = {}
task_list = ['summarize', 'translate', 'refine', 'concode', 'defect']
for task in task_list:
if task == 'summarize':
sub_tasks = ['ruby', 'r', 'go', 'typescript', 'swift', 'php']
elif task == 'translate':
sub_tasks = ['swift-cs', 'cs-swift']
elif task == 'refine':
sub_tasks = ['small', 'medium']
else:
sub_tasks = ['none']
args.task = task
for sub_task in sub_tasks:
args.sub_task = sub_task
filename = get_filenames(args.data_dir, task, sub_task, split_tag)
examples = read_examples(filename, args.data_num, task)
if is_sample:
examples = random.sample(examples, min(5000, len(examples)))
if split_tag == 'train':
calc_stats(examples, tokenizer, is_tokenize=True)
else:
calc_stats(examples)
tuple_examples = [(ex, idx, tokenizer, args, split_tag)
for idx, ex in enumerate(examples)]
features = [
convert_examples_to_features(x)
for x in tuple_examples
]
all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
if only_src:
data = TensorDataset(all_source_ids)
else:
all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)
data = TensorDataset(all_source_ids, all_target_ids)
key = f"{task}_{sub_task}" if sub_task != 'none' else task
examples_data_dict[key] = (examples, data)
if args.local_rank in [-1, 0] and not is_sample:
torch.save(examples_data_dict, cache_fn)
logger.info("Save data into %s", cache_fn)
return examples_data_dict
# -----------------------------
# FILE HELPERS (UNCHANGED)
# -----------------------------
def get_filenames(data_root, task, sub_task, split=''):
if task == 'concode':
data_dir = f'{data_root}/{task}'
train_fn = f'{data_dir}/train.json'
dev_fn = f'{data_dir}/dev.json'
test_fn = f'{data_dir}/test.json'
elif task == 'jit':
data_dir = f'{data_root}/{task}/{sub_task}'
train_fn = f'{data_dir}/train.jsonl'
test_fn = f'{data_dir}/test.jsonl'
dev_fn = test_fn
elif task == 'summarize':
data_dir = f'{data_root}/{task}/{sub_task}'
train_fn = f'{data_dir}/train.jsonl'
dev_fn = f'{data_dir}/valid.jsonl'
test_fn = f'{data_dir}/test.jsonl'
elif task == 'refine':
data_dir = f'{data_root}/{task}/{sub_task}'
train_fn = f'{data_dir}/train.buggy-fixed.buggy,{data_dir}/train.buggy-fixed.fixed'
dev_fn = f'{data_dir}/valid.buggy-fixed.buggy,{data_dir}/valid.buggy-fixed.fixed'
test_fn = f'{data_dir}/test.buggy-fixed.buggy,{data_dir}/test.buggy-fixed.fixed'
elif task == 'translate':
data_dir = f'{data_root}/translate'
if sub_task == 'cs-swift':
train_fn = f'{data_dir}/train.swift-cs.txt.cs,{data_dir}/train.swift-cs.txt.swift'
dev_fn = f'{data_dir}/valid.swift-cs.txt.cs,{data_dir}/valid.swift-cs.txt.swift'
test_fn = f'{data_dir}/test.swift-cs.txt.cs,{data_dir}/test.swift-cs.txt.swift'
else:
train_fn = f'{data_dir}/train.swift-cs.txt.swift,{data_dir}/train.swift-cs.txt.cs'
dev_fn = f'{data_dir}/valid.swift-cs.txt.swift,{data_dir}/valid.swift-cs.txt.cs'
test_fn = f'{data_dir}/test.swift-cs.txt.swift,{data_dir}/test.swift-cs.txt.cs'
elif task == 'clone':
data_dir = f'{data_root}/clone'
train_fn = f'{data_dir}/train.txt'
dev_fn = f'{data_dir}/valid.txt'
test_fn = f'{data_dir}/test.txt'
elif task == 'defect':
data_dir = f'{data_root}/defect'
train_fn = f'{data_dir}/train.jsonl'
dev_fn = f'{data_dir}/valid.jsonl'
test_fn = f'{data_dir}/test.jsonl'
elif task == 'mcmd':
data_dir = f'{data_root}/{task}/{sub_task}'
train_fn = f'{data_dir}/train.jsonl'
dev_fn = f'{data_dir}/valid.jsonl'
test_fn = f'{data_dir}/test.jsonl'
elif task == 'mcmd-nt':
data_dir = f'{data_root}/{task}/{sub_task}'
train_fn = f'{data_dir}/train.jsonl'
dev_fn = f'{data_dir}/valid.jsonl'
test_fn = f'{data_dir}/test.jsonl'
elif task == 'mcmd-nl':
data_dir = f'{data_root}/{task}/{sub_task}'
train_fn = f'{data_dir}/train.jsonl'
dev_fn = f'{data_dir}/valid.jsonl'
test_fn = f'{data_dir}/test.jsonl'
if split == 'train':
return train_fn
elif split == 'dev':
return dev_fn
elif split == 'test':
return test_fn
else:
return train_fn, dev_fn, test_fn
# -----------------------------
# REMAINING FUNCTIONS (UNCHANGED)
# -----------------------------
def read_examples(filename, data_num, task):
read_example_dict = {
'summarize': read_summarize_examples,
'refine': read_refine_examples,
'translate': read_translate_examples,
'concode': read_concode_examples,
'clone': read_clone_examples,
'defect': read_defect_examples,
'jit': read_jit_examples,
"mcmd": read_mcmd_examples,
"mcmd-nt": read_mcmd_examples,
"mcmd-nl": read_mcmd_examples,
}
return read_example_dict[task](filename, data_num)
def calc_stats(examples, tokenizer=None, is_tokenize=False):
avg_src_len = []
avg_trg_len = []
for ex in examples:
avg_src_len.append(len(ex.source.split()))
avg_trg_len.append(len(str(ex.target).split()))
logger.info(
"Read %d examples, avg src len: %.2f, avg trg len: %.2f",
len(examples),
np.mean(avg_src_len),
np.mean(avg_trg_len)
)
def get_elapse_time(t0):
elapse_time = time.time() - t0
if elapse_time > 3600:
return f"{int(elapse_time//3600)}h{int((elapse_time%3600)//60)}m"
return f"{int(elapse_time//60)}m"