File size: 6,399 Bytes
3269340 f61a01f 3269340 e5d9888 3269340 e5d9888 3269340 e5d9888 3269340 e5d9888 3269340 f61a01f 3269340 e5d9888 3269340 e5d9888 3269340 e5d9888 3269340 e5d9888 3269340 e5d9888 3269340 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import json
import argparse
import csv
import sys
import copy
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
from sklearn.preprocessing import minmax_scale
import random
import pickle
import json
import pandas as pd
def sample_ds(input_file, output_file, num_insts=10000, min_num_text_per_inst=0, max_num_text_per_inst=3):
"""
sample_ds('/mnt/swordfish-pool2/nikhil/raw_all/test_queries.jsonl', '/mnt/swordfish-pool2/milad/hiatus-data/reddit_cluster_test.pkl',
num_insts=10000,
min_num_text_per_inst=3,
max_num_text_per_inst=10)
sample_ds('/mnt/swordfish-pool2/nikhil/raw_all/data.jsonl', '/mnt/swordfish-pool2/milad/hiatus-data/reddit_cluster_training.pkl',
num_insts=10000,
min_num_text_per_inst=3,
max_num_text_per_inst=10)
"""
f = open(input_file)
out_list = []
for i in range(num_insts):
json_obj = json.loads(f.readline())
if len(json_obj['syms']) < min_num_text_per_inst:
continue
out_list.append({
'fullText': json_obj['syms'][:max_num_text_per_inst],
'authorID': json_obj['author_id']
})
df = pd.DataFrame(out_list)
df.to_pickle(output_file)
def get_reddit_data(input_path, random_seed=123, num_instances=100, num_documents_per_author=8, min_instance_len=10):
df = pd.read_pickle(open(input_path, 'rb'))
df['fullText'] = df.fullText.map(lambda x: [d for d in x if len(d.split()) > min_instance_len])
df = df[df.fullText.str.len() > num_documents_per_author * 2]
output_objs = []
for _, row in df.iterrows():
# Get the current author's documents
query_author_df = df[df.authorID == row['authorID']]
# split the author's documents into two: query and correct author
author_documents = [x for x in query_author_df.fullText.tolist()[0] if len(x.split()) > min_instance_len]
if len(author_documents) <= num_documents_per_author * 2:
continue
query_documents = author_documents[:num_documents_per_author]
correct_documents = author_documents[num_documents_per_author:]
# Sample two *other* authors
other_authors_df = df[df.authorID != row['authorID']]
other_two_authors = other_authors_df.sample(2, random_state=random_seed)
# Make sure all authors are are of equivelant number of texts
min_found_texts = min([len(correct_documents), len(query_documents)] + [len(x) for x in other_two_authors.fullText.tolist()])
query_documents = query_documents[:min_found_texts]
correct_documents = correct_documents[:min_found_texts]
other_two_authors.fullText = other_two_authors.fullText.apply(lambda x: x[:min_found_texts])
output_objs.append({
"Q_authorID": str(row["authorID"]),
"Q_fullText": ["Text:\n{}".format(d) for d in query_documents],
"a0_authorID": str(other_two_authors.iloc[0]["authorID"]),
"a0_fullText": ["Text:\n{}".format(d) for d in other_two_authors.iloc[0]["fullText"][:num_documents_per_author]],
"a1_authorID": str(other_two_authors.iloc[1]["authorID"]),
"a1_fullText": ["Text:\n{}".format(d) for d in other_two_authors.iloc[1]["fullText"][:num_documents_per_author]],
"a2_authorID": str(row["authorID"]) + "_correct",
"a2_fullText": ["Text:\n{}".format(d) for d in correct_documents],
"gt_idx": 2
})
print( "Text:\n\n".join(query_documents))
random_seed += 1 # Increment seed to get different authors for the next task
if len(output_objs) >= num_instances:
break
return output_objs
def get_iarapa_pilot_data(input_path):
for data_point in glob.glob(input_path + '*/'):
candidates_file = list(glob.glob(data_point + '/data/*_candidates.jsonl'))[0]
queries_file = list(glob.glob(data_point + '/data/*_queries.jsonl'))[0]
grount_truth_file = list(glob.glob(data_point + '/groundtruth/*_groundtruth.npy'))[0]
q_labels_file = glob.glob(data_point + '/groundtruth/*_query-labels.txt')[0]
c_labels_file = glob.glob(data_point + '/groundtruth/*_candidate-labels.txt')[0]
candidates_df = pd.read_json(candidates_file, lines=True)
queries_df = pd.read_json(queries_file, lines=True)
queries_df['authorID'] = queries_df.authorIDs.apply(lambda x: x[0])
candidates_df['authorID'] = candidates_df.authorSetIDs.apply(lambda x: x[0])
queries_df = queries_df.groupby('authorID').agg({'fullText': lambda x: list(x)}).reset_index()
candidates_df = candidates_df.groupby('authorID').agg({'fullText': lambda x: list(x)}).reset_index()
ground_truth_assignment = np.load(open(grount_truth_file, 'rb'))
candidate_authors = [a[2:-3] for a in open(c_labels_file).read().split('\n')][:-1]
query_authors = [a[2:-3] for a in open(q_labels_file).read().split('\n')][:-1]
#print(ground_truth_assignment)
#print(candidate_authors)
#print(query_authors)
yield query_authors, candidate_authors, queries_df, candidates_df, ground_truth_assignment
def main():
"""
Main entry point for the script.
"""
parser = argparse.ArgumentParser(description="Prepare Reddit data for author attribution tasks.")
parser.add_argument("input_path", type=str, help="Path to the input pandas DataFrame pickle file.")
parser.add_argument("output_path", type=str, help="Path to save the output JSON file.")
parser.add_argument("--random_seed", type=int, default=123, help="Random seed for sampling.")
parser.add_argument("--num_docs", type=int, default=5, help="Number of documents per author for query and correct sets.")
args = parser.parse_args()
print(f"Processing data from: {args.input_path}")
output_data = get_reddit_data(
input_path=args.input_path,
random_seed=args.random_seed,
num_documents_per_author=args.num_docs
)
print(f"Saving {len(output_data)} tasks to: {args.output_path}")
with open(args.output_path, 'w') as f:
json.dump(output_data, f, indent=4)
print("Done.")
if __name__ == "__main__":
main()
|