depscreen / ml /scripts /validate_dapt_corpus.py
halsabbah's picture
style: apply ruff format to pass CI format check
95974bc
"""
DAPT corpus validation β€” run BEFORE training on collected Reddit data.
Checks:
1. Data leakage: DAPT corpus must NOT overlap with ReDSM5 labeled data
2. Deduplication: exact and near-duplicate removal
3. Language quality: filter non-English, bot/spam, too-short posts
4. Corpus statistics: length distribution, vocabulary
5. Provenance documentation
Usage:
python validate_dapt_corpus.py
"""
import hashlib
import json
import re
from collections import Counter
from pathlib import Path
import pandas as pd
def load_redsm5_texts(redsm5_dir: Path) -> set[str]:
"""Load all labeled ReDSM5 text for leakage detection."""
posts = pd.read_csv(redsm5_dir / "redsm5_posts.csv")
annotations = pd.read_csv(redsm5_dir / "redsm5_annotations.csv")
# Collect all text that appears in the labeled dataset
texts = set()
for t in posts["text"].dropna():
texts.add(t.strip().lower())
for t in annotations["sentence_text"].dropna():
texts.add(t.strip().lower())
return texts
def normalize_for_dedup(text: str) -> str:
"""Normalize text for near-duplicate detection."""
text = text.lower().strip()
text = re.sub(r"\s+", " ", text)
text = re.sub(r"[^\w\s]", "", text) # remove punctuation
return text
def detect_bot_spam(text: str) -> bool:
"""Heuristic bot/spam detection."""
# Repetitive character patterns
if re.search(r"(.)\1{10,}", text):
return True
# All caps posts (shouting/spam)
alpha_chars = [c for c in text if c.isalpha()]
if len(alpha_chars) > 20 and sum(1 for c in alpha_chars if c.isupper()) / len(alpha_chars) > 0.8:
return True
# Promotional patterns
promo_patterns = [
r"check out my",
r"subscribe to",
r"use code",
r"discount",
r"buy now",
r"click here",
r"free trial",
r"www\.\S+\.(com|org|net)",
]
return any(re.search(pattern, text, re.IGNORECASE) for pattern in promo_patterns)
def detect_non_english(text: str) -> bool:
"""Simple heuristic: if <60% of words are common English words, flag it.
Uses a small set of the most common English words as a proxy.
"""
common_words = {
"the",
"be",
"to",
"of",
"and",
"a",
"in",
"that",
"have",
"i",
"it",
"for",
"not",
"on",
"with",
"he",
"as",
"you",
"do",
"at",
"this",
"but",
"his",
"by",
"from",
"they",
"we",
"say",
"her",
"she",
"or",
"an",
"will",
"my",
"one",
"all",
"would",
"there",
"their",
"what",
"so",
"up",
"out",
"if",
"about",
"who",
"get",
"which",
"go",
"me",
"when",
"make",
"can",
"like",
"time",
"no",
"just",
"him",
"know",
"take",
"people",
"into",
"year",
"your",
"good",
"some",
"could",
"them",
"see",
"other",
"than",
"then",
"now",
"look",
"only",
"come",
"its",
"over",
"think",
"also",
"back",
"after",
"use",
"two",
"how",
"our",
"work",
"first",
"well",
"way",
"even",
"new",
"want",
"because",
"any",
"these",
"give",
"day",
"most",
"us",
"am",
"been",
"was",
"is",
"are",
"don",
"feel",
"really",
"life",
"much",
"very",
"still",
"never",
"going",
"help",
"being",
"too",
"need",
"had",
"did",
"has",
}
words = text.lower().split()
if len(words) < 5:
return False # too short to judge
common_count = sum(1 for w in words if w in common_words)
ratio = common_count / len(words)
return ratio < 0.15 # very low overlap with English
def main():
project_root = Path(__file__).parent.parent.parent.parent
base_dir = Path(__file__).parent.parent
corpus_dir = base_dir / "data" / "dapt_corpus"
redsm5_dir = project_root / "redsm5"
corpus_file = corpus_dir / "dapt_corpus.txt"
print("=" * 60)
print("DAPT CORPUS VALIDATION")
print("=" * 60)
# Load corpus
with open(corpus_file) as f:
raw = f.read()
chunks = [c.strip() for c in raw.split("\n\n") if len(c.strip()) > 0]
print(f"\nRaw chunks loaded: {len(chunks)}")
# ── CHECK 1: Data Leakage ──────────────────────────────────────────────
print("\n── CHECK 1: Data Leakage Against ReDSM5 ──")
redsm5_texts = load_redsm5_texts(redsm5_dir)
print(f" ReDSM5 reference texts: {len(redsm5_texts)}")
leaked = []
for i, chunk in enumerate(chunks):
chunk_lower = chunk.strip().lower()
# Exact match
if chunk_lower in redsm5_texts:
leaked.append((i, chunk[:100]))
continue
# Substring match β€” check if any ReDSM5 sentence appears in the DAPT chunk
for ref in redsm5_texts:
if len(ref) >= 30 and ref in chunk_lower:
leaked.append((i, f"contains: '{ref[:80]}...'"))
break
if leaked:
print(f" LEAKAGE DETECTED: {len(leaked)} chunks overlap with ReDSM5!")
for idx, preview in leaked[:5]:
print(f" chunk {idx}: {preview}")
if len(leaked) > 5:
print(f" ... and {len(leaked) - 5} more")
else:
print(" NO LEAKAGE β€” zero overlap with ReDSM5 labeled data")
# ── CHECK 2: Exact Deduplication ───────────────────────────────────────
print("\n── CHECK 2: Deduplication ──")
seen_hashes = {}
exact_dupes = 0
unique_chunks = []
dupe_indices = set()
for i, chunk in enumerate(chunks):
h = hashlib.md5(chunk.encode()).hexdigest()
if h in seen_hashes:
exact_dupes += 1
dupe_indices.add(i)
else:
seen_hashes[h] = i
unique_chunks.append(chunk)
print(f" Exact duplicates: {exact_dupes}")
print(f" Unique chunks: {len(unique_chunks)}")
# Near-duplicate detection (normalized text)
seen_normalized = {}
near_dupes = 0
final_chunks = []
for chunk in unique_chunks:
norm = normalize_for_dedup(chunk)
if norm in seen_normalized:
near_dupes += 1
else:
seen_normalized[norm] = True
final_chunks.append(chunk)
print(f" Near-duplicates (after normalization): {near_dupes}")
print(f" After dedup: {len(final_chunks)} chunks")
# ── CHECK 3: Bot/Spam Detection ────────────────────────────────────────
print("\n── CHECK 3: Bot/Spam Detection ──")
spam_count = 0
spam_indices = []
clean_chunks = []
for i, chunk in enumerate(final_chunks):
if detect_bot_spam(chunk):
spam_count += 1
spam_indices.append(i)
else:
clean_chunks.append(chunk)
print(f" Bot/spam detected: {spam_count}")
if spam_indices[:3]:
for idx in spam_indices[:3]:
print(f" example: '{final_chunks[idx][:100]}...'")
print(f" After spam removal: {len(clean_chunks)} chunks")
# ── CHECK 4: Language Quality ──────────────────────────────────────────
print("\n── CHECK 4: Language Quality ──")
non_english = 0
too_short = 0
quality_chunks = []
for chunk in clean_chunks:
if len(chunk) < 30:
too_short += 1
continue
if detect_non_english(chunk):
non_english += 1
continue
quality_chunks.append(chunk)
print(f" Too short (<30 chars): {too_short}")
print(f" Non-English (heuristic): {non_english}")
print(f" After quality filter: {len(quality_chunks)} chunks")
# ── CHECK 5: Corpus Statistics ─────────────────────────────────────────
print("\n── CHECK 5: Corpus Statistics ──")
lengths = [len(c) for c in quality_chunks]
word_counts = [len(c.split()) for c in quality_chunks]
total_chars = sum(lengths)
total_words = sum(word_counts)
print(f" Final chunks: {len(quality_chunks)}")
print(f" Total characters: {total_chars:,}")
print(f" Total words: {total_words:,}")
print(f" Estimated tokens (~0.75 words/token): {int(total_words / 0.75):,}")
print(f" Avg chars/chunk: {total_chars / len(quality_chunks):.0f}")
print(f" Avg words/chunk: {total_words / len(quality_chunks):.0f}")
print(f" Median chars: {sorted(lengths)[len(lengths) // 2]}")
print(f" Min chars: {min(lengths)}, Max chars: {max(lengths)}")
# Length distribution buckets
buckets = Counter()
for l in lengths:
if l < 100:
buckets["<100"] += 1
elif l < 300:
buckets["100-300"] += 1
elif l < 500:
buckets["300-500"] += 1
elif l < 1000:
buckets["500-1000"] += 1
else:
buckets["1000+"] += 1
print("\n Length distribution:")
for bucket in ["<100", "100-300", "300-500", "500-1000", "1000+"]:
count = buckets.get(bucket, 0)
pct = count / len(quality_chunks) * 100
bar = "#" * int(pct / 2)
print(f" {bucket:>10}: {count:>6} ({pct:>5.1f}%) {bar}")
# ── Save cleaned corpus ────────────────────────────────────────────────
print("\n── Saving cleaned corpus ──")
cleaned_file = corpus_dir / "dapt_corpus_cleaned.txt"
with open(cleaned_file, "w") as f:
f.write("\n\n".join(quality_chunks))
# Save validation report
report = {
"raw_chunks": len(chunks),
"leakage_detected": len(leaked),
"exact_duplicates_removed": exact_dupes,
"near_duplicates_removed": near_dupes,
"spam_removed": spam_count,
"too_short_removed": too_short,
"non_english_removed": non_english,
"final_chunks": len(quality_chunks),
"total_characters": total_chars,
"total_words": total_words,
"estimated_tokens": int(total_words / 0.75),
"avg_chars_per_chunk": round(total_chars / len(quality_chunks)),
"avg_words_per_chunk": round(total_words / len(quality_chunks)),
"validation_passed": len(leaked) == 0,
}
with open(corpus_dir / "validation_report.json", "w") as f:
json.dump(report, f, indent=2)
print(f" Cleaned corpus: {cleaned_file}")
print(f" Validation report: {corpus_dir / 'validation_report.json'}")
# ── VERDICT ────────────────────────────────────────────────────────────
print(f"\n{'=' * 60}")
if len(leaked) > 0:
print("VERDICT: FAIL β€” Data leakage detected. Must remove leaked chunks before DAPT.")
else:
removed = len(chunks) - len(quality_chunks)
print(f"VERDICT: PASS β€” No leakage. Removed {removed} chunks ({removed / len(chunks) * 100:.1f}%) for quality.")
print(
f"Cleaned corpus: {len(quality_chunks)} chunks, {total_words:,} words, ~{int(total_words / 0.75):,} tokens"
)
print("=" * 60)
if __name__ == "__main__":
main()