| | """Advanced Data Augmentation for Training"""
|
| |
|
| | import json
|
| | import logging
|
| | import random
|
| | from abc import ABC, abstractmethod
|
| | from dataclasses import dataclass
|
| | from typing import Any, Dict, List, Optional, Tuple
|
| |
|
| | import numpy as np
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | @dataclass
|
| | class AugmentationConfig:
|
| | """Configuration for data augmentation."""
|
| | enabled_methods: List[str] = field(default_factory=lambda: [
|
| | "synonym_replacement",
|
| | "back_translation",
|
| | "code_perturbation",
|
| | "paraphrasing",
|
| | "noise_injection",
|
| | ])
|
| | probabilities: Dict[str, float] = field(default_factory=lambda: {
|
| | "synonym_replacement": 0.3,
|
| | "back_translation": 0.2,
|
| | "code_perturbation": 0.4,
|
| | "paraphrasing": 0.3,
|
| | "noise_injection": 0.1,
|
| | })
|
| | max_augmentations_per_sample: int = 2
|
| |
|
| |
|
| | class AugmentationMethod(ABC):
|
| | """Base class for augmentation methods."""
|
| |
|
| | @abstractmethod
|
| | def augment(self, sample: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| | """Apply augmentation to sample. Return augmented sample or None if failed."""
|
| | pass
|
| |
|
| | @abstractmethod
|
| | def can_augment(self, sample: Dict[str, Any]) -> bool:
|
| | """Check if sample can be augmented by this method."""
|
| | pass
|
| |
|
| |
|
| | class SynonymReplacement(AugmentationMethod):
|
| | """Replace words with synonyms."""
|
| |
|
| | def __init__(self, replacement_prob: float = 0.1):
|
| | self.replacement_prob = replacement_prob
|
| |
|
| | self.synonyms = {
|
| | "good": ["excellent", "great", "fine", "quality", "superb"],
|
| | "bad": ["poor", "terrible", "awful", "inferior", "subpar"],
|
| | "big": ["large", "huge", "enormous", "massive", "giant"],
|
| | "small": ["tiny", "little", "miniature", "compact", "petite"],
|
| | "fast": ["quick", "rapid", "speedy", "swift", "expedited"],
|
| | "slow": ["sluggish", "leisurely", "unhurried", "gradual", "delayed"],
|
| | "create": ["build", "generate", "produce", "develop", "construct"],
|
| | "use": ["utilize", "employ", "apply", "leverage", "harness"],
|
| | "find": ["discover", "locate", "detect", "identify", "uncover"],
|
| | "improve": ["enhance", "upgrade", "optimize", "refine", "better"],
|
| | }
|
| |
|
| | def can_augment(self, sample: Dict[str, Any]) -> bool:
|
| | """Check if sample has text to augment."""
|
| | text = self._extract_text(sample)
|
| | return len(text.split()) > 10
|
| |
|
| | def augment(self, sample: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| | """Replace random words with synonyms."""
|
| | text = self._extract_text(sample)
|
| | words = text.split()
|
| |
|
| |
|
| | new_words = []
|
| | for word in words:
|
| | if random.random() < self.replacement_prob and word.lower() in self.synonyms:
|
| | synonym = random.choice(self.synonyms[word.lower()])
|
| |
|
| | if word[0].isupper():
|
| | synonym = synonym.capitalize()
|
| | new_words.append(synonym)
|
| | else:
|
| | new_words.append(word)
|
| |
|
| | new_text = " ".join(new_words)
|
| | if new_text == text:
|
| | return None
|
| |
|
| | augmented = sample.copy()
|
| | self._replace_text(augmented, new_text)
|
| | augmented["augmentation"] = "synonym_replacement"
|
| | return augmented
|
| |
|
| | def _extract_text(self, sample: Dict[str, Any]) -> str:
|
| | """Extract text from sample."""
|
| | if "conversations" in sample:
|
| | conv = sample["conversations"]
|
| | if isinstance(conv, list):
|
| | return " ".join(msg.get("content", "") for msg in conv if isinstance(msg, dict))
|
| | return sample.get("text", sample.get("content", ""))
|
| |
|
| | def _replace_text(self, sample: Dict[str, Any], new_text: str):
|
| | """Replace text in sample."""
|
| | if "conversations" in sample:
|
| | conv = sample["conversations"]
|
| | if isinstance(conv, list):
|
| |
|
| | for msg in conv:
|
| | if isinstance(msg, dict) and "content" in msg:
|
| | msg["content"] = new_text[:len(msg["content"])]
|
| | break
|
| | else:
|
| | sample["text"] = new_text
|
| | sample["content"] = new_text
|
| |
|
| |
|
| | class CodePerturbation(AugmentationMethod):
|
| | """Perturb code while preserving functionality."""
|
| |
|
| | def __init__(self):
|
| | self.perturbations = [
|
| | self._rename_variables,
|
| | self._reorder_statements,
|
| | self._add_redundant_parentheses,
|
| | self._change_loop_style,
|
| | self._add_comments,
|
| | ]
|
| |
|
| | def can_augment(self, sample: Dict[str, Any]) -> bool:
|
| | """Check if sample has code."""
|
| | return "code" in sample or any(
|
| | "```" in str(conv.get("content", ""))
|
| | for conv in sample.get("conversations", [])
|
| | if isinstance(conv, dict)
|
| | )
|
| |
|
| | def augment(self, sample: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| | """Apply random code perturbation."""
|
| | code = self._extract_code(sample)
|
| | if not code:
|
| | return None
|
| |
|
| |
|
| | perturbation = random.choice(self.perturbations)
|
| | new_code = perturbation(code)
|
| |
|
| | if new_code == code:
|
| | return None
|
| |
|
| | augmented = sample.copy()
|
| | self._replace_code(augmented, new_code)
|
| | augmented["augmentation"] = f"code_perturbation:{perturbation.__name__}"
|
| | return augmented
|
| |
|
| | def _extract_code(self, sample: Dict[str, Any]) -> str:
|
| | """Extract code from sample."""
|
| | if "code" in sample:
|
| | return sample["code"]
|
| |
|
| | for conv in sample.get("conversations", []):
|
| | if isinstance(conv, dict):
|
| | content = conv.get("content", "")
|
| | if "```" in content:
|
| |
|
| | parts = content.split("```")
|
| | if len(parts) >= 2:
|
| | return parts[1].strip()
|
| | return ""
|
| |
|
| | def _replace_code(self, sample: Dict[str, Any], new_code: str):
|
| | """Replace code in sample."""
|
| | if "code" in sample:
|
| | sample["code"] = new_code
|
| | else:
|
| | for conv in sample.get("conversations", []):
|
| | if isinstance(conv, dict) and "```" in conv.get("content", ""):
|
| | parts = conv["content"].split("```")
|
| | conv["content"] = f"```{new_code}```"
|
| |
|
| | def _rename_variables(self, code: str) -> str:
|
| | """Rename variables to random names (simple version)."""
|
| |
|
| | import re
|
| |
|
| | variables = re.findall(r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b', code)
|
| | unique_vars = set(variables)
|
| |
|
| |
|
| | replacements = {}
|
| | for var in unique_vars:
|
| | if len(var) > 1 and var not in ["if", "for", "while", "def", "class", "return", "import", "from"]:
|
| | new_name = f"var_{random.randint(1000, 9999)}"
|
| | replacements[var] = new_name
|
| |
|
| |
|
| | for old, new in replacements.items():
|
| | code = code.replace(old, new)
|
| |
|
| | return code
|
| |
|
| | def _reorder_statements(self, code: str) -> str:
|
| | """Reorder independent statements."""
|
| | lines = code.split('\n')
|
| |
|
| |
|
| | return code
|
| |
|
| | def _add_redundant_parentheses(self, code: str) -> str:
|
| | """Add redundant parentheses."""
|
| |
|
| | import re
|
| |
|
| | return code
|
| |
|
| | def _change_loop_style(self, code: str) -> str:
|
| | """Change between for loops and while loops where possible."""
|
| |
|
| | return code
|
| |
|
| | def _add_comments(self, code: str) -> str:
|
| | """Add explanatory comments."""
|
| | lines = code.split('\n')
|
| | new_lines = []
|
| | for i, line in enumerate(lines):
|
| | new_lines.append(line)
|
| | if line.strip() and not line.strip().startswith('#'):
|
| | if random.random() < 0.2:
|
| | new_lines.append(f"# TODO: Explain this line")
|
| | return '\n'.join(new_lines)
|
| |
|
| |
|
| | class BackTranslation(AugmentationMethod):
|
| | """Simulate back-translation by paraphrasing."""
|
| |
|
| | def __init__(self):
|
| | self.paraphrase_templates = [
|
| | "In other words, {text}",
|
| | "To put it differently, {text}",
|
| | "That is to say, {text}",
|
| | "Alternatively, {text}",
|
| | ]
|
| |
|
| | def can_augment(self, sample: Dict[str, Any]) -> bool:
|
| | """Check if sample has text suitable for back translation."""
|
| | text = self._extract_text(sample)
|
| | return len(text) > 50
|
| |
|
| | def augment(self, sample: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| | """Apply back translation simulation."""
|
| | text = self._extract_text(sample)
|
| | template = random.choice(self.paraphrase_templates)
|
| | new_text = template.format(text=text[:200]) + text[200:]
|
| |
|
| | if new_text == text:
|
| | return None
|
| |
|
| | augmented = sample.copy()
|
| | self._replace_text(augmented, new_text)
|
| | augmented["augmentation"] = "back_translation"
|
| | return augmented
|
| |
|
| | def _extract_text(self, sample: Dict[str, Any]) -> str:
|
| | """Extract text from sample."""
|
| | if "conversations" in sample:
|
| | conv = sample["conversations"]
|
| | if isinstance(conv, list):
|
| | return " ".join(msg.get("content", "") for msg in conv if isinstance(msg, dict))
|
| | return sample.get("text", sample.get("content", ""))
|
| |
|
| | def _replace_text(self, sample: Dict[str, Any], new_text: str):
|
| | """Replace text in sample."""
|
| | if "conversations" in sample:
|
| | conv = sample["conversations"]
|
| | if isinstance(conv, list):
|
| | for msg in conv:
|
| | if isinstance(msg, dict) and "content" in msg:
|
| | msg["content"] = new_text[:len(msg["content"])]
|
| | break
|
| | else:
|
| | sample["text"] = new_text
|
| | sample["content"] = new_text
|
| |
|
| |
|
| | class Paraphrasing(AugmentationMethod):
|
| | """Paraphrase text using templates."""
|
| |
|
| | def __init__(self):
|
| | self.paraphrase_patterns = [
|
| | (r"\b(is)\b", ["represents", "constitutes", "means"]),
|
| | (r"\b(has)\b", ["contains", "possesses", "includes"]),
|
| | (r"\b(use)\b", ["utilize", "employ", "leverage"]),
|
| | (r"\b(make)\b", ["create", "build", "produce"]),
|
| | (r"\b(find)\b", ["discover", "locate", "identify"]),
|
| | ]
|
| |
|
| | def can_augment(self, sample: Dict[str, Any]) -> bool:
|
| | """Check if sample has text."""
|
| | text = self._extract_text(sample)
|
| | return len(text) > 30
|
| |
|
| | def augment(self, sample: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| | """Apply paraphrasing."""
|
| | text = self._extract_text(sample)
|
| | new_text = text
|
| |
|
| |
|
| | pattern, replacements = random.choice(self.paraphrase_patterns)
|
| | import re
|
| | matches = re.findall(pattern, text, re.IGNORECASE)
|
| | if matches:
|
| |
|
| | old_word = matches[0]
|
| | new_word = random.choice(replacements)
|
| | new_text = re.sub(pattern, new_word, text, count=1, flags=re.IGNORECASE)
|
| |
|
| | if new_text == text:
|
| | return None
|
| |
|
| | augmented = sample.copy()
|
| | self._replace_text(augmented, new_text)
|
| | augmented["augmentation"] = "paraphrasing"
|
| | return augmented
|
| |
|
| | def _extract_text(self, sample: Dict[str, Any]) -> str:
|
| | """Extract text from sample."""
|
| | if "conversations" in sample:
|
| | conv = sample["conversations"]
|
| | if isinstance(conv, list):
|
| | return " ".join(msg.get("content", "") for msg in conv if isinstance(msg, dict))
|
| | return sample.get("text", sample.get("content", ""))
|
| |
|
| | def _replace_text(self, sample: Dict[str, Any], new_text: str):
|
| | """Replace text in sample."""
|
| | if "conversations" in sample:
|
| | conv = sample["conversations"]
|
| | if isinstance(conv, list):
|
| | for msg in conv:
|
| | if isinstance(msg, dict) and "content" in msg:
|
| | msg["content"] = new_text[:len(msg["content"])]
|
| | break
|
| | else:
|
| | sample["text"] = new_text
|
| | sample["content"] = new_text
|
| |
|
| |
|
| | class NoiseInjection(AugmentationMethod):
|
| | """Inject noise into text."""
|
| |
|
| | def __init__(self, noise_prob: float = 0.01):
|
| | self.noise_prob = noise_prob
|
| | self.noise_tokens = ["[MASK]", "<noise>", "...", "[UNK]"]
|
| |
|
| | def can_augment(self, sample: Dict[str, Any]) -> bool:
|
| | """Check if sample has text."""
|
| | text = self._extract_text(sample)
|
| | return len(text) > 20
|
| |
|
| | def augment(self, sample: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| | """Inject noise tokens."""
|
| | text = self._extract_text(sample)
|
| | words = text.split()
|
| |
|
| |
|
| | new_words = []
|
| | for word in words:
|
| | if random.random() < self.noise_prob and len(word) > 3:
|
| | new_words.append(random.choice(self.noise_tokens))
|
| | else:
|
| | new_words.append(word)
|
| |
|
| | new_text = " ".join(new_words)
|
| | if new_text == text:
|
| | return None
|
| |
|
| | augmented = sample.copy()
|
| | self._replace_text(augmented, new_text)
|
| | augmented["augmentation"] = "noise_injection"
|
| | return augmented
|
| |
|
| | def _extract_text(self, sample: Dict[str, Any]) -> str:
|
| | """Extract text from sample."""
|
| | if "conversations" in sample:
|
| | conv = sample["conversations"]
|
| | if isinstance(conv, list):
|
| | return " ".join(msg.get("content", "") for msg in conv if isinstance(msg, dict))
|
| | return sample.get("text", sample.get("content", ""))
|
| |
|
| | def _replace_text(self, sample: Dict[str, Any], new_text: str):
|
| | """Replace text in sample."""
|
| | if "conversations" in sample:
|
| | conv = sample["conversations"]
|
| | if isinstance(conv, list):
|
| | for msg in conv:
|
| | if isinstance(msg, dict) and "content" in msg:
|
| | msg["content"] = new_text[:len(msg["content"])]
|
| | break
|
| | else:
|
| | sample["text"] = new_text
|
| | sample["content"] = new_text
|
| |
|
| |
|
| | class DataAugmenter:
|
| | """Manages multiple augmentation methods."""
|
| |
|
| | def __init__(self, config: AugmentationConfig):
|
| | self.config = config
|
| | self.methods: Dict[str, AugmentationMethod] = {
|
| | "synonym_replacement": SynonymReplacement(),
|
| | "back_translation": BackTranslation(),
|
| | "code_perturbation": CodePerturbation(),
|
| | "paraphrasing": Paraphrasing(),
|
| | "noise_injection": NoiseInjection(),
|
| | }
|
| |
|
| | def augment(self, sample: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| | """Apply random augmentation to sample."""
|
| |
|
| | enabled_methods = [
|
| | m for m in self.config.enabled_methods
|
| | if m in self.methods and self.methods[m].can_augment(sample)
|
| | ]
|
| |
|
| | if not enabled_methods:
|
| | return None
|
| |
|
| | method_name = random.choice(enabled_methods)
|
| | method = self.methods[method_name]
|
| |
|
| |
|
| | augmented = method.augment(sample)
|
| |
|
| | if augmented:
|
| | augmented["augmentation_applied"] = method_name
|
| |
|
| | return augmented
|
| |
|
| | def augment_batch(
|
| | self,
|
| | batch: List[Dict[str, Any]],
|
| | augmentation_ratio: float = 0.1,
|
| | ) -> List[Dict[str, Any]]:
|
| | """Augment a batch of samples."""
|
| | augmented_batch = []
|
| |
|
| | for sample in batch:
|
| | augmented_batch.append(sample)
|
| |
|
| | if random.random() < augmentation_ratio:
|
| | augmented = self.augment(sample)
|
| | if augmented:
|
| | augmented_batch.append(augmented)
|
| |
|
| | return augmented_batch
|
| |
|
| |
|
| | def augment_sample(
|
| | sample: Dict[str, Any],
|
| | methods: List[str],
|
| | max_augmentations: int = 2,
|
| | ) -> List[Dict[str, Any]]:
|
| | """Augment a single sample with multiple methods."""
|
| | augmenter = DataAugmenter(AugmentationConfig(enabled_methods=methods))
|
| | results = [sample]
|
| |
|
| | for _ in range(max_augmentations):
|
| | augmented = augmenter.augment(sample)
|
| | if augmented:
|
| | results.append(augmented)
|
| |
|
| | return results
|
| |
|