Text_classification / preprocess_data.py
Tantawi's picture
Upload 13 files
f2a4578 verified
import argparse
import os
import sys
import pandas as pd
import numpy as np
def standardize_columns(df: pd.DataFrame) -> pd.DataFrame:
"""Lowercase, strip, and replace spaces with underscores in column names."""
df = df.copy()
df.columns = [c.strip().lower().replace(" ", "_") for c in df.columns]
return df
def drop_invalid_rows(df: pd.DataFrame) -> pd.DataFrame:
"""Drop rows with missing target (first column) and fully empty feature rows."""
df = df.copy()
target_col = df.columns[0]
df = df[~df[target_col].isna()]
feature_df = df.iloc[:, 1:]
non_empty_mask = ~(feature_df.isna().all(axis=1) | (feature_df.sum(axis=1) == 0))
df = df.loc[non_empty_mask]
return df
def remove_constant_and_sparse_features(df: pd.DataFrame, min_positive_frac: float = 0.0005):
"""Remove columns that are constant or extremely sparse (near-zero variance)."""
target = df.columns[0]
X = df.iloc[:, 1:]
keep_cols = []
for col in X.columns:
series = X[col]
if series.nunique(dropna=True) <= 1:
continue
# If binary-like, compute positive ratio
try:
pos_frac = (series.fillna(0) > 0).mean()
except Exception:
pos_frac = 1.0
if pos_frac < min_positive_frac:
continue
keep_cols.append(col)
cleaned = pd.concat([df[[target]], X[keep_cols]], axis=1)
return cleaned
def impute_missing(df: pd.DataFrame) -> pd.DataFrame:
"""Impute missing values in features with 0, keep target as is."""
target = df.columns[0]
X = df.iloc[:, 1:].fillna(0)
return pd.concat([df[[target]], X], axis=1)
def limit_classes(df: pd.DataFrame, min_samples: int = 5) -> pd.DataFrame:
"""Keep only classes with at least min_samples samples."""
target = df.columns[0]
counts = df[target].value_counts()
keep = counts[counts >= min_samples].index
return df[df[target].isin(keep)]
def main():
parser = argparse.ArgumentParser(description="Preprocess disease-symptom CSV for training.")
parser.add_argument("--input", required=True, help="Path to raw CSV")
parser.add_argument("--output", default="cleaned_dataset.csv", help="Path to save cleaned CSV")
args = parser.parse_args()
if not os.path.exists(args.input):
print(f"❌ Input CSV not found: {args.input}")
sys.exit(1)
print("Loading CSV...")
df = pd.read_csv(args.input)
print(f"Raw shape: {df.shape}")
print("Standardizing column names...")
df = standardize_columns(df)
print("Dropping invalid/empty rows...")
df = drop_invalid_rows(df)
print(f"After row cleanup: {df.shape}")
print("Removing constant and sparse features...")
df = remove_constant_and_sparse_features(df)
print(f"After feature cleanup: {df.shape}")
print("Imputing missing values (0 for symptoms)...")
df = impute_missing(df)
print("Limiting classes with very few samples...")
df = limit_classes(df, min_samples=5)
print(f"After class filtering: {df.shape}")
print(f"Saving cleaned CSV to: {args.output}")
df.to_csv(args.output, index=False)
print("Done.")
if __name__ == "__main__":
main()