| import pandas as pd |
| import re |
| import os |
| import logging |
| import unicodedata |
| import random |
| from typing import List, Tuple |
| from difflib import SequenceMatcher |
| from collections import Counter |
|
|
| |
| |
| |
|
|
| DATA_FILES = ["tamil.csv", "english.csv", "tanglish.csv"] |
| SUPPORTED_ENCODINGS = [ |
| "utf-8-sig", |
| "utf-16", |
| "utf-16le", |
| "utf-8", |
| "latin1" |
| ] |
|
|
| SIMILARITY_THRESHOLD = 0.97 |
| MAX_TEXT_LENGTH = 512 |
| ENABLE_AUGMENTATION = True |
| AUGMENT_MULTIPLIER = 2 |
| TARGET_NON_TOXIC_RATIO = 0.7 |
|
|
|
|
| |
| |
| |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(levelname)s - %(message)s" |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
| def read_csv_with_fallback(file_path: str): |
| if file_path.endswith(".csv"): |
| for enc in SUPPORTED_ENCODINGS: |
| try: |
| df = pd.read_csv( |
| file_path, |
| encoding=enc, |
| engine="python", |
| sep="\t" |
| ) |
| logger.info(f"Read {file_path} with encoding {enc}") |
| return df |
| except Exception as e: |
| logger.warning(f"Encoding {enc} failed: {e}") |
| continue |
|
|
| elif file_path.endswith(".xlsx"): |
| df = pd.read_excel(file_path, engine="openpyxl") |
| logger.info(f"Read {file_path} as Excel") |
| return df |
|
|
| raise ValueError(f"Unsupported file format: {file_path}") |
| |
| |
| |
|
|
| def normalize_unicode(text: str) -> str: |
| return unicodedata.normalize("NFKC", text) |
|
|
|
|
| def normalize_repetitions(text: str) -> str: |
| |
| return re.sub(r"(.)\1{2,}", r"\1\1", text) |
|
|
|
|
| def reverse_leetspeak(text: str) -> str: |
| mapping = { |
| "0": "o", |
| "1": "i", |
| "3": "e", |
| "@": "a", |
| "$": "s", |
| "!": "i" |
| } |
| return "".join(mapping.get(c, c) for c in text) |
|
|
|
|
| def clean_text(text: str) -> str: |
| if not isinstance(text, str): |
| return "" |
|
|
| text = normalize_unicode(text) |
| text = text.lower() |
| text = reverse_leetspeak(text) |
|
|
| text = re.sub(r"http\S+", " URL ", text) |
| text = re.sub(r"@\w+", " USER ", text) |
|
|
| text = normalize_repetitions(text) |
|
|
| text = re.sub(r"\s+", " ", text).strip() |
|
|
| if len(text) > MAX_TEXT_LENGTH: |
| text = text[:MAX_TEXT_LENGTH] |
|
|
| return text |
|
|
|
|
| |
| |
| |
|
|
| def is_similar(a: str, b: str, threshold: float = SIMILARITY_THRESHOLD) -> bool: |
| if abs(len(a) - len(b)) > 10: |
| return False |
| return SequenceMatcher(None, a, b).ratio() > threshold |
|
|
|
|
| def remove_near_duplicates(df: pd.DataFrame) -> pd.DataFrame: |
| texts = df["cleaned_text"].tolist() |
| keep_indices = [] |
|
|
| for i, text in enumerate(texts): |
| duplicate_found = False |
| for j in keep_indices: |
| if is_similar(text, texts[j]): |
| duplicate_found = True |
| break |
| if not duplicate_found: |
| keep_indices.append(i) |
|
|
| return df.iloc[keep_indices].reset_index(drop=True) |
|
|
|
|
| |
| |
| |
|
|
| def random_case(text: str) -> str: |
| return "".join( |
| c.upper() if random.random() > 0.5 else c.lower() for c in text |
| ) |
|
|
|
|
| def insert_punctuation(text: str) -> str: |
| if len(text) < 3: |
| return text |
| pos = random.randint(1, len(text) - 2) |
| return text[:pos] + "." + text[pos:] |
|
|
|
|
| def char_drop(text: str) -> str: |
| if len(text) < 4: |
| return text |
| pos = random.randint(0, len(text) - 1) |
| return text[:pos] + text[pos+1:] |
|
|
|
|
| def generate_adversarial_variants(text: str) -> List[str]: |
| variants = [] |
| transforms = [random_case, insert_punctuation, char_drop] |
|
|
| for _ in range(AUGMENT_MULTIPLIER): |
| t = random.choice(transforms) |
| variants.append(t(text)) |
|
|
| return variants |
|
|
|
|
| def augment_dataset(df: pd.DataFrame) -> pd.DataFrame: |
| augmented_rows = [] |
|
|
| for _, row in df.iterrows(): |
| if row["label"] in ["toxic", "severe_toxic"]: |
| variants = generate_adversarial_variants(row["cleaned_text"]) |
| for v in variants: |
| augmented_rows.append({ |
| "text": v, |
| "label": row["label"], |
| "cleaned_text": clean_text(v) |
| }) |
|
|
| if augmented_rows: |
| df_aug = pd.DataFrame(augmented_rows) |
| df = pd.concat([df, df_aug], ignore_index=True) |
|
|
| return df |
|
|
|
|
| |
| |
| |
|
|
| def rebalance_classes(df: pd.DataFrame) -> pd.DataFrame: |
| counts = df["label"].value_counts().to_dict() |
|
|
| non_toxic_df = df[df["label"] == "non_toxic"] |
| toxic_df = df[df["label"] != "non_toxic"] |
|
|
| desired_non_toxic = int(len(df) * TARGET_NON_TOXIC_RATIO) |
|
|
| if len(non_toxic_df) > desired_non_toxic: |
| non_toxic_df = non_toxic_df.sample(desired_non_toxic, random_state=42) |
|
|
| df = pd.concat([non_toxic_df, toxic_df]).sample(frac=1, random_state=42) |
|
|
| return df.reset_index(drop=True) |
|
|
|
|
| |
| |
| |
|
|
| def load_and_preprocess_data(base_path: str = ".", augment: bool = True) -> pd.DataFrame: |
| dfs = [] |
| for filename in DATA_FILES: |
| file_path = os.path.join(base_path, filename) |
| if not os.path.exists(file_path): |
| logger.warning(f"File not found: {file_path}") |
| continue |
|
|
| df = read_csv_with_fallback(file_path) |
|
|
| df.columns = ["text", "label"] + list(df.columns[2:]) |
| df = df[["text", "label"]] |
|
|
| dfs.append(df) |
| logger.info(f"Loaded {filename}: {df.shape[0]} rows") |
|
|
| if not dfs: |
| raise ValueError("No dataset files loaded.") |
|
|
| df = pd.concat(dfs, ignore_index=True) |
|
|
| logger.info(f"Initial rows: {len(df)}") |
|
|
| df.dropna(subset=["text", "label"], inplace=True) |
| df.drop_duplicates(subset=["text"], inplace=True) |
|
|
| df["cleaned_text"] = df["text"].apply(clean_text) |
| df = df[df["cleaned_text"] != ""] |
|
|
| logger.info(f"After basic cleaning: {len(df)}") |
|
|
| df = remove_near_duplicates(df) |
| logger.info(f"After near-duplicate removal: {len(df)}") |
|
|
| if augment and ENABLE_AUGMENTATION: |
| df = augment_dataset(df) |
| logger.info(f"After augmentation: {len(df)}") |
| df = rebalance_classes(df) |
|
|
| logger.info("Final label distribution:") |
| logger.info(df["label"].value_counts()) |
|
|
| return df |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| df = load_and_preprocess_data(".", augment=False) |
| print(df.sample(5)) |