from typing import List, Tuple, Union, Dict, Optional, Any, Callable import numpy as np import pandas as pd from collections import Counter def compute_class_weights(y: Union[List, np.ndarray], method: str = "balanced") -> Union[Dict[int, float], None]: if method == "balanced": from sklearn.utils.class_weight import compute_class_weight classes = np.unique(y) weights = compute_class_weight('balanced', classes=classes, y=y) return dict(zip(classes, weights)) else: return None def get_pytorch_weighted_loss(class_weights: Optional[Dict[int, float]] = None, num_classes: Optional[int] = None) -> 'torch.nn.Module': try: import torch import torch.nn as nn except ImportError: raise ImportError("PyTorch not installed") if class_weights is not None: weight_tensor = torch.tensor([class_weights[i] for i in sorted(class_weights.keys())], dtype=torch.float) return nn.CrossEntropyLoss(weight=weight_tensor) else: return nn.CrossEntropyLoss() def get_tensorflow_weighted_loss(class_weights: Optional[Dict[int, float]] = None) -> Callable: if not class_weights: return 'sparse_categorical_crossentropy' weight_list = [class_weights[i] for i in sorted(class_weights.keys())] import tensorflow as tf def weighted_sparse_categorical_crossentropy(y_true, y_pred): y_true = tf.cast(y_true, tf.int32) y_true_one_hot = tf.one_hot(y_true, depth=len(weight_list)) weights = tf.reduce_sum(y_true_one_hot * weight_list, axis=1) unweighted_losses = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred) weighted_losses = unweighted_losses * weights return tf.reduce_mean(weighted_losses) return weighted_sparse_categorical_crossentropy def apply_sampling( X: np.ndarray, y: np.ndarray, method: str = "random_under", random_state: int = 42 ) -> Tuple[np.ndarray, np.ndarray]: from imblearn.over_sampling import SMOTE, ADASYN from imblearn.under_sampling import RandomUnderSampler from imblearn.over_sampling import RandomOverSampler if method == "random_under": sampler = RandomUnderSampler(random_state=random_state) elif method == "random_over": sampler = RandomOverSampler(random_state=random_state) elif method == "smote": sampler = SMOTE(random_state=random_state) elif method == "adasyn": sampler = ADASYN(random_state=random_state) else: raise ValueError("method must be one of: random_under, random_over, smote, adasyn") X_res, y_res = sampler.fit_resample(X, y) return X_res, y_res def augment_texts( texts: List[str], labels: List[Any], augmentation_type: str = "synonym", aug_p: float = 0.1, lang: str = "ru", # language code model_name: Optional[str] = None, num_aug: int = 1, random_state: int = 42 ) -> Tuple[List[str], List[Any]]: try: import nlpaug.augmenter.word as naw import nlpaug.augmenter.sentence as nas except ImportError: raise ImportError("Install nlpaug: pip install nlpaug") augmented_texts = [] augmented_labels = [] if augmentation_type == "synonym": if lang == "en": aug = naw.SynonymAug(aug_p=aug_p, aug_max=None) else: aug = naw.ContextualWordEmbsAug( model_path='bert-base-multilingual-cased', action="substitute", aug_p=aug_p, device='cpu' ) elif augmentation_type == "insert": aug = naw.RandomWordAug(action="insert", aug_p=aug_p) elif augmentation_type == "delete": aug = naw.RandomWordAug(action="delete", aug_p=aug_p) elif augmentation_type == "swap": aug = naw.RandomWordAug(action="swap", aug_p=aug_p) elif augmentation_type == "eda": aug = naw.AntonymAug() elif augmentation_type == "back_trans": if not model_name: if lang == "ru": model_name = "Helsinki-NLP/opus-mt-ru-en" back_model = "Helsinki-NLP/opus-mt-en-ru" else: model_name = "Helsinki-NLP/opus-mt-en-ru" back_model = "Helsinki-NLP/opus-mt-ru-en" else: back_model = model_name try: from transformers import pipeline translator1 = pipeline("translation", model=model_name, tokenizer=model_name) translator2 = pipeline("translation", model=back_model, tokenizer=back_model) def back_translate(text): try: trans = translator1(text)[0]['translation_text'] back = translator2(trans)[0]['translation_text'] return back except Exception: return text augmented = [back_translate(t) for t in texts for _ in range(num_aug)] labels_aug = [l for l in labels for _ in range(num_aug)] return augmented, labels_aug except Exception as e: print(f"Back-translation failed: {e}. Falling back to synonym augmentation.") aug = naw.ContextualWordEmbsAug(model_path='bert-base-multilingual-cased', aug_p=aug_p) elif augmentation_type == "llm": raise NotImplementedError("LLM-controlled augmentation requires external API (e.g., OpenAI, YandexGPT)") else: raise ValueError("Unknown augmentation_type") for text, label in zip(texts, labels): for _ in range(num_aug): try: aug_text = aug.augment(text) if isinstance(aug_text, list): aug_text = aug_text[0] augmented_texts.append(aug_text) augmented_labels.append(label) except Exception as e: augmented_texts.append(text) augmented_labels.append(label) return augmented_texts, augmented_labels def balance_text_dataset( texts: List[str], labels: List[Any], strategy: str = "augmentation", minority_classes: Optional[List[Any]] = None, augmentation_type: str = "synonym", sampling_method: str = "smote", lang: str = "ru", embedding_func: Optional[Callable] = None, class_weights: bool = False, random_state: int = 42 ) -> Union[ Tuple[List[str], List[Any]], # for augmentation Tuple[np.ndarray, np.ndarray, Optional[Dict]] # for sampling + weights ]: label_counts = Counter(labels) if minority_classes is None: min_count = min(label_counts.values()) minority_classes = [lbl for lbl, cnt in label_counts.items() if cnt == min_count] if strategy == "augmentation": minority_texts = [t for t, l in zip(texts, labels) if l in minority_classes] minority_labels = [l for l in labels if l in minority_classes] aug_texts, aug_labels = augment_texts( minority_texts, minority_labels, augmentation_type=augmentation_type, lang=lang, num_aug=max(1, int((max(label_counts.values()) / min_count)) - 1), random_state=random_state ) balanced_texts = texts + aug_texts balanced_labels = labels + aug_labels return balanced_texts, balanced_labels elif strategy == "sampling": if embedding_func is None: raise ValueError("embedding_func is required for sampling strategy") X_embed = np.array([embedding_func(t) for t in texts]) X_res, y_res = apply_sampling(X_embed, np.array(labels), method=sampling_method, random_state=random_state) weights = compute_class_weights(y_res) if class_weights else None return X_res, y_res, weights elif strategy == "both": aug_texts, aug_labels = balance_text_dataset( texts, labels, strategy="augmentation", minority_classes=minority_classes, augmentation_type=augmentation_type, lang=lang, random_state=random_state ) if embedding_func is None: return aug_texts, aug_labels X_embed = np.array([embedding_func(t) for t in aug_texts]) X_res, y_res = apply_sampling(X_embed, np.array(aug_labels), method=sampling_method, random_state=random_state) weights = compute_class_weights(y_res) if class_weights else None return X_res, y_res, weights else: raise ValueError("strategy must be 'augmentation', 'sampling', or 'both'")