Milad Alshomary
commited on
Commit
·
e5d9888
1
Parent(s):
e7bcc02
updates
Browse files- prepare_data.py +23 -8
- utils/interp_space_utils.py +1 -1
- utils/ui.py +6 -0
prepare_data.py
CHANGED
|
@@ -44,19 +44,22 @@ def sample_ds(input_file, output_file, num_insts=10000, min_num_text_per_inst=0,
|
|
| 44 |
df = pd.DataFrame(out_list)
|
| 45 |
df.to_pickle(output_file)
|
| 46 |
|
| 47 |
-
def get_reddit_data(input_path, random_seed=123, num_instances=50, num_documents_per_author=
|
| 48 |
|
| 49 |
df = pd.read_pickle(open(input_path, 'rb'))
|
|
|
|
|
|
|
|
|
|
| 50 |
output_objs = []
|
| 51 |
|
| 52 |
-
for
|
| 53 |
|
| 54 |
# Get the current author's documents
|
| 55 |
query_author_df = df[df.authorID == row['authorID']]
|
| 56 |
# split the author's documents into two: query and correct author
|
| 57 |
-
author_documents = query_author_df.fullText.tolist()[0]
|
| 58 |
|
| 59 |
-
if len(author_documents)
|
| 60 |
continue
|
| 61 |
|
| 62 |
query_documents = author_documents[:num_documents_per_author]
|
|
@@ -67,17 +70,29 @@ def get_reddit_data(input_path, random_seed=123, num_instances=50, num_documents
|
|
| 67 |
other_authors_df = df[df.authorID != row['authorID']]
|
| 68 |
other_two_authors = other_authors_df.sample(2, random_state=random_seed)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
output_objs.append({
|
| 71 |
"Q_authorID": str(row["authorID"]),
|
| 72 |
-
"Q_fullText": query_documents,
|
| 73 |
"a0_authorID": str(other_two_authors.iloc[0]["authorID"]),
|
| 74 |
-
"a0_fullText": other_two_authors.iloc[0]["fullText"][:num_documents_per_author],
|
| 75 |
"a1_authorID": str(other_two_authors.iloc[1]["authorID"]),
|
| 76 |
-
"a1_fullText": other_two_authors.iloc[1]["fullText"][:num_documents_per_author],
|
| 77 |
"a2_authorID": str(row["authorID"]) + "_correct",
|
| 78 |
-
"a2_fullText": correct_documents,
|
| 79 |
"gt_idx": 2
|
| 80 |
})
|
|
|
|
| 81 |
random_seed += 1 # Increment seed to get different authors for the next task
|
| 82 |
if len(output_objs) >= num_instances:
|
| 83 |
break
|
|
|
|
| 44 |
df = pd.DataFrame(out_list)
|
| 45 |
df.to_pickle(output_file)
|
| 46 |
|
| 47 |
+
def get_reddit_data(input_path, random_seed=123, num_instances=50, num_documents_per_author=8, min_instance_len=10):
|
| 48 |
|
| 49 |
df = pd.read_pickle(open(input_path, 'rb'))
|
| 50 |
+
df['fullText'] = df.fullText.map(lambda x: [d for d in x if len(d.split()) > min_instance_len])
|
| 51 |
+
df = df[df.fullText.str.len() > num_documents_per_author * 2]
|
| 52 |
+
|
| 53 |
output_objs = []
|
| 54 |
|
| 55 |
+
for _, row in df.iterrows():
|
| 56 |
|
| 57 |
# Get the current author's documents
|
| 58 |
query_author_df = df[df.authorID == row['authorID']]
|
| 59 |
# split the author's documents into two: query and correct author
|
| 60 |
+
author_documents = [x for x in query_author_df.fullText.tolist()[0] if len(x.split()) > min_instance_len]
|
| 61 |
|
| 62 |
+
if len(author_documents) <= num_documents_per_author * 2:
|
| 63 |
continue
|
| 64 |
|
| 65 |
query_documents = author_documents[:num_documents_per_author]
|
|
|
|
| 70 |
other_authors_df = df[df.authorID != row['authorID']]
|
| 71 |
other_two_authors = other_authors_df.sample(2, random_state=random_seed)
|
| 72 |
|
| 73 |
+
# output_objs.append({
|
| 74 |
+
# "Q_authorID": str(row["authorID"]),
|
| 75 |
+
# "Q_fullText": "\n\n".join(["Text:\n{}".format(d) for d in query_documents]),
|
| 76 |
+
# "a0_authorID": str(other_two_authors.iloc[0]["authorID"]),
|
| 77 |
+
# "a0_fullText": "\n\n".join(["Text:\n{}".format(d) for d in other_two_authors.iloc[0]["fullText"][:num_documents_per_author]]),
|
| 78 |
+
# "a1_authorID": str(other_two_authors.iloc[1]["authorID"]),
|
| 79 |
+
# "a1_fullText": "\n\n".join(["Text:\n{}".format(d) for d in other_two_authors.iloc[1]["fullText"][:num_documents_per_author]]),
|
| 80 |
+
# "a2_authorID": str(row["authorID"]) + "_correct",
|
| 81 |
+
# "a2_fullText": "\n\n".join(["Text:\n{}".format(d) for d in correct_documents]),
|
| 82 |
+
# "gt_idx": 2
|
| 83 |
+
# })
|
| 84 |
output_objs.append({
|
| 85 |
"Q_authorID": str(row["authorID"]),
|
| 86 |
+
"Q_fullText": ["Text:\n{}".format(d) for d in query_documents],
|
| 87 |
"a0_authorID": str(other_two_authors.iloc[0]["authorID"]),
|
| 88 |
+
"a0_fullText": ["Text:\n{}".format(d) for d in other_two_authors.iloc[0]["fullText"][:num_documents_per_author]],
|
| 89 |
"a1_authorID": str(other_two_authors.iloc[1]["authorID"]),
|
| 90 |
+
"a1_fullText": ["Text:\n{}".format(d) for d in other_two_authors.iloc[1]["fullText"][:num_documents_per_author]],
|
| 91 |
"a2_authorID": str(row["authorID"]) + "_correct",
|
| 92 |
+
"a2_fullText": ["Text:\n{}".format(d) for d in correct_documents],
|
| 93 |
"gt_idx": 2
|
| 94 |
})
|
| 95 |
+
print( "Text:\n\n".join(query_documents))
|
| 96 |
random_seed += 1 # Increment seed to get different authors for the next task
|
| 97 |
if len(output_objs) >= num_instances:
|
| 98 |
break
|
utils/interp_space_utils.py
CHANGED
|
@@ -61,7 +61,7 @@ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd
|
|
| 61 |
# Gather the input texts (preserves list-of-strings if any)
|
| 62 |
#texts = background_corpus_df[text_clm].fillna("").tolist()
|
| 63 |
author_texts = ['\n\n'.join(x) for x in clustered_authors_df.fullText.tolist()]
|
| 64 |
-
|
| 65 |
print(f"Number of author_texts: {len(author_texts)}")
|
| 66 |
|
| 67 |
# Create a reproducible JSON serialization of the texts
|
|
|
|
| 61 |
# Gather the input texts (preserves list-of-strings if any)
|
| 62 |
#texts = background_corpus_df[text_clm].fillna("").tolist()
|
| 63 |
author_texts = ['\n\n'.join(x) for x in clustered_authors_df.fullText.tolist()]
|
| 64 |
+
print('author_text at 0:{}'.format(author_texts[0]))
|
| 65 |
print(f"Number of author_texts: {len(author_texts)}")
|
| 66 |
|
| 67 |
# Create a reproducible JSON serialization of the texts
|
utils/ui.py
CHANGED
|
@@ -159,6 +159,12 @@ def update_task_display(mode, iid, instances, background_df, mystery_file, cand1
|
|
| 159 |
]
|
| 160 |
|
| 161 |
def task_HTML(mystery_text, candidate_texts, predicted_author, ground_truth_author):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
header_html = f"""
|
| 163 |
<div style="border:1px solid #ccc; padding:10px; margin-bottom:10px;">
|
| 164 |
<h3>Here’s the mystery passage alongside three candidate texts—look for the green highlight to see the predicted author.</h3>
|
|
|
|
| 159 |
]
|
| 160 |
|
| 161 |
def task_HTML(mystery_text, candidate_texts, predicted_author, ground_truth_author):
|
| 162 |
+
|
| 163 |
+
# if any of the texts is a list of text then concatenate them
|
| 164 |
+
if isinstance(mystery_text, list):
|
| 165 |
+
mystery_text = "\n\n".join(["Text: {}".format(x) for x in mystery_text])
|
| 166 |
+
candidate_texts = ["\n\n".join(["Text: {}".format(t) for t in x]) for x in candidate_texts]
|
| 167 |
+
|
| 168 |
header_html = f"""
|
| 169 |
<div style="border:1px solid #ccc; padding:10px; margin-bottom:10px;">
|
| 170 |
<h3>Here’s the mystery passage alongside three candidate texts—look for the green highlight to see the predicted author.</h3>
|