Auto-FineTune-Ops / preprocessing /augmentation.py
aneeb15's picture
Initial release of Auto-FineTune-Ops
d4398e6
"""
Augmentation Module (Optional)
================================
Lightweight synthetic data expansion stubs.
These are pure-Python approximations. For production quality,
integrate with an LLM API or NLP library.
"""
import random
import re
from dataclasses import dataclass
from typing import List
import pandas as pd
@dataclass
class AugmentationConfig:
"""Configuration for data augmentation."""
enabled: bool = False
paraphrase: bool = False
generate_variations: bool = False
back_translate: bool = False
tone_rewrite: bool = False
augmentation_factor: int = 1 # how many extra copies per sample
# ---------------------------------------------------------------------------
# Synonym map for lightweight paraphrasing
# ---------------------------------------------------------------------------
_SYNONYMS = {
'explain': ['describe', 'elaborate on', 'clarify', 'break down'],
'create': ['generate', 'produce', 'make', 'build'],
'write': ['compose', 'draft', 'author', 'pen'],
'list': ['enumerate', 'outline', 'itemize', 'catalog'],
'help': ['assist', 'aid', 'support', 'guide'],
'show': ['demonstrate', 'display', 'present', 'illustrate'],
'tell': ['inform', 'describe', 'narrate', 'share'],
'give': ['provide', 'supply', 'offer', 'deliver'],
'find': ['locate', 'discover', 'identify', 'search for'],
'use': ['utilize', 'employ', 'apply', 'leverage'],
'what': ['which', 'what exactly'],
'how': ['in what way', 'by what method'],
'important': ['crucial', 'essential', 'significant', 'vital'],
'good': ['excellent', 'great', 'effective', 'beneficial'],
'bad': ['poor', 'negative', 'harmful', 'detrimental'],
'big': ['large', 'significant', 'substantial', 'major'],
'small': ['minor', 'slight', 'modest', 'minimal'],
}
def paraphrase_instruction(text: str) -> str:
"""
Simple synonym-based paraphrasing.
Replaces one random word with a synonym.
"""
if not isinstance(text, str) or len(text.strip()) < 5:
return text
words = text.split()
candidates = []
for i, word in enumerate(words):
word_lower = word.lower().strip('.,!?;:')
if word_lower in _SYNONYMS:
candidates.append((i, word_lower))
if not candidates:
return text
idx, orig_word = random.choice(candidates)
replacement = random.choice(_SYNONYMS[orig_word])
# Preserve original casing
if words[idx][0].isupper():
replacement = replacement.capitalize()
# Preserve trailing punctuation
trailing = ''
if words[idx] and words[idx][-1] in '.,!?;:':
trailing = words[idx][-1]
words[idx] = replacement + trailing
else:
words[idx] = replacement
return ' '.join(words)
def generate_variation(text: str) -> str:
"""
Generate a minor variation of the text:
- Random case changes
- Add/remove trailing punctuation
- Slight word reordering at clause boundaries
"""
if not isinstance(text, str) or len(text.strip()) < 5:
return text
variations = [
lambda t: t.rstrip('.!?') + random.choice(['.', '!', '?', '']),
lambda t: t[0].upper() + t[1:] if len(t) > 1 else t,
lambda t: re.sub(r'\s+', ' ', t).strip(),
lambda t: t + ' Please be detailed.' if random.random() > 0.5 else t,
]
variation = random.choice(variations)
return variation(text)
def back_translate(text: str) -> str:
"""
Stub for back-translation.
In production, this would translate to another language and back.
Here we just do a light paraphrase.
"""
return paraphrase_instruction(text)
def rewrite_tone(text: str, tone: str = "formal") -> str:
"""
Stub for tone rewriting.
"""
tone_prefixes = {
'formal': 'Please ',
'casual': 'Hey, can you ',
'academic': 'Kindly provide a detailed analysis of ',
'friendly': 'I would really appreciate if you could ',
}
prefix = tone_prefixes.get(tone, '')
# Don't double-prefix
if text.lower().startswith(prefix.lower().strip()):
return text
# Simple approach: prepend tone prefix if the text starts with a verb-like word
first_word = text.split()[0].lower() if text.split() else ''
action_words = {'explain', 'describe', 'write', 'create', 'list', 'show', 'tell', 'give', 'find', 'help', 'make'}
if first_word in action_words:
return prefix + text[0].lower() + text[1:]
return text
def augment_dataset(
df: pd.DataFrame,
col: str,
config: AugmentationConfig,
) -> pd.DataFrame:
"""
Apply augmentation to create additional samples.
Returns the original + augmented samples.
"""
if not config.enabled:
return df
methods = []
if config.paraphrase:
methods.append(paraphrase_instruction)
if config.generate_variations:
methods.append(generate_variation)
if config.back_translate:
methods.append(back_translate)
if config.tone_rewrite:
methods.append(lambda t: rewrite_tone(t, "formal"))
if not methods:
return df
new_rows = []
for _, row in df.iterrows():
for _ in range(config.augmentation_factor):
method = random.choice(methods)
new_row = row.copy()
new_row[col] = method(str(row[col]))
new_rows.append(new_row)
if new_rows:
augmented = pd.DataFrame(new_rows)
return pd.concat([df, augmented], ignore_index=True)
return df