Auto-FineTune-Ops / preprocessing /deduplication.py
aneeb15's picture
Initial release of Auto-FineTune-Ops
d4398e6
"""
Deduplication Module
======================
Exact and semantic (TF-IDF cosine similarity) deduplication.
"""
from dataclasses import dataclass
from typing import List, Optional
import pandas as pd
import numpy as np
@dataclass
class DeduplicationConfig:
"""Configuration for deduplication."""
remove_exact: bool = True
remove_semantic: bool = False
semantic_threshold: float = 0.90 # cosine similarity threshold
def remove_exact_duplicates(
df: pd.DataFrame,
col: str,
) -> pd.DataFrame:
"""Remove rows with exact duplicate values in the given column."""
return df.drop_duplicates(subset=[col]).reset_index(drop=True)
def remove_semantic_duplicates(
df: pd.DataFrame,
col: str,
threshold: float = 0.90,
) -> pd.DataFrame:
"""
Remove semantically similar rows using TF-IDF cosine similarity.
Rows with cosine similarity >= threshold to an earlier row are dropped.
"""
if len(df) < 2:
return df
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
except ImportError:
# If scikit-learn not available, just return as-is
return df
texts = df[col].fillna('').astype(str).tolist()
# Build TF-IDF matrix
vectorizer = TfidfVectorizer(max_features=5000, stop_words='english')
try:
tfidf_matrix = vectorizer.fit_transform(texts)
except ValueError:
return df
# Find duplicates — compare each row to all previous rows
keep_indices = [0]
for i in range(1, len(texts)):
# Compare row i against all kept rows
sim = cosine_similarity(
tfidf_matrix[i:i+1],
tfidf_matrix[keep_indices],
)
if sim.max() < threshold:
keep_indices.append(i)
return df.iloc[keep_indices].reset_index(drop=True)
def apply_deduplication(
df: pd.DataFrame,
col: str,
config: DeduplicationConfig,
) -> pd.DataFrame:
"""Apply all enabled deduplication methods."""
if config.remove_exact:
df = remove_exact_duplicates(df, col)
if config.remove_semantic:
df = remove_semantic_duplicates(df, col, config.semantic_threshold)
return df