text_classificators / src /imbalance_handling.py
theformatisvalid's picture
Upload 7 files
2153792 verified
from typing import List, Tuple, Union, Dict, Optional, Any, Callable
import numpy as np
import pandas as pd
from collections import Counter
def compute_class_weights(y: Union[List, np.ndarray], method: str = "balanced") -> Union[Dict[int, float], None]:
if method == "balanced":
from sklearn.utils.class_weight import compute_class_weight
classes = np.unique(y)
weights = compute_class_weight('balanced', classes=classes, y=y)
return dict(zip(classes, weights))
else:
return None
def get_pytorch_weighted_loss(class_weights: Optional[Dict[int, float]] = None,
num_classes: Optional[int] = None) -> 'torch.nn.Module':
try:
import torch
import torch.nn as nn
except ImportError:
raise ImportError("PyTorch not installed")
if class_weights is not None:
weight_tensor = torch.tensor([class_weights[i] for i in sorted(class_weights.keys())], dtype=torch.float)
return nn.CrossEntropyLoss(weight=weight_tensor)
else:
return nn.CrossEntropyLoss()
def get_tensorflow_weighted_loss(class_weights: Optional[Dict[int, float]] = None) -> Callable:
if not class_weights:
return 'sparse_categorical_crossentropy'
weight_list = [class_weights[i] for i in sorted(class_weights.keys())]
import tensorflow as tf
def weighted_sparse_categorical_crossentropy(y_true, y_pred):
y_true = tf.cast(y_true, tf.int32)
y_true_one_hot = tf.one_hot(y_true, depth=len(weight_list))
weights = tf.reduce_sum(y_true_one_hot * weight_list, axis=1)
unweighted_losses = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
weighted_losses = unweighted_losses * weights
return tf.reduce_mean(weighted_losses)
return weighted_sparse_categorical_crossentropy
def apply_sampling(
X: np.ndarray,
y: np.ndarray,
method: str = "random_under",
random_state: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
from imblearn.over_sampling import SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import RandomOverSampler
if method == "random_under":
sampler = RandomUnderSampler(random_state=random_state)
elif method == "random_over":
sampler = RandomOverSampler(random_state=random_state)
elif method == "smote":
sampler = SMOTE(random_state=random_state)
elif method == "adasyn":
sampler = ADASYN(random_state=random_state)
else:
raise ValueError("method must be one of: random_under, random_over, smote, adasyn")
X_res, y_res = sampler.fit_resample(X, y)
return X_res, y_res
def augment_texts(
texts: List[str],
labels: List[Any],
augmentation_type: str = "synonym",
aug_p: float = 0.1,
lang: str = "ru", # language code
model_name: Optional[str] = None,
num_aug: int = 1,
random_state: int = 42
) -> Tuple[List[str], List[Any]]:
try:
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.sentence as nas
except ImportError:
raise ImportError("Install nlpaug: pip install nlpaug")
augmented_texts = []
augmented_labels = []
if augmentation_type == "synonym":
if lang == "en":
aug = naw.SynonymAug(aug_p=aug_p, aug_max=None)
else:
aug = naw.ContextualWordEmbsAug(
model_path='bert-base-multilingual-cased',
action="substitute",
aug_p=aug_p,
device='cpu'
)
elif augmentation_type == "insert":
aug = naw.RandomWordAug(action="insert", aug_p=aug_p)
elif augmentation_type == "delete":
aug = naw.RandomWordAug(action="delete", aug_p=aug_p)
elif augmentation_type == "swap":
aug = naw.RandomWordAug(action="swap", aug_p=aug_p)
elif augmentation_type == "eda":
aug = naw.AntonymAug()
elif augmentation_type == "back_trans":
if not model_name:
if lang == "ru":
model_name = "Helsinki-NLP/opus-mt-ru-en"
back_model = "Helsinki-NLP/opus-mt-en-ru"
else:
model_name = "Helsinki-NLP/opus-mt-en-ru"
back_model = "Helsinki-NLP/opus-mt-ru-en"
else:
back_model = model_name
try:
from transformers import pipeline
translator1 = pipeline("translation", model=model_name, tokenizer=model_name)
translator2 = pipeline("translation", model=back_model, tokenizer=back_model)
def back_translate(text):
try:
trans = translator1(text)[0]['translation_text']
back = translator2(trans)[0]['translation_text']
return back
except Exception:
return text
augmented = [back_translate(t) for t in texts for _ in range(num_aug)]
labels_aug = [l for l in labels for _ in range(num_aug)]
return augmented, labels_aug
except Exception as e:
print(f"Back-translation failed: {e}. Falling back to synonym augmentation.")
aug = naw.ContextualWordEmbsAug(model_path='bert-base-multilingual-cased', aug_p=aug_p)
elif augmentation_type == "llm":
raise NotImplementedError("LLM-controlled augmentation requires external API (e.g., OpenAI, YandexGPT)")
else:
raise ValueError("Unknown augmentation_type")
for text, label in zip(texts, labels):
for _ in range(num_aug):
try:
aug_text = aug.augment(text)
if isinstance(aug_text, list):
aug_text = aug_text[0]
augmented_texts.append(aug_text)
augmented_labels.append(label)
except Exception as e:
augmented_texts.append(text)
augmented_labels.append(label)
return augmented_texts, augmented_labels
def balance_text_dataset(
texts: List[str],
labels: List[Any],
strategy: str = "augmentation",
minority_classes: Optional[List[Any]] = None,
augmentation_type: str = "synonym",
sampling_method: str = "smote",
lang: str = "ru",
embedding_func: Optional[Callable] = None,
class_weights: bool = False,
random_state: int = 42
) -> Union[
Tuple[List[str], List[Any]], # for augmentation
Tuple[np.ndarray, np.ndarray, Optional[Dict]] # for sampling + weights
]:
label_counts = Counter(labels)
if minority_classes is None:
min_count = min(label_counts.values())
minority_classes = [lbl for lbl, cnt in label_counts.items() if cnt == min_count]
if strategy == "augmentation":
minority_texts = [t for t, l in zip(texts, labels) if l in minority_classes]
minority_labels = [l for l in labels if l in minority_classes]
aug_texts, aug_labels = augment_texts(
minority_texts, minority_labels,
augmentation_type=augmentation_type,
lang=lang,
num_aug=max(1, int((max(label_counts.values()) / min_count)) - 1),
random_state=random_state
)
balanced_texts = texts + aug_texts
balanced_labels = labels + aug_labels
return balanced_texts, balanced_labels
elif strategy == "sampling":
if embedding_func is None:
raise ValueError("embedding_func is required for sampling strategy")
X_embed = np.array([embedding_func(t) for t in texts])
X_res, y_res = apply_sampling(X_embed, np.array(labels), method=sampling_method, random_state=random_state)
weights = compute_class_weights(y_res) if class_weights else None
return X_res, y_res, weights
elif strategy == "both":
aug_texts, aug_labels = balance_text_dataset(
texts, labels, strategy="augmentation", minority_classes=minority_classes,
augmentation_type=augmentation_type, lang=lang, random_state=random_state
)
if embedding_func is None:
return aug_texts, aug_labels
X_embed = np.array([embedding_func(t) for t in aug_texts])
X_res, y_res = apply_sampling(X_embed, np.array(aug_labels), method=sampling_method, random_state=random_state)
weights = compute_class_weights(y_res) if class_weights else None
return X_res, y_res, weights
else:
raise ValueError("strategy must be 'augmentation', 'sampling', or 'both'")