ALIGN-Sim / src /SentencePerturbation /sentence_perturbation.py
yzm0034's picture
Upload folder using huggingface_hub
4f08d2c verified
raw
history blame
8.93 kB
from absl import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import sys
sys.path.insert(0, "/home/yash/ALIGN-SIM/src")
from utils import mkdir_p, full_path, read_data
from SentencePerturbation.word_replacer import WordReplacer, WordSwapping
import random
from perturbation_args import get_args
def perturb_sentences(dataset_name: str, task: str, target_lang:str ="en", output_dir: str = "./data/perturbed_dataset/", sample_size: int = 3500, save :str = False) -> None:
"""
perturb_sentences _summary_
Args:
dataset_name (str): ["MRPC","PAWS","QQP"]
task (str): ["Synonym","Antonym","Jumbling"]
target_lang (str, optional): _description_. Defaults to "en".
output_dir (str, optional): _description_. Defaults to "./data/perturbed_dataset/".
sample_size (int, optional): _description_. Defaults to 3500.
save (str, optional): _description_. Defaults to False.
"""
print("--------------------------------------")
output_csv = full_path(os.path.join(output_dir, target_lang, task, f"{dataset_name}_{task}_perturbed_{target_lang}.csv"))
if os.path.exists(output_csv):
print(f"File already exists at: {output_csv}")
return
# TODO: make it compatible with other language datasets
print("Loading dataset...")
data = read_data(dataset_name)
if "Unnamed: 0" in data.columns:
data.drop("Unnamed: 0", axis=1, inplace=True)
if "idx" in data.columns:
data.drop("idx", axis=1, inplace=True)
print(f"Loaded {dataset_name} dataset")
print("--------------------------------------")
# Initialize WordReplacer
replacer = WordReplacer()
# set seed
random.seed(42)
# Create a new dataframe to store perturbed sentences
# Sample sentences
perturbed_data = pd.DataFrame(columns=["original_sentence"])
# sample_data , pos_pairs, balance_dataset = sampling(data, sample_size)
if task in ["Syn","syn","Synonym"]:
print("Creating Synonym perturbed data...")
sample_data = sampling(data, task, sample_size)
perturbed_data["original_sentence"] = sample_data.sentence1
perturbed_data["perturb_n1"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 1, "synonyms"))
perturbed_data["perturb_n2"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 2, "synonyms"))
perturbed_data["perturb_n3"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 3, "synonyms"))
assert perturbed_data.shape[1] == 4, "Perturbed data size mismatch"
if task in ["paraphrase","Paraphrase","para"]:
print("Creating Paraphrase perturbed data...")
# shuffling the negative samples
# we also want equal number of positive and negative samples
perturbed_data = sampling(data, task, sample_size) # balance data
perturbed_data["original_sentence"] = perturbed_data.sentence1
perturbed_data["paraphrased_sentence"] = perturbed_data.sentence2
assert perturbed_data.shape[1] == 3, "Perturbed data size mismatch" # original_sentence, paraphrased, label
if task in ["Anto","anto","Antonym"]:
print("Creating Antonym perturbed data...")
pos_pairs = sampling(data, task, sample_size)
# Apply antonym replacement
perturbed_data["original_sentence"] = pos_pairs.sentence1
perturbed_data["paraphrased_sentence"] = pos_pairs.sentence2
perturbed_data["perturb_n1"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 1, "antonyms"))
assert perturbed_data.shape[1] == 3, "Perturbed data size mismatch"
# Apply jumbling
if task in ["jumbling", "Jumbling","jumb"]:
print("Creating Jumbling perturbed data...")
pos_pairs = sampling(data, task, sample_size)
perturbed_data["original_sentence"] = pos_pairs.sentence1
perturbed_data["paraphrased_sentence"] = pos_pairs.sentence2
perturbed_data["perturb_n1"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,1))
perturbed_data["perturb_n2"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,2))
perturbed_data["perturb_n3"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,3))
assert perturbed_data.shape[1] == 5, "Perturbed data size mismatch"
# Save to CSV
if save:
perturbed_data.to_csv(mkdir_p(output_csv), index=False)
print("--------------------------------------")
print(f"Saved at: {output_csv}")
print("--------------------------------------")
def sampling(data: pd.DataFrame, task :str, sample_size: int, random_state: int = 42):
"""
Combines two sampling strategies:
1. sampled_data: Samples from the dataset by first taking all positive pairs and then,
if needed, filling the remainder with negative pairs.
2. balanced_data: Constructs a dataset with roughly equal positive and negative pairs,
adjusting the numbers if one group is underrepresented.
Returns:
sampled_data (pd.DataFrame): Dataset sampled by filling negatives if positives are insufficient.
positive_data (pd.DataFrame): All positive samples (label == 1).
balanced_data (pd.DataFrame): Dataset balanced between positive and negative pairs.
"""
# Split the data into positive and negative pairs
positive_data = data[data["label"] == 1]
negative_data = data[data["label"] == 0]
if task in ["Anto","anto","Antonym","jumbling", "Jumbling","jumb"]:
return positive_data
# ----- Sampling positive pair, but also checking if we satisfy sample size -----
if sample_size is None or sample_size > len(positive_data):
# If no sample size is provided or it exceeds the available data,
# return a copy of the entire dataset.
sampled_data = positive_data.copy()
else:
# Otherwise, randomly sample the specified number of rows.
sampled_data = positive_data.sample(n=sample_size, random_state=random_state)
if task in ["Syn","syn","Synonym"]:
return sampled_data
# ----- Sampling for Paraphrased Criterion -----
# Shuffle negative pairs first
negative_data = negative_data.reset_index(drop=True)
shuffled_sentence2 = negative_data["sentence2"].sample(frac=1, random_state=random_state).reset_index(drop=True)
negative_data["sentence2"] = shuffled_sentence2
# Determine ideal sample size per group (half of total sample size)
if sample_size is None:
pos_sample_size = len(positive_data)
neg_sample_size = len(negative_data)
else:
# Determine ideal sample size per group (half of total sample size)
half_size = sample_size // 2
pos_available = len(positive_data)
neg_available = len(negative_data)
pos_sample_size = min(half_size, pos_available)
neg_sample_size = min(half_size, neg_available)
# If there is a remainder, add extra samples from the group with more available data.
total_sampled = pos_sample_size + neg_sample_size
remainder = sample_size - total_sampled
if remainder > 0:
if (pos_available - pos_sample_size) >= (neg_available - neg_sample_size):
pos_sample_size += remainder
else:
neg_sample_size += remainder
# Sample from each group
sampled_positive = positive_data.sample(n=pos_sample_size, random_state=random_state)
sampled_negative = negative_data.sample(n=neg_sample_size, random_state=random_state)
# Add a 'label' column
sampled_positive["label"] = 1
sampled_negative["label"] = 0
# Combine and shuffle the resulting dataset
balanced_data = pd.concat([sampled_positive, sampled_negative]).sample(frac=1, random_state=random_state).reset_index(drop=True)
if task in ["paraphrase","Paraphrase","para"]:
return balanced_data
# return sampled_data, positive_data, balanced_data
if __name__ == "__main__":
# # For Testing
if sys.gettrace() is not None:
config = {
"dataset_name": "mrpc",
"task": "syn",
"target_lang": "en",
"output_dir": "./data/perturbed_dataset/",
"save": True
}
else:
args = get_args()
config = {
"dataset_name": args.dataset_name,
"task": args.task,
"target_lang": args.target_lang,
"output_dir": args.output_dir,
"save": args.save,
"sample_size": args.sample_size
}
perturb_sentences(**config)