Upload scripts/analyze_class_dist.py
Browse files
scripts/analyze_class_dist.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Analyze TE family distribution from gzipped FASTA headers only (no sequences)."""
|
| 2 |
+
import gzip, sys, os, json
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
|
| 6 |
+
print("Downloading te_rmsk_hg38_fa.gz...")
|
| 7 |
+
p = hf_hub_download("vedatonuryilmaz/te_hg38", "te_rmsk_hg38_fa.gz",
|
| 8 |
+
repo_type="dataset", token=True)
|
| 9 |
+
print(f"Downloaded to: {p}")
|
| 10 |
+
|
| 11 |
+
rClass = Counter()
|
| 12 |
+
rFamily = Counter()
|
| 13 |
+
total, nontarget = 0, 0
|
| 14 |
+
NON_TARGET = {"Simple_repeat", "Low_complexity", "Unknown", "Other", "rRNA", "tRNA", "snRNA", "scRNA", "srpRNA", "RNA", "RC", "Satellite", "ARTEFACT"}
|
| 15 |
+
|
| 16 |
+
with gzip.open(p, "rt") as f:
|
| 17 |
+
for line in f:
|
| 18 |
+
if line.startswith(">"):
|
| 19 |
+
total += 1
|
| 20 |
+
parts = line[1:].strip().split("|")
|
| 21 |
+
if len(parts) >= 4:
|
| 22 |
+
cls = parts[1] if len(parts) > 1 else "Unknown"
|
| 23 |
+
fam = parts[2] if len(parts) > 2 else "Unknown"
|
| 24 |
+
if cls in NON_TARGET or fam in NON_TARGET:
|
| 25 |
+
nontarget += 1
|
| 26 |
+
continue
|
| 27 |
+
rClass[cls] += 1
|
| 28 |
+
rFamily[fam] += 1
|
| 29 |
+
if total % 500000 == 0:
|
| 30 |
+
print(f" ... {total:,} records, {len(rFamily)} target families")
|
| 31 |
+
|
| 32 |
+
# Results
|
| 33 |
+
print(f"\n=== RESULTS ===")
|
| 34 |
+
print(f"Total records: {total:,}")
|
| 35 |
+
print(f"Non-TE records (excluded): {nontarget:,}")
|
| 36 |
+
print(f"Target TE records: {total - nontarget:,}")
|
| 37 |
+
print(f"Target families: {len(rFamily)}")
|
| 38 |
+
print(f"Target classes (repClass): {len(rClass)}")
|
| 39 |
+
|
| 40 |
+
print(f"\n--- repClass ---")
|
| 41 |
+
for c, n in rClass.most_common():
|
| 42 |
+
print(f" {c:20s}: {n:>10,}")
|
| 43 |
+
|
| 44 |
+
print(f"\n--- Size thresholds ---")
|
| 45 |
+
for t in [10, 50, 100, 500, 1000, 5000, 10000, 100000]:
|
| 46 |
+
n = sum(1 for v in rFamily.values() if v >= t)
|
| 47 |
+
s = sum(v for v in rFamily.values() if v >= t)
|
| 48 |
+
print(f" >= {t:>6}: {n:>5} families, {s:>12,} records ({100*s/max(1,total-nontarget):.1f}%)")
|
| 49 |
+
|
| 50 |
+
print(f"\n--- Top 50 families ---")
|
| 51 |
+
for f, n in rFamily.most_common(50):
|
| 52 |
+
pct = 100 * n / max(1, total - nontarget)
|
| 53 |
+
print(f" {f:40s}: {n:>9,} ({pct:.1f}%)")
|
| 54 |
+
|
| 55 |
+
print(f"\n--- Tail ---")
|
| 56 |
+
tail = [n for n in rFamily.values() if n < 100]
|
| 57 |
+
print(f"Families with <100 samples: {len(tail)} (total: {sum(tail):,})")
|
| 58 |
+
tail10 = [n for n in rFamily.values() if n < 10]
|
| 59 |
+
print(f"Families with <10 samples: {len(tail10)} (total: {sum(tail10):,})")
|
| 60 |
+
|
| 61 |
+
# Recommended split
|
| 62 |
+
MIN = 100
|
| 63 |
+
good = {f: n for f, n in rFamily.items() if n >= MIN}
|
| 64 |
+
rubble = sum(n for f, n in rFamily.items() if n < MIN)
|
| 65 |
+
print(f"\n--- Recommendation: min={MIN} per family ---")
|
| 66 |
+
print(f" Good families: {len(good)}")
|
| 67 |
+
print(f" Total records in good: {sum(good.values()):,}")
|
| 68 |
+
print(f" Rubble records (group as 'other'): {rubble:,}")
|
| 69 |
+
print(f" Result: {len(good) + 1} classes ({len(good)} real + 1 bug bucket)")
|
| 70 |
+
|
| 71 |
+
# Save
|
| 72 |
+
out = {}
|
| 73 |
+
out["total_records"] = total
|
| 74 |
+
out["nontarget_excluded"] = nontarget
|
| 75 |
+
out["target_families"] = len(rFamily)
|
| 76 |
+
out["repClass_counts"] = dict(rClass.most_common())
|
| 77 |
+
out["family_counts"] = dict(rFamily.most_common())
|
| 78 |
+
out["thresholds"] = {str(t): {"families": sum(1 for v in rFamily.values() if v >= t),
|
| 79 |
+
"records": sum(v for v in rFamily.values() if v >= t)}
|
| 80 |
+
for t in [10, 50, 100, 500, 1000, 5000]}
|
| 81 |
+
out["recommended_families_min100"] = list(good.keys())
|
| 82 |
+
out["recommended_num_classes"] = len(good) + 1
|
| 83 |
+
json.dump(out, open("class_distribution.json", "w"), indent=2)
|
| 84 |
+
print("\nSaved class_distribution.json")
|