| """ |
| Generate line data for line retrieval task. |
| |
| Usage: |
| python3 gen_data.py --number 1000 |
| """ |
|
|
| import argparse |
| import json |
| from collections import defaultdict |
|
|
| import numpy as np |
| from tqdm import tqdm |
|
|
|
|
| def generate_lines(random_words, num_lines, redirect_ratio): |
| prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask." |
| suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resolving the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is" |
|
|
| |
| visited_indices = set([None]) |
| visited_values = set([None]) |
|
|
| lines = [] |
| redirects = [] |
| indices = [] |
| values = [] |
| for i in tqdm(range(num_lines)): |
| line_index = None |
| while line_index in visited_indices: |
| line_index = "-".join(np.random.choice(random_words, size=(2,))) |
| visited_indices.add(line_index) |
|
|
| line_value = np.random.randint(low=0, high=999999) |
| line_value = f"{line_value:06}" |
|
|
| line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}." |
| lines.append(line) |
| redirects.append(None) |
| indices.append(line_index) |
| values.append(line_value) |
|
|
| |
| if redirect_ratio > 0: |
| num_redirect_lines = int(len(lines) * redirect_ratio) |
| redirect_indices = np.random.choice( |
| np.arange(len(lines)), size=(num_redirect_lines,), replace=False |
| ) |
| for i in redirect_indices: |
| target_idx = np.random.choice(min(i * 2 + 100, num_lines)) |
| lines[i] = ( |
| f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." |
| ) |
| redirects[i] = target_idx |
|
|
| |
| links = [[] for _ in range(num_lines)] |
| contains_ring = set() |
| for i in range(num_lines): |
| if redirects[i] is None: |
| continue |
|
|
| tmp_link = [] |
| cur = i |
| visited = set() |
| while redirects[cur] is not None: |
| visited.add(cur) |
| tmp_link.append(redirects[cur]) |
| cur = redirects[cur] |
|
|
| if cur in visited: |
| contains_ring.add(i) |
| tmp_link = None |
| break |
| values[i] = values[cur] |
| links[i] = tmp_link |
|
|
| |
| group_by_num_hoops = defaultdict(list) |
| for i in range(num_lines): |
| if i in contains_ring: |
| continue |
| group_by_num_hoops[len(links[i]) + 1].append(i) |
|
|
| keys = sorted(list(group_by_num_hoops.keys())) |
| for num_links in keys: |
| print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}") |
|
|
| |
| hoop1_candidates = list(group_by_num_hoops[1]) |
| hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates} |
| hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c]) |
| hoop2_candidates = list(group_by_num_hoops[2]) |
| hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates} |
| hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c]) |
|
|
| i = hoop1_candidates[5] |
| suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i]) |
| if len(hoop2_candidates): |
| i = hoop2_candidates[0] |
| suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i]) |
| i = hoop2_candidates[1] |
| suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i]) |
| else: |
| i = hoop1_candidates[1] |
| suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i]) |
| i = hoop1_candidates[10] |
| suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i]) |
|
|
| obj = { |
| "prefix": prefix, |
| "suffix": suffix, |
| "lines": lines, |
| "indices": indices, |
| "values": values, |
| "links": links, |
| "group_by_num_hoops": group_by_num_hoops, |
| "contains_ring": sorted(list(contains_ring)), |
| } |
| return obj |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--number", type=int) |
| parser.add_argument("--redirect-ratio", type=float, default=0.0) |
| args = parser.parse_args() |
|
|
| num_lines = args.number |
|
|
| random_words_filename = "random_words.json" |
| random_words = json.load(open(random_words_filename, "r")) |
|
|
| np.random.seed(42) |
| obj = generate_lines(random_words, num_lines, args.redirect_ratio) |
|
|
| fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json" |
| with open(fout, "w") as fout: |
| json.dump(obj, fout, indent=2) |
|
|