File size: 2,820 Bytes
a925c4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import pandas as pd
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char as nac

# Primary: synonym replacement (good when WordNet has coverage)
aug_synonym = naw.SynonymAug(aug_src='wordnet', aug_p=0.15)

# Fallback: random swap of word positions — no dictionary needed, always works
aug_swap = naw.RandomWordAug(action='swap', aug_p=0.2)

# Second fallback: random char-level keyboard typo simulation — always produces change
aug_char = nac.KeyboardAug(aug_char_p=0.1, aug_word_p=0.2)

def _try_augment(text: str) -> str | None:
    """Try augmenters in order, return first result that differs from original."""
    for augmenter in [aug_synonym, aug_swap, aug_char]:
        try:
            result = augmenter.augment(text)
            new_text = result[0] if isinstance(result, list) else result
            if new_text.strip() != text.strip():
                return new_text
        except Exception:
            continue
    # All augmenters failed or returned same text — return a simple word-swap manually
    words = text.split()
    if len(words) >= 2:
        words[0], words[-1] = words[-1], words[0]
        return " ".join(words)
    return None  # single-word text, truly can't augment

def augment_class(df: pd.DataFrame, label_id: int, target_count: int) -> pd.DataFrame:
    df = df.copy()
    df["label"] = df["label"].astype(int)

    class_df = df[df["label"] == label_id]

    if len(class_df) == 0:
        raise ValueError(
            f"No rows found for label_id={label_id}. "
            f"Available labels: {df['label'].unique().tolist()} "
            f"(dtype: {df['label'].dtype})"
        )

    already_have = len(class_df)
    if already_have >= target_count:
        print(f"  Label {label_id} already has {already_have} samples — skipping")
        return df

    needed = target_count - already_have
    print(f"  Label {label_id}: {already_have}{target_count} (generating {needed} samples)")

    augmented_rows = []
    attempts = 0
    max_attempts = needed * 10  # increased — some headlines will be hard to augment

    while len(augmented_rows) < needed and attempts < max_attempts:
        attempts += 1
        original = class_df["text"].sample(1).values[0]
        new_text = _try_augment(original)
        if new_text:
            augmented_rows.append({"text": new_text, "label": label_id})

    if len(augmented_rows) < needed:
        print(f"  Warning: only generated {len(augmented_rows)}/{needed} after {max_attempts} attempts")

    result_df = pd.concat([df, pd.DataFrame(augmented_rows)], ignore_index=True)
    result_df["label"] = result_df["label"].astype(int)
    result_df["text"] = result_df["text"].astype(str)

    print(f"  Done. Distribution: {result_df['label'].value_counts().sort_index().to_dict()}")
    return result_df