EMOTIA / models /advanced /data_augmentation.py
Manav2op's picture
Upload folder using huggingface_hub
25d0747 verified
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.transforms import functional as TF
import torchaudio
import torchaudio.transforms as AT
import numpy as np
import random
from PIL import Image
import librosa
class AdvancedDataAugmentation:
"""
Advanced data augmentation pipeline for multi-modal training
"""
def __init__(self):
# Vision augmentations
self.vision_transforms = T.Compose([
T.RandomApply([T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)], p=0.3),
T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.1),
T.RandomApply([T.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1))], p=0.2),
T.RandomHorizontalFlip(p=0.1),
T.RandomApply([T.RandomErasing(p=0.1, scale=(0.02, 0.1), ratio=(0.3, 3.3))], p=0.1),
])
# Audio augmentations
self.audio_sample_rate = 16000
def augment_vision(self, image):
"""
Apply advanced vision augmentations
"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Apply standard augmentations
augmented = self.vision_transforms(image)
# Additional advanced augmentations
if random.random() < 0.1:
# Simulate different lighting conditions
augmented = TF.adjust_gamma(augmented, random.uniform(0.8, 1.2))
if random.random() < 0.1:
# Add noise
img_array = np.array(augmented)
noise = np.random.normal(0, 5, img_array.shape)
img_array = np.clip(img_array + noise, 0, 255).astype(np.uint8)
augmented = Image.fromarray(img_array)
return augmented
def augment_audio(self, audio, sample_rate):
"""
Apply advanced audio augmentations
"""
if isinstance(audio, torch.Tensor):
audio = audio.numpy()
augmented_audios = [audio]
# Time stretching
if random.random() < 0.3:
rate = random.uniform(0.8, 1.2)
stretched = librosa.effects.time_stretch(audio, rate=rate)
augmented_audios.append(stretched)
# Pitch shifting
if random.random() < 0.3:
steps = random.randint(-2, 2)
pitched = librosa.effects.pitch_shift(audio, sr=sample_rate, n_steps=steps)
augmented_audios.append(pitched)
# Add background noise
if random.random() < 0.2:
noise = np.random.normal(0, 0.01, len(audio))
noisy = audio + noise
augmented_audios.append(noisy)
# Volume perturbation
if random.random() < 0.3:
volume_factor = random.uniform(0.7, 1.3)
volume_aug = audio * volume_factor
augmented_audios.append(volume_aug)
# Random cropping/padding
if random.random() < 0.2:
target_length = int(sample_rate * random.uniform(2.5, 4.0))
if len(audio) > target_length:
start = random.randint(0, len(audio) - target_length)
cropped = audio[start:start + target_length]
else:
padding = target_length - len(audio)
cropped = np.pad(audio, (0, padding), 'constant')
augmented_audios.append(cropped)
# Select one augmentation or original
selected = random.choice(augmented_audios)
# Ensure consistent length (3 seconds)
target_length = sample_rate * 3
if len(selected) > target_length:
selected = selected[:target_length]
elif len(selected) < target_length:
selected = np.pad(selected, (0, target_length - len(selected)), 'constant')
return torch.tensor(selected, dtype=torch.float32)
def augment_text(self, text, tokenizer):
"""
Apply text augmentations
"""
augmented_texts = [text]
# Synonym replacement (simplified)
if random.random() < 0.2:
words = text.split()
if len(words) > 3:
# Simple synonym replacement (would need a proper synonym dictionary)
idx = random.randint(0, len(words) - 1)
# For demo, just shuffle some words
if random.random() < 0.5:
random.shuffle(words)
synonym_aug = ' '.join(words)
augmented_texts.append(synonym_aug)
# Backtranslation augmentation would go here (requires translation models)
# Random deletion
if random.random() < 0.1:
words = text.split()
if len(words) > 3:
keep_prob = 0.9
kept_words = [w for w in words if random.random() < keep_prob]
if kept_words:
deletion_aug = ' '.join(kept_words)
augmented_texts.append(deletion_aug)
selected_text = random.choice(augmented_texts)
return selected_text
class AdvancedPreprocessingPipeline:
"""
Advanced preprocessing pipeline with quality checks and normalization
"""
def __init__(self, target_face_size=(224, 224), target_audio_length=3.0):
self.target_face_size = target_face_size
self.target_audio_length = target_audio_length
self.sample_rate = 16000
# Quality thresholds
self.min_face_confidence = 0.7
self.min_audio_snr = 10.0 # dB
def preprocess_face(self, face_image, bbox=None, landmarks=None):
"""
Advanced face preprocessing with alignment and quality checks
"""
# Quality check
if not self._check_face_quality(face_image):
return None
# Convert to PIL if needed
if isinstance(face_image, np.ndarray):
face_image = Image.fromarray(face_image)
# Face alignment using landmarks if available
if landmarks is not None:
face_image = self._align_face(face_image, landmarks)
# Resize and normalize
face_image = face_image.resize(self.target_face_size, Image.BILINEAR)
# Convert to tensor
face_tensor = TF.to_tensor(face_image)
# Normalize (ImageNet stats for CLIP compatibility)
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
face_tensor = normalize(face_tensor)
return face_tensor
def preprocess_audio(self, audio_path_or_array, sample_rate=None):
"""
Advanced audio preprocessing with quality checks
"""
# Load audio
if isinstance(audio_path_or_array, str):
audio, sr = librosa.load(audio_path_or_array, sr=self.sample_rate)
else:
audio = audio_path_or_array
sr = sample_rate or self.sample_rate
# Resample if needed
if sr != self.sample_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)
# Quality check
if not self._check_audio_quality(audio):
return None
# Voice activity detection (simple energy-based)
audio = self._voice_activity_detection(audio)
# Normalize audio
audio = self._normalize_audio(audio)
# Ensure consistent length
target_samples = int(self.sample_rate * self.target_audio_length)
if len(audio) > target_samples:
# Random crop
start = random.randint(0, len(audio) - target_samples)
audio = audio[start:start + target_samples]
elif len(audio) < target_samples:
# Pad with zeros
padding = target_samples - len(audio)
audio = np.pad(audio, (0, padding), 'constant')
return torch.tensor(audio, dtype=torch.float32)
def preprocess_text(self, text, tokenizer, max_length=128):
"""
Advanced text preprocessing
"""
# Clean text
text = self._clean_text(text)
# Tokenize
encoding = tokenizer(
text,
max_length=max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return encoding
def _check_face_quality(self, face_image):
"""
Check face image quality
"""
if isinstance(face_image, np.ndarray):
# Check resolution
if face_image.shape[0] < 64 or face_image.shape[1] < 64:
return False
# Check brightness
brightness = np.mean(face_image)
if brightness < 30 or brightness > 225:
return False
# Check contrast
contrast = np.std(face_image)
if contrast < 10:
return False
return True
def _check_audio_quality(self, audio):
"""
Check audio quality using SNR
"""
# Simple SNR calculation
signal_power = np.mean(audio ** 2)
noise_power = np.var(audio - np.convolve(audio, np.ones(100)/100, mode='same'))
snr = 10 * np.log10(signal_power / (noise_power + 1e-10))
return snr >= self.min_audio_snr
def _align_face(self, face_image, landmarks):
"""
Align face using facial landmarks
"""
# Simplified alignment - in practice would use proper face alignment
# For now, just return the image
return face_image
def _voice_activity_detection(self, audio, threshold=0.01):
"""
Simple voice activity detection
"""
# Calculate energy
energy = librosa.feature.rms(y=audio, frame_length=1024, hop_length=512)[0]
# Find segments above threshold
active_segments = energy > threshold
if np.any(active_segments):
# Keep only active segments
active_indices = np.where(active_segments)[0]
start_idx = active_indices[0] * 512
end_idx = (active_indices[-1] + 1) * 512
return audio[start_idx:end_idx]
return audio
def _normalize_audio(self, audio):
"""
Normalize audio amplitude
"""
# Peak normalization
max_val = np.max(np.abs(audio))
if max_val > 0:
audio = audio / max_val
return audio
def _clean_text(self, text):
"""
Clean and normalize text
"""
import re
# Remove extra whitespace
text = ' '.join(text.split())
# Remove special characters but keep punctuation
text = re.sub(r'[^\w\s.,!?\'"-]', '', text)
# Normalize quotes
text = text.replace('"', '"').replace('"', '"')
text = text.replace(''', "'").replace(''', "'")
return text.lower()