from torch.utils.data import TensorDataset import numpy as np import logging import os import random import torch import time from _utils import * logger = logging.getLogger(__name__) # ----------------------------- # GENERATION DATA # ----------------------------- def load_and_cache_gen_data(args, filename, pool, tokenizer, split_tag, only_src=False, is_sample=False): data_tag = '_all' if args.data_num == -1 else '_%d' % args.data_num cache_fn = '{}/{}.pt'.format( args.cache_path, split_tag + ('_src' if only_src else '') + data_tag ) examples = read_examples(filename, args.data_num, args.task) logger.info("Processing %d examples", len(examples)) if args.data_num != -1: logger.warning("⚠ Running in small-data mode: %d samples", args.data_num) if is_sample: examples = random.sample(examples, min(10000, len(examples))) if split_tag == 'train': calc_stats(examples, tokenizer, is_tokenize=True) else: calc_stats(examples) # ---------------- CACHE ---------------- if os.path.exists(cache_fn) and os.path.getsize(cache_fn) > 0 and not is_sample: logger.info("Load cache data from %s", cache_fn) data = torch.load(cache_fn) return examples, data logger.info("Creating cache data into %s", cache_fn) tuple_examples = [(example, idx, tokenizer, args, split_tag) for idx, example in enumerate(examples)] # ---------------- NO MULTIPROCESSING ---------------- features = [ convert_examples_to_features(x) for x in tuple_examples ] all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) if args.task == 'jit': all_tag_ids = torch.tensor([f.tag_ids for f in features], dtype=torch.long) all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long) labels = torch.tensor([f.url for f in features], dtype=torch.long) data = TensorDataset(all_source_ids, all_tag_ids, all_target_ids, labels) elif args.data_type in ['s1', 's2']: all_tag_ids = torch.tensor([f.tag_ids for f in features], dtype=torch.long) if split_tag == 'test' or only_src: data = TensorDataset(all_source_ids, all_tag_ids) else: all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long) data = TensorDataset(all_source_ids, all_tag_ids, all_target_ids) else: if split_tag == 'test' or only_src: data = TensorDataset(all_source_ids) else: all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long) data = TensorDataset(all_source_ids, all_target_ids) if args.local_rank in [-1, 0] and not is_sample: torch.save(data, cache_fn) return examples, data # ----------------------------- # CLONE DATA # ----------------------------- def load_and_cache_clone_data(args, filename, pool, tokenizer, split_tag, is_sample=False): cache_fn = '{}/{}.pt'.format( args.cache_path, split_tag + '_all' if args.data_num == -1 else '_%d' % args.data_num ) examples = read_examples(filename, args.data_num, args.task) if is_sample: examples = random.sample(examples, int(len(examples) * 0.1)) calc_stats(examples, tokenizer, is_tokenize=True) if os.path.exists(cache_fn): logger.info("Load cache data from %s", cache_fn) return examples, torch.load(cache_fn) tuple_examples = [(example, idx, tokenizer, args) for idx, example in enumerate(examples)] features = [ convert_clone_examples_to_features(x) for x in tuple_examples ] all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) all_labels = torch.tensor([f.label for f in features], dtype=torch.long) data = TensorDataset(all_source_ids, all_labels) if args.local_rank in [-1, 0] and args.data_num == -1: torch.save(data, cache_fn) return examples, data # ----------------------------- # DEFECT DATA # ----------------------------- def load_and_cache_defect_data(args, filename, pool, tokenizer, split_tag, is_sample=False): cache_fn = os.path.join(args.cache_path, split_tag) examples = read_examples(filename, args.data_num, args.task) if is_sample: examples = random.sample(examples, int(len(examples) * 0.1)) calc_stats(examples, tokenizer, is_tokenize=True) if os.path.exists(cache_fn): logger.info("Load cache data from %s", cache_fn) return examples, torch.load(cache_fn) tuple_examples = [(example, idx, tokenizer, args) for idx, example in enumerate(examples)] features = [ convert_defect_examples_to_features(x) for x in tuple_examples ] all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) all_labels = torch.tensor([f.label for f in features], dtype=torch.long) data = TensorDataset(all_source_ids, all_labels) if args.local_rank in [-1, 0] and args.data_num == -1: torch.save(data, cache_fn) return examples, data # ----------------------------- # MULTI-GEN DATA # ----------------------------- def load_and_cache_multi_gen_data(args, pool, tokenizer, split_tag, only_src=False, is_sample=False): cache_fn = os.path.join(args.cache_path, split_tag) if os.path.exists(cache_fn) and not is_sample: logger.info("Load cache data from %s", cache_fn) return torch.load(cache_fn) examples_data_dict = {} task_list = ['summarize', 'translate', 'refine', 'concode', 'defect'] for task in task_list: if task == 'summarize': sub_tasks = ['ruby', 'r', 'go', 'typescript', 'swift', 'php'] elif task == 'translate': sub_tasks = ['swift-cs', 'cs-swift'] elif task == 'refine': sub_tasks = ['small', 'medium'] else: sub_tasks = ['none'] args.task = task for sub_task in sub_tasks: args.sub_task = sub_task filename = get_filenames(args.data_dir, task, sub_task, split_tag) examples = read_examples(filename, args.data_num, task) if is_sample: examples = random.sample(examples, min(5000, len(examples))) if split_tag == 'train': calc_stats(examples, tokenizer, is_tokenize=True) else: calc_stats(examples) tuple_examples = [(ex, idx, tokenizer, args, split_tag) for idx, ex in enumerate(examples)] features = [ convert_examples_to_features(x) for x in tuple_examples ] all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) if only_src: data = TensorDataset(all_source_ids) else: all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long) data = TensorDataset(all_source_ids, all_target_ids) key = f"{task}_{sub_task}" if sub_task != 'none' else task examples_data_dict[key] = (examples, data) if args.local_rank in [-1, 0] and not is_sample: torch.save(examples_data_dict, cache_fn) logger.info("Save data into %s", cache_fn) return examples_data_dict # ----------------------------- # FILE HELPERS (UNCHANGED) # ----------------------------- def get_filenames(data_root, task, sub_task, split=''): if task == 'concode': data_dir = f'{data_root}/{task}' train_fn = f'{data_dir}/train.json' dev_fn = f'{data_dir}/dev.json' test_fn = f'{data_dir}/test.json' elif task == 'jit': data_dir = f'{data_root}/{task}/{sub_task}' train_fn = f'{data_dir}/train.jsonl' test_fn = f'{data_dir}/test.jsonl' dev_fn = test_fn elif task == 'summarize': data_dir = f'{data_root}/{task}/{sub_task}' train_fn = f'{data_dir}/train.jsonl' dev_fn = f'{data_dir}/valid.jsonl' test_fn = f'{data_dir}/test.jsonl' elif task == 'refine': data_dir = f'{data_root}/{task}/{sub_task}' train_fn = f'{data_dir}/train.buggy-fixed.buggy,{data_dir}/train.buggy-fixed.fixed' dev_fn = f'{data_dir}/valid.buggy-fixed.buggy,{data_dir}/valid.buggy-fixed.fixed' test_fn = f'{data_dir}/test.buggy-fixed.buggy,{data_dir}/test.buggy-fixed.fixed' elif task == 'translate': data_dir = f'{data_root}/translate' if sub_task == 'cs-swift': train_fn = f'{data_dir}/train.swift-cs.txt.cs,{data_dir}/train.swift-cs.txt.swift' dev_fn = f'{data_dir}/valid.swift-cs.txt.cs,{data_dir}/valid.swift-cs.txt.swift' test_fn = f'{data_dir}/test.swift-cs.txt.cs,{data_dir}/test.swift-cs.txt.swift' else: train_fn = f'{data_dir}/train.swift-cs.txt.swift,{data_dir}/train.swift-cs.txt.cs' dev_fn = f'{data_dir}/valid.swift-cs.txt.swift,{data_dir}/valid.swift-cs.txt.cs' test_fn = f'{data_dir}/test.swift-cs.txt.swift,{data_dir}/test.swift-cs.txt.cs' elif task == 'clone': data_dir = f'{data_root}/clone' train_fn = f'{data_dir}/train.txt' dev_fn = f'{data_dir}/valid.txt' test_fn = f'{data_dir}/test.txt' elif task == 'defect': data_dir = f'{data_root}/defect' train_fn = f'{data_dir}/train.jsonl' dev_fn = f'{data_dir}/valid.jsonl' test_fn = f'{data_dir}/test.jsonl' elif task == 'mcmd': data_dir = f'{data_root}/{task}/{sub_task}' train_fn = f'{data_dir}/train.jsonl' dev_fn = f'{data_dir}/valid.jsonl' test_fn = f'{data_dir}/test.jsonl' elif task == 'mcmd-nt': data_dir = f'{data_root}/{task}/{sub_task}' train_fn = f'{data_dir}/train.jsonl' dev_fn = f'{data_dir}/valid.jsonl' test_fn = f'{data_dir}/test.jsonl' elif task == 'mcmd-nl': data_dir = f'{data_root}/{task}/{sub_task}' train_fn = f'{data_dir}/train.jsonl' dev_fn = f'{data_dir}/valid.jsonl' test_fn = f'{data_dir}/test.jsonl' if split == 'train': return train_fn elif split == 'dev': return dev_fn elif split == 'test': return test_fn else: return train_fn, dev_fn, test_fn # ----------------------------- # REMAINING FUNCTIONS (UNCHANGED) # ----------------------------- def read_examples(filename, data_num, task): read_example_dict = { 'summarize': read_summarize_examples, 'refine': read_refine_examples, 'translate': read_translate_examples, 'concode': read_concode_examples, 'clone': read_clone_examples, 'defect': read_defect_examples, 'jit': read_jit_examples, "mcmd": read_mcmd_examples, "mcmd-nt": read_mcmd_examples, "mcmd-nl": read_mcmd_examples, } return read_example_dict[task](filename, data_num) def calc_stats(examples, tokenizer=None, is_tokenize=False): avg_src_len = [] avg_trg_len = [] for ex in examples: avg_src_len.append(len(ex.source.split())) avg_trg_len.append(len(str(ex.target).split())) logger.info( "Read %d examples, avg src len: %.2f, avg trg len: %.2f", len(examples), np.mean(avg_src_len), np.mean(avg_trg_len) ) def get_elapse_time(t0): elapse_time = time.time() - t0 if elapse_time > 3600: return f"{int(elapse_time//3600)}h{int((elapse_time%3600)//60)}m" return f"{int(elapse_time//60)}m"