Auto-FineTune-Ops / preprocessing /quality_filters.py
aneeb15's picture
Initial release of Auto-FineTune-Ops
d4398e6
"""
Quality Filters Module
========================
Filter samples by word count, profanity, language,
and low-quality response detection.
"""
from dataclasses import dataclass, field
from typing import List, Optional
import re
import pandas as pd
@dataclass
class QualityFilterConfig:
"""Configuration for quality filters."""
min_word_count: int = 0
max_word_count: int = 0 # 0 = no limit
profanity_filter: bool = False
language_filter: bool = False
allowed_languages: List[str] = field(default_factory=lambda: ["en"])
remove_low_quality: bool = False
min_quality_length: int = 20
# ---------------------------------------------------------------------------
# Profanity word list (small built-in set, extend as needed)
# ---------------------------------------------------------------------------
_PROFANITY_WORDS = {
'fuck', 'shit', 'damn', 'ass', 'bitch', 'bastard', 'crap',
'dick', 'piss', 'slut', 'whore', 'cock',
}
# Generic filler/placeholder responses that indicate low quality
_GENERIC_RESPONSES = [
"i don't know",
"i am not sure",
"no comment",
"n/a",
"none",
"null",
"test",
"asdf",
"lorem ipsum",
"placeholder",
"todo",
"tbd",
]
def _word_count(text: str) -> int:
"""Count words in a text string."""
if not isinstance(text, str):
return 0
return len(text.split())
def filter_by_word_count(
df: pd.DataFrame,
col: str,
min_words: int = 0,
max_words: int = 0,
) -> pd.DataFrame:
"""Filter rows by word count in the given column."""
df = df.copy()
counts = df[col].apply(_word_count)
if min_words > 0:
df = df[counts >= min_words]
counts = counts[df.index]
if max_words > 0:
df = df[counts <= max_words]
return df.reset_index(drop=True)
def contains_profanity(text: str) -> bool:
"""Check if text contains any profanity words."""
if not isinstance(text, str):
return False
words = set(re.findall(r'\b\w+\b', text.lower()))
return bool(words & _PROFANITY_WORDS)
def filter_profanity(
df: pd.DataFrame,
col: str,
) -> pd.DataFrame:
"""Remove rows containing profanity in the given column."""
mask = ~df[col].apply(contains_profanity)
return df[mask].reset_index(drop=True)
def detect_language(text: str) -> str:
"""
Detect the language of a text string.
Returns ISO 639-1 code (e.g., 'en', 'fr', 'de').
Falls back to 'unknown' if detection fails.
"""
try:
from langdetect import detect
if not isinstance(text, str) or len(text.strip()) < 10:
return 'unknown'
return detect(text)
except ImportError:
return 'unknown'
except Exception:
return 'unknown'
def filter_by_language(
df: pd.DataFrame,
col: str,
allowed_langs: List[str] = None,
) -> pd.DataFrame:
"""Keep only rows where the text is in one of the allowed languages."""
if allowed_langs is None:
allowed_langs = ['en']
langs = df[col].apply(detect_language)
mask = langs.isin(allowed_langs) | (langs == 'unknown')
return df[mask].reset_index(drop=True)
def is_low_quality(text: str, min_len: int = 20) -> bool:
"""
Check if a response is low-quality:
- Too short
- Matches generic/placeholder patterns
"""
if not isinstance(text, str):
return True
text_stripped = text.strip()
if len(text_stripped) < min_len:
return True
text_lower = text_stripped.lower()
for phrase in _GENERIC_RESPONSES:
if text_lower == phrase or text_lower.startswith(phrase):
return True
return False
def filter_low_quality(
df: pd.DataFrame,
col: str,
min_len: int = 20,
) -> pd.DataFrame:
"""Remove low-quality responses."""
mask = ~df[col].apply(lambda t: is_low_quality(t, min_len))
return df[mask].reset_index(drop=True)
def apply_quality_filters(
df: pd.DataFrame,
col: str,
config: QualityFilterConfig,
) -> pd.DataFrame:
"""Apply all enabled quality filters to a DataFrame."""
if config.min_word_count > 0 or config.max_word_count > 0:
df = filter_by_word_count(df, col, config.min_word_count, config.max_word_count)
if config.profanity_filter:
df = filter_profanity(df, col)
if config.language_filter:
df = filter_by_language(df, col, config.allowed_languages)
if config.remove_low_quality:
df = filter_low_quality(df, col, config.min_quality_length)
return df