Spaces:
Build error
Build error
| from typing import List, Tuple, Union, Dict, Optional, Any, Callable | |
| import numpy as np | |
| import pandas as pd | |
| from collections import Counter | |
| def compute_class_weights(y: Union[List, np.ndarray], method: str = "balanced") -> Union[Dict[int, float], None]: | |
| if method == "balanced": | |
| from sklearn.utils.class_weight import compute_class_weight | |
| classes = np.unique(y) | |
| weights = compute_class_weight('balanced', classes=classes, y=y) | |
| return dict(zip(classes, weights)) | |
| else: | |
| return None | |
| def get_pytorch_weighted_loss(class_weights: Optional[Dict[int, float]] = None, | |
| num_classes: Optional[int] = None) -> 'torch.nn.Module': | |
| try: | |
| import torch | |
| import torch.nn as nn | |
| except ImportError: | |
| raise ImportError("PyTorch not installed") | |
| if class_weights is not None: | |
| weight_tensor = torch.tensor([class_weights[i] for i in sorted(class_weights.keys())], dtype=torch.float) | |
| return nn.CrossEntropyLoss(weight=weight_tensor) | |
| else: | |
| return nn.CrossEntropyLoss() | |
| def get_tensorflow_weighted_loss(class_weights: Optional[Dict[int, float]] = None) -> Callable: | |
| if not class_weights: | |
| return 'sparse_categorical_crossentropy' | |
| weight_list = [class_weights[i] for i in sorted(class_weights.keys())] | |
| import tensorflow as tf | |
| def weighted_sparse_categorical_crossentropy(y_true, y_pred): | |
| y_true = tf.cast(y_true, tf.int32) | |
| y_true_one_hot = tf.one_hot(y_true, depth=len(weight_list)) | |
| weights = tf.reduce_sum(y_true_one_hot * weight_list, axis=1) | |
| unweighted_losses = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred) | |
| weighted_losses = unweighted_losses * weights | |
| return tf.reduce_mean(weighted_losses) | |
| return weighted_sparse_categorical_crossentropy | |
| def apply_sampling( | |
| X: np.ndarray, | |
| y: np.ndarray, | |
| method: str = "random_under", | |
| random_state: int = 42 | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| from imblearn.over_sampling import SMOTE, ADASYN | |
| from imblearn.under_sampling import RandomUnderSampler | |
| from imblearn.over_sampling import RandomOverSampler | |
| if method == "random_under": | |
| sampler = RandomUnderSampler(random_state=random_state) | |
| elif method == "random_over": | |
| sampler = RandomOverSampler(random_state=random_state) | |
| elif method == "smote": | |
| sampler = SMOTE(random_state=random_state) | |
| elif method == "adasyn": | |
| sampler = ADASYN(random_state=random_state) | |
| else: | |
| raise ValueError("method must be one of: random_under, random_over, smote, adasyn") | |
| X_res, y_res = sampler.fit_resample(X, y) | |
| return X_res, y_res | |
| def augment_texts( | |
| texts: List[str], | |
| labels: List[Any], | |
| augmentation_type: str = "synonym", | |
| aug_p: float = 0.1, | |
| lang: str = "ru", # language code | |
| model_name: Optional[str] = None, | |
| num_aug: int = 1, | |
| random_state: int = 42 | |
| ) -> Tuple[List[str], List[Any]]: | |
| try: | |
| import nlpaug.augmenter.word as naw | |
| import nlpaug.augmenter.sentence as nas | |
| except ImportError: | |
| raise ImportError("Install nlpaug: pip install nlpaug") | |
| augmented_texts = [] | |
| augmented_labels = [] | |
| if augmentation_type == "synonym": | |
| if lang == "en": | |
| aug = naw.SynonymAug(aug_p=aug_p, aug_max=None) | |
| else: | |
| aug = naw.ContextualWordEmbsAug( | |
| model_path='bert-base-multilingual-cased', | |
| action="substitute", | |
| aug_p=aug_p, | |
| device='cpu' | |
| ) | |
| elif augmentation_type == "insert": | |
| aug = naw.RandomWordAug(action="insert", aug_p=aug_p) | |
| elif augmentation_type == "delete": | |
| aug = naw.RandomWordAug(action="delete", aug_p=aug_p) | |
| elif augmentation_type == "swap": | |
| aug = naw.RandomWordAug(action="swap", aug_p=aug_p) | |
| elif augmentation_type == "eda": | |
| aug = naw.AntonymAug() | |
| elif augmentation_type == "back_trans": | |
| if not model_name: | |
| if lang == "ru": | |
| model_name = "Helsinki-NLP/opus-mt-ru-en" | |
| back_model = "Helsinki-NLP/opus-mt-en-ru" | |
| else: | |
| model_name = "Helsinki-NLP/opus-mt-en-ru" | |
| back_model = "Helsinki-NLP/opus-mt-ru-en" | |
| else: | |
| back_model = model_name | |
| try: | |
| from transformers import pipeline | |
| translator1 = pipeline("translation", model=model_name, tokenizer=model_name) | |
| translator2 = pipeline("translation", model=back_model, tokenizer=back_model) | |
| def back_translate(text): | |
| try: | |
| trans = translator1(text)[0]['translation_text'] | |
| back = translator2(trans)[0]['translation_text'] | |
| return back | |
| except Exception: | |
| return text | |
| augmented = [back_translate(t) for t in texts for _ in range(num_aug)] | |
| labels_aug = [l for l in labels for _ in range(num_aug)] | |
| return augmented, labels_aug | |
| except Exception as e: | |
| print(f"Back-translation failed: {e}. Falling back to synonym augmentation.") | |
| aug = naw.ContextualWordEmbsAug(model_path='bert-base-multilingual-cased', aug_p=aug_p) | |
| elif augmentation_type == "llm": | |
| raise NotImplementedError("LLM-controlled augmentation requires external API (e.g., OpenAI, YandexGPT)") | |
| else: | |
| raise ValueError("Unknown augmentation_type") | |
| for text, label in zip(texts, labels): | |
| for _ in range(num_aug): | |
| try: | |
| aug_text = aug.augment(text) | |
| if isinstance(aug_text, list): | |
| aug_text = aug_text[0] | |
| augmented_texts.append(aug_text) | |
| augmented_labels.append(label) | |
| except Exception as e: | |
| augmented_texts.append(text) | |
| augmented_labels.append(label) | |
| return augmented_texts, augmented_labels | |
| def balance_text_dataset( | |
| texts: List[str], | |
| labels: List[Any], | |
| strategy: str = "augmentation", | |
| minority_classes: Optional[List[Any]] = None, | |
| augmentation_type: str = "synonym", | |
| sampling_method: str = "smote", | |
| lang: str = "ru", | |
| embedding_func: Optional[Callable] = None, | |
| class_weights: bool = False, | |
| random_state: int = 42 | |
| ) -> Union[ | |
| Tuple[List[str], List[Any]], # for augmentation | |
| Tuple[np.ndarray, np.ndarray, Optional[Dict]] # for sampling + weights | |
| ]: | |
| label_counts = Counter(labels) | |
| if minority_classes is None: | |
| min_count = min(label_counts.values()) | |
| minority_classes = [lbl for lbl, cnt in label_counts.items() if cnt == min_count] | |
| if strategy == "augmentation": | |
| minority_texts = [t for t, l in zip(texts, labels) if l in minority_classes] | |
| minority_labels = [l for l in labels if l in minority_classes] | |
| aug_texts, aug_labels = augment_texts( | |
| minority_texts, minority_labels, | |
| augmentation_type=augmentation_type, | |
| lang=lang, | |
| num_aug=max(1, int((max(label_counts.values()) / min_count)) - 1), | |
| random_state=random_state | |
| ) | |
| balanced_texts = texts + aug_texts | |
| balanced_labels = labels + aug_labels | |
| return balanced_texts, balanced_labels | |
| elif strategy == "sampling": | |
| if embedding_func is None: | |
| raise ValueError("embedding_func is required for sampling strategy") | |
| X_embed = np.array([embedding_func(t) for t in texts]) | |
| X_res, y_res = apply_sampling(X_embed, np.array(labels), method=sampling_method, random_state=random_state) | |
| weights = compute_class_weights(y_res) if class_weights else None | |
| return X_res, y_res, weights | |
| elif strategy == "both": | |
| aug_texts, aug_labels = balance_text_dataset( | |
| texts, labels, strategy="augmentation", minority_classes=minority_classes, | |
| augmentation_type=augmentation_type, lang=lang, random_state=random_state | |
| ) | |
| if embedding_func is None: | |
| return aug_texts, aug_labels | |
| X_embed = np.array([embedding_func(t) for t in aug_texts]) | |
| X_res, y_res = apply_sampling(X_embed, np.array(aug_labels), method=sampling_method, random_state=random_state) | |
| weights = compute_class_weights(y_res) if class_weights else None | |
| return X_res, y_res, weights | |
| else: | |
| raise ValueError("strategy must be 'augmentation', 'sampling', or 'both'") |