sbasu2512's picture
training on latest dataset
a925c4f
import pandas as pd
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char as nac
# Primary: synonym replacement (good when WordNet has coverage)
aug_synonym = naw.SynonymAug(aug_src='wordnet', aug_p=0.15)
# Fallback: random swap of word positions β€” no dictionary needed, always works
aug_swap = naw.RandomWordAug(action='swap', aug_p=0.2)
# Second fallback: random char-level keyboard typo simulation β€” always produces change
aug_char = nac.KeyboardAug(aug_char_p=0.1, aug_word_p=0.2)
def _try_augment(text: str) -> str | None:
"""Try augmenters in order, return first result that differs from original."""
for augmenter in [aug_synonym, aug_swap, aug_char]:
try:
result = augmenter.augment(text)
new_text = result[0] if isinstance(result, list) else result
if new_text.strip() != text.strip():
return new_text
except Exception:
continue
# All augmenters failed or returned same text β€” return a simple word-swap manually
words = text.split()
if len(words) >= 2:
words[0], words[-1] = words[-1], words[0]
return " ".join(words)
return None # single-word text, truly can't augment
def augment_class(df: pd.DataFrame, label_id: int, target_count: int) -> pd.DataFrame:
df = df.copy()
df["label"] = df["label"].astype(int)
class_df = df[df["label"] == label_id]
if len(class_df) == 0:
raise ValueError(
f"No rows found for label_id={label_id}. "
f"Available labels: {df['label'].unique().tolist()} "
f"(dtype: {df['label'].dtype})"
)
already_have = len(class_df)
if already_have >= target_count:
print(f" Label {label_id} already has {already_have} samples β€” skipping")
return df
needed = target_count - already_have
print(f" Label {label_id}: {already_have} β†’ {target_count} (generating {needed} samples)")
augmented_rows = []
attempts = 0
max_attempts = needed * 10 # increased β€” some headlines will be hard to augment
while len(augmented_rows) < needed and attempts < max_attempts:
attempts += 1
original = class_df["text"].sample(1).values[0]
new_text = _try_augment(original)
if new_text:
augmented_rows.append({"text": new_text, "label": label_id})
if len(augmented_rows) < needed:
print(f" Warning: only generated {len(augmented_rows)}/{needed} after {max_attempts} attempts")
result_df = pd.concat([df, pd.DataFrame(augmented_rows)], ignore_index=True)
result_df["label"] = result_df["label"].astype(int)
result_df["text"] = result_df["text"].astype(str)
print(f" Done. Distribution: {result_df['label'].value_counts().sort_index().to_dict()}")
return result_df