text_classificators / src /imbalance_handling.py
theformatisvalid's picture
Upload 7 files
2153792 verified
raw
history blame
8.84 kB
from typing import List, Tuple, Union, Dict, Optional, Any, Callable
import numpy as np
import pandas as pd
from collections import Counter
def compute_class_weights(y: Union[List, np.ndarray], method: str = "balanced") -> Union[Dict[int, float], None]:
if method == "balanced":
from sklearn.utils.class_weight import compute_class_weight
classes = np.unique(y)
weights = compute_class_weight('balanced', classes=classes, y=y)
return dict(zip(classes, weights))
else:
return None
def get_pytorch_weighted_loss(class_weights: Optional[Dict[int, float]] = None,
num_classes: Optional[int] = None) -> 'torch.nn.Module':
try:
import torch
import torch.nn as nn
except ImportError:
raise ImportError("PyTorch not installed")
if class_weights is not None:
weight_tensor = torch.tensor([class_weights[i] for i in sorted(class_weights.keys())], dtype=torch.float)
return nn.CrossEntropyLoss(weight=weight_tensor)
else:
return nn.CrossEntropyLoss()
def get_tensorflow_weighted_loss(class_weights: Optional[Dict[int, float]] = None) -> Callable:
if not class_weights:
return 'sparse_categorical_crossentropy'
weight_list = [class_weights[i] for i in sorted(class_weights.keys())]
import tensorflow as tf
def weighted_sparse_categorical_crossentropy(y_true, y_pred):
y_true = tf.cast(y_true, tf.int32)
y_true_one_hot = tf.one_hot(y_true, depth=len(weight_list))
weights = tf.reduce_sum(y_true_one_hot * weight_list, axis=1)
unweighted_losses = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
weighted_losses = unweighted_losses * weights
return tf.reduce_mean(weighted_losses)
return weighted_sparse_categorical_crossentropy
def apply_sampling(
X: np.ndarray,
y: np.ndarray,
method: str = "random_under",
random_state: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
from imblearn.over_sampling import SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import RandomOverSampler
if method == "random_under":
sampler = RandomUnderSampler(random_state=random_state)
elif method == "random_over":
sampler = RandomOverSampler(random_state=random_state)
elif method == "smote":
sampler = SMOTE(random_state=random_state)
elif method == "adasyn":
sampler = ADASYN(random_state=random_state)
else:
raise ValueError("method must be one of: random_under, random_over, smote, adasyn")
X_res, y_res = sampler.fit_resample(X, y)
return X_res, y_res
def augment_texts(
texts: List[str],
labels: List[Any],
augmentation_type: str = "synonym",
aug_p: float = 0.1,
lang: str = "ru", # language code
model_name: Optional[str] = None,
num_aug: int = 1,
random_state: int = 42
) -> Tuple[List[str], List[Any]]:
try:
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.sentence as nas
except ImportError:
raise ImportError("Install nlpaug: pip install nlpaug")
augmented_texts = []
augmented_labels = []
if augmentation_type == "synonym":
if lang == "en":
aug = naw.SynonymAug(aug_p=aug_p, aug_max=None)
else:
aug = naw.ContextualWordEmbsAug(
model_path='bert-base-multilingual-cased',
action="substitute",
aug_p=aug_p,
device='cpu'
)
elif augmentation_type == "insert":
aug = naw.RandomWordAug(action="insert", aug_p=aug_p)
elif augmentation_type == "delete":
aug = naw.RandomWordAug(action="delete", aug_p=aug_p)
elif augmentation_type == "swap":
aug = naw.RandomWordAug(action="swap", aug_p=aug_p)
elif augmentation_type == "eda":
aug = naw.AntonymAug()
elif augmentation_type == "back_trans":
if not model_name:
if lang == "ru":
model_name = "Helsinki-NLP/opus-mt-ru-en"
back_model = "Helsinki-NLP/opus-mt-en-ru"
else:
model_name = "Helsinki-NLP/opus-mt-en-ru"
back_model = "Helsinki-NLP/opus-mt-ru-en"
else:
back_model = model_name
try:
from transformers import pipeline
translator1 = pipeline("translation", model=model_name, tokenizer=model_name)
translator2 = pipeline("translation", model=back_model, tokenizer=back_model)
def back_translate(text):
try:
trans = translator1(text)[0]['translation_text']
back = translator2(trans)[0]['translation_text']
return back
except Exception:
return text
augmented = [back_translate(t) for t in texts for _ in range(num_aug)]
labels_aug = [l for l in labels for _ in range(num_aug)]
return augmented, labels_aug
except Exception as e:
print(f"Back-translation failed: {e}. Falling back to synonym augmentation.")
aug = naw.ContextualWordEmbsAug(model_path='bert-base-multilingual-cased', aug_p=aug_p)
elif augmentation_type == "llm":
raise NotImplementedError("LLM-controlled augmentation requires external API (e.g., OpenAI, YandexGPT)")
else:
raise ValueError("Unknown augmentation_type")
for text, label in zip(texts, labels):
for _ in range(num_aug):
try:
aug_text = aug.augment(text)
if isinstance(aug_text, list):
aug_text = aug_text[0]
augmented_texts.append(aug_text)
augmented_labels.append(label)
except Exception as e:
augmented_texts.append(text)
augmented_labels.append(label)
return augmented_texts, augmented_labels
def balance_text_dataset(
texts: List[str],
labels: List[Any],
strategy: str = "augmentation",
minority_classes: Optional[List[Any]] = None,
augmentation_type: str = "synonym",
sampling_method: str = "smote",
lang: str = "ru",
embedding_func: Optional[Callable] = None,
class_weights: bool = False,
random_state: int = 42
) -> Union[
Tuple[List[str], List[Any]], # for augmentation
Tuple[np.ndarray, np.ndarray, Optional[Dict]] # for sampling + weights
]:
label_counts = Counter(labels)
if minority_classes is None:
min_count = min(label_counts.values())
minority_classes = [lbl for lbl, cnt in label_counts.items() if cnt == min_count]
if strategy == "augmentation":
minority_texts = [t for t, l in zip(texts, labels) if l in minority_classes]
minority_labels = [l for l in labels if l in minority_classes]
aug_texts, aug_labels = augment_texts(
minority_texts, minority_labels,
augmentation_type=augmentation_type,
lang=lang,
num_aug=max(1, int((max(label_counts.values()) / min_count)) - 1),
random_state=random_state
)
balanced_texts = texts + aug_texts
balanced_labels = labels + aug_labels
return balanced_texts, balanced_labels
elif strategy == "sampling":
if embedding_func is None:
raise ValueError("embedding_func is required for sampling strategy")
X_embed = np.array([embedding_func(t) for t in texts])
X_res, y_res = apply_sampling(X_embed, np.array(labels), method=sampling_method, random_state=random_state)
weights = compute_class_weights(y_res) if class_weights else None
return X_res, y_res, weights
elif strategy == "both":
aug_texts, aug_labels = balance_text_dataset(
texts, labels, strategy="augmentation", minority_classes=minority_classes,
augmentation_type=augmentation_type, lang=lang, random_state=random_state
)
if embedding_func is None:
return aug_texts, aug_labels
X_embed = np.array([embedding_func(t) for t in aug_texts])
X_res, y_res = apply_sampling(X_embed, np.array(aug_labels), method=sampling_method, random_state=random_state)
weights = compute_class_weights(y_res) if class_weights else None
return X_res, y_res, weights
else:
raise ValueError("strategy must be 'augmentation', 'sampling', or 'both'")