|
|
import json |
|
|
import argparse |
|
|
import csv |
|
|
import sys |
|
|
import copy |
|
|
import os |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import glob |
|
|
from sklearn.preprocessing import minmax_scale |
|
|
import random |
|
|
|
|
|
|
|
|
import pickle |
|
|
|
|
|
import json |
|
|
import pandas as pd |
|
|
|
|
|
def sample_ds(input_file, output_file, num_insts=10000, min_num_text_per_inst=0, max_num_text_per_inst=3): |
|
|
""" |
|
|
sample_ds('/mnt/swordfish-pool2/nikhil/raw_all/test_queries.jsonl', '/mnt/swordfish-pool2/milad/hiatus-data/reddit_cluster_test.pkl', |
|
|
num_insts=10000, |
|
|
min_num_text_per_inst=3, |
|
|
max_num_text_per_inst=10) |
|
|
|
|
|
sample_ds('/mnt/swordfish-pool2/nikhil/raw_all/data.jsonl', '/mnt/swordfish-pool2/milad/hiatus-data/reddit_cluster_training.pkl', |
|
|
num_insts=10000, |
|
|
min_num_text_per_inst=3, |
|
|
max_num_text_per_inst=10) |
|
|
""" |
|
|
f = open(input_file) |
|
|
out_list = [] |
|
|
for i in range(num_insts): |
|
|
json_obj = json.loads(f.readline()) |
|
|
if len(json_obj['syms']) < min_num_text_per_inst: |
|
|
continue |
|
|
|
|
|
out_list.append({ |
|
|
'fullText': json_obj['syms'][:max_num_text_per_inst], |
|
|
'authorID': json_obj['author_id'] |
|
|
}) |
|
|
df = pd.DataFrame(out_list) |
|
|
df.to_pickle(output_file) |
|
|
|
|
|
def get_reddit_data(input_path, random_seed=123, num_instances=100, num_documents_per_author=8, min_instance_len=10): |
|
|
|
|
|
df = pd.read_pickle(open(input_path, 'rb')) |
|
|
df['fullText'] = df.fullText.map(lambda x: [d for d in x if len(d.split()) > min_instance_len]) |
|
|
df = df[df.fullText.str.len() > num_documents_per_author * 2] |
|
|
|
|
|
output_objs = [] |
|
|
|
|
|
for _, row in df.iterrows(): |
|
|
|
|
|
|
|
|
query_author_df = df[df.authorID == row['authorID']] |
|
|
|
|
|
author_documents = [x for x in query_author_df.fullText.tolist()[0] if len(x.split()) > min_instance_len] |
|
|
|
|
|
if len(author_documents) <= num_documents_per_author * 2: |
|
|
continue |
|
|
|
|
|
query_documents = author_documents[:num_documents_per_author] |
|
|
correct_documents = author_documents[num_documents_per_author:] |
|
|
|
|
|
|
|
|
|
|
|
other_authors_df = df[df.authorID != row['authorID']] |
|
|
other_two_authors = other_authors_df.sample(2, random_state=random_seed) |
|
|
|
|
|
|
|
|
min_found_texts = min([len(correct_documents), len(query_documents)] + [len(x) for x in other_two_authors.fullText.tolist()]) |
|
|
query_documents = query_documents[:min_found_texts] |
|
|
correct_documents = correct_documents[:min_found_texts] |
|
|
other_two_authors.fullText = other_two_authors.fullText.apply(lambda x: x[:min_found_texts]) |
|
|
|
|
|
output_objs.append({ |
|
|
"Q_authorID": str(row["authorID"]), |
|
|
"Q_fullText": ["Text:\n{}".format(d) for d in query_documents], |
|
|
"a0_authorID": str(other_two_authors.iloc[0]["authorID"]), |
|
|
"a0_fullText": ["Text:\n{}".format(d) for d in other_two_authors.iloc[0]["fullText"][:num_documents_per_author]], |
|
|
"a1_authorID": str(other_two_authors.iloc[1]["authorID"]), |
|
|
"a1_fullText": ["Text:\n{}".format(d) for d in other_two_authors.iloc[1]["fullText"][:num_documents_per_author]], |
|
|
"a2_authorID": str(row["authorID"]) + "_correct", |
|
|
"a2_fullText": ["Text:\n{}".format(d) for d in correct_documents], |
|
|
"gt_idx": 2 |
|
|
}) |
|
|
print( "Text:\n\n".join(query_documents)) |
|
|
random_seed += 1 |
|
|
if len(output_objs) >= num_instances: |
|
|
break |
|
|
|
|
|
return output_objs |
|
|
|
|
|
|
|
|
def get_iarapa_pilot_data(input_path): |
|
|
for data_point in glob.glob(input_path + '*/'): |
|
|
candidates_file = list(glob.glob(data_point + '/data/*_candidates.jsonl'))[0] |
|
|
queries_file = list(glob.glob(data_point + '/data/*_queries.jsonl'))[0] |
|
|
grount_truth_file = list(glob.glob(data_point + '/groundtruth/*_groundtruth.npy'))[0] |
|
|
q_labels_file = glob.glob(data_point + '/groundtruth/*_query-labels.txt')[0] |
|
|
c_labels_file = glob.glob(data_point + '/groundtruth/*_candidate-labels.txt')[0] |
|
|
|
|
|
candidates_df = pd.read_json(candidates_file, lines=True) |
|
|
queries_df = pd.read_json(queries_file, lines=True) |
|
|
|
|
|
queries_df['authorID'] = queries_df.authorIDs.apply(lambda x: x[0]) |
|
|
candidates_df['authorID'] = candidates_df.authorSetIDs.apply(lambda x: x[0]) |
|
|
|
|
|
queries_df = queries_df.groupby('authorID').agg({'fullText': lambda x: list(x)}).reset_index() |
|
|
candidates_df = candidates_df.groupby('authorID').agg({'fullText': lambda x: list(x)}).reset_index() |
|
|
|
|
|
ground_truth_assignment = np.load(open(grount_truth_file, 'rb')) |
|
|
candidate_authors = [a[2:-3] for a in open(c_labels_file).read().split('\n')][:-1] |
|
|
query_authors = [a[2:-3] for a in open(q_labels_file).read().split('\n')][:-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yield query_authors, candidate_authors, queries_df, candidates_df, ground_truth_assignment |
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main entry point for the script. |
|
|
""" |
|
|
parser = argparse.ArgumentParser(description="Prepare Reddit data for author attribution tasks.") |
|
|
parser.add_argument("input_path", type=str, help="Path to the input pandas DataFrame pickle file.") |
|
|
parser.add_argument("output_path", type=str, help="Path to save the output JSON file.") |
|
|
parser.add_argument("--random_seed", type=int, default=123, help="Random seed for sampling.") |
|
|
parser.add_argument("--num_docs", type=int, default=5, help="Number of documents per author for query and correct sets.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"Processing data from: {args.input_path}") |
|
|
output_data = get_reddit_data( |
|
|
input_path=args.input_path, |
|
|
random_seed=args.random_seed, |
|
|
num_documents_per_author=args.num_docs |
|
|
) |
|
|
|
|
|
print(f"Saving {len(output_data)} tasks to: {args.output_path}") |
|
|
with open(args.output_path, 'w') as f: |
|
|
json.dump(output_data, f, indent=4) |
|
|
|
|
|
print("Done.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|