SPJIMR-ReviewPaper-V2 / preprocessing.py
rahull30's picture
Clean commit: preprocessing, clustering, embedding fixes
a0c55ac
"""
preprocessing.py β€” Text cleaning and combined_text creation for topic modelling pipeline.
Produces two text columns:
- combined_text_raw : original casing (Title + Abstract) β†’ used for SPECTER2 embeddings
- combined_text_clean : lowercased, normalised β†’ used for keyword extraction
Also performs:
- DOI-based exact deduplication
- Fuzzy title deduplication (difflib similarity >= 0.85)
- Filtering of rows whose combined text is < 100 characters
- Dataset overview stats (total, cleaned, duplicates removed, missing abstracts)
"""
import re
import difflib
import pandas as pd
from typing import Optional, Tuple
# ─── TEXT HELPERS ─────────────────────────────────────────────────────────────
def _normalize_whitespace(text: str) -> str:
"""Collapse multiple spaces/newlines into a single space and strip."""
return re.sub(r"\s+", " ", text).strip()
def clean_text_raw(title: str, abstract: str) -> str:
"""
Combine title + abstract preserving original casing.
Used for SPECTER2 embeddings (case-sensitive model).
"""
title = _normalize_whitespace(title) if isinstance(title, str) else ""
abstract = _normalize_whitespace(abstract) if isinstance(abstract, str) else ""
return (title + " " + abstract).strip()
def clean_text_lower(title: str, abstract: str) -> str:
"""
Combine title + abstract, lowercase and lightly normalise.
Preserves hyphens and slashes common in science (covid-19, RNA/DNA).
Used for keyword extraction (KeyBERT).
"""
title = _normalize_whitespace(title).lower() if isinstance(title, str) else ""
abstract = _normalize_whitespace(abstract).lower() if isinstance(abstract, str) else ""
combined = (title + " " + abstract).strip()
# Remove characters that are not word chars, whitespace, hyphens, or slashes
combined = re.sub(r"[^\w\s\-/]", " ", combined)
return re.sub(r"\s+", " ", combined).strip()
# ─── DEDUPLICATION ────────────────────────────────────────────────────────────
def _deduplicate(df: pd.DataFrame) -> Tuple[pd.DataFrame, int]:
"""
Remove duplicate papers using:
1. Exact DOI match (drop subsequent duplicates where DOI is non-empty)
2. Fuzzy title similarity >= 0.85 (difflib SequenceMatcher)
Returns (deduplicated_df, n_removed).
"""
original_len = len(df)
# --- Step 1: exact DOI deduplication (ignore blank / index-based DOIs)
real_doi_mask = df["DOI"].str.strip().str.len() > 3 # skip index placeholders
doi_dupes = df[real_doi_mask].duplicated(subset=["DOI"], keep="first")
# Mark real-DOI duplicates for removal
drop_idx = set(df[real_doi_mask][doi_dupes].index.tolist())
# --- Step 2: fuzzy title deduplication on remaining rows
remaining = df[~df.index.isin(drop_idx)].reset_index(drop=False)
titles = [str(t).lower().strip() for t in remaining["Title"].tolist()]
fuzzy_drop = set()
if len(titles) > 1:
from sklearn.feature_extraction.text import TfidfVectorizer
# Use TF-IDF char n-grams for very fast and robust fuzzy matching
vectorizer = TfidfVectorizer(analyzer='char_wb', ngram_range=(2, 4), min_df=1)
tfidf_matrix = vectorizer.fit_transform(titles)
# Compute cosine similarity matrix
similarity_matrix = tfidf_matrix.dot(tfidf_matrix.T).tocoo()
# We only care about upper triangle (i < j) where similarity is high
for i, j, v in zip(similarity_matrix.row, similarity_matrix.col, similarity_matrix.data):
if i < j and v >= 0.85:
# If i is not already dropped, drop j
if i not in fuzzy_drop:
fuzzy_drop.add(j)
for j in fuzzy_drop:
drop_idx.add(remaining.iloc[j]["index"])
deduped = df[~df.index.isin(drop_idx)].reset_index(drop=True)
return deduped, original_len - len(deduped)
# ─── MAIN ENTRY POINT ─────────────────────────────────────────────────────────
def load_and_preprocess(filepath: str) -> Tuple[pd.DataFrame, dict]:
print("\n========== PREPROCESSING STARTED ==========\n")
# ── Load CSV
print("[Step 1] Loading dataset...")
df = pd.read_csv(filepath)
print(f"[INFO] Loaded {len(df)} rows")
df.columns = [c.strip() for c in df.columns]
print(f"[INFO] Columns detected: {list(df.columns)}\n")
# ── Required columns check
print("[Step 2] Validating required columns...")
required = {"Title", "Abstract"}
missing_cols = required - set(df.columns)
if missing_cols:
raise ValueError(f"CSV is missing required columns: {missing_cols}")
print("[OK] Required columns present\n")
stats: dict = {"total": len(df)}
# ── Missing abstracts
print("[Step 3] Checking missing abstracts...")
missing_abstracts = int(df["Abstract"].isna().sum())
stats["missing_abstracts"] = missing_abstracts
print(f"[INFO] Missing abstracts: {missing_abstracts}\n")
# ── Drop missing titles
print("[Step 4] Cleaning missing titles...")
before = len(df)
df = df.dropna(subset=["Title"]).copy()
df["Abstract"] = df["Abstract"].fillna("")
print(f"[INFO] Dropped {before - len(df)} rows with missing titles")
print(f"[INFO] Remaining rows: {len(df)}\n")
stats["after_drop_title"] = len(df)
# ── DOI handling
print("[Step 5] Processing DOI column...")
doi_col = None
for candidate in ["DOI", "doi", "Document Object Identifier"]:
if candidate in df.columns:
doi_col = candidate
break
if doi_col is None:
raise ValueError("CSV must contain a DOI column. None found.")
elif doi_col != "DOI":
df = df.rename(columns={doi_col: "DOI"})
df["DOI"] = df["DOI"].fillna("").astype(str)
print(f"[INFO] Sample DOIs: {df['DOI'].head(3).tolist()}\n")
# ── Deduplication
print("[Step 6] Deduplication...")
before = len(df)
df, n_dupes = _deduplicate(df)
stats["duplicates_removed"] = n_dupes
print(f"[INFO] Removed {n_dupes} duplicates")
print(f"[INFO] Remaining rows: {len(df)}\n")
# ── Build combined text
print("[Step 7] Building combined text columns...")
df["combined_text_raw"] = df.apply(
lambda r: clean_text_raw(r["Title"], r["Abstract"]), axis=1
)
df["combined_text_clean"] = df.apply(
lambda r: clean_text_lower(r["Title"], r["Abstract"]), axis=1
)
print("[INFO] Sample combined_text_raw:")
print(df["combined_text_raw"].head(2).tolist(), "\n")
# ── Filter short text
print("[Step 8] Filtering short text entries (<100 chars)...")
before = len(df)
df = df[df["combined_text_raw"].str.len() >= 100].reset_index(drop=True)
removed = before - len(df)
print(f"[INFO] Removed {removed} short-text papers")
print(f"[INFO] Remaining rows: {len(df)}\n")
stats["final_count"] = len(df)
# ── Final validation
print("[Step 9] Final validation...")
if len(df) < 50:
raise ValueError(
f"Dataset too small after preprocessing: {len(df)} papers. Need at least 50."
)
print("\n========== PREPROCESSING COMPLETE ==========\n")
print(f"[SUMMARY]")
print(f"Total input: {stats['total']}")
print(f"Missing abstracts: {stats['missing_abstracts']}")
print(f"Duplicates removed: {stats['duplicates_removed']}")
print(f"Final dataset size: {stats['final_count']}\n")
return (
df[["DOI", "Title", "Abstract", "combined_text_raw", "combined_text_clean"]],
stats,
)