dna-origin-classifier / dna_filter.py
phanerozoic's picture
v2: retrained on scaled+deduplicated benchmark (host 0.995/0.993, engineered 0.909/0.874, 5-class 0.69); add read-filter CLI
ba12ed6 verified
#!/usr/bin/env python
"""Reference-free human/host read filter built on dna-origin-classifier.
Reads FASTA/FASTQ (optionally gzipped), scores each read with the host head of the
closed-form classifier, and either depletes host reads (pathogen enrichment) or removes
human reads (privacy). No alignment, no reference database; numpy + safetensors only.
Examples:
python dna_filter.py reads.fastq.gz --mode deplete-host --out nonhost.fasta
python dna_filter.py reads.fastq --mode scrub-human --out scrubbed.fasta --report calls.tsv
python dna_filter.py reads.fasta --mode classify --report calls.tsv
"""
import argparse, gzip, sys, os
from model import DnaOriginClassifier, CLASSES
def opener(path):
return gzip.open(path, "rt") if path.endswith(".gz") else open(path)
def read_seqs(path):
"""Yield (id, seq) from FASTA or FASTQ, gzip-aware."""
with opener(path) as f:
first = f.readline()
if not first:
return
is_fastq = first[0] == "@"
if is_fastq:
rid = first[1:].strip().split()[0]; seq = f.readline().strip()
while seq is not None and seq != "":
yield rid, seq
f.readline(); f.readline() # '+' and qualities
h = f.readline()
if not h: break
rid = h[1:].strip().split()[0]; seq = f.readline().strip()
else:
rid = first[1:].strip().split()[0]; chunks = []
for line in f:
if line.startswith(">"):
yield rid, "".join(chunks)
rid = line[1:].strip().split()[0]; chunks = []
else:
chunks.append(line.strip())
if chunks: yield rid, "".join(chunks)
def main():
ap = argparse.ArgumentParser(description="Reference-free human/host read filter")
ap.add_argument("input", help="FASTA/FASTQ file (.fa/.fq/.fasta/.fastq, optionally .gz)")
ap.add_argument("--mode", choices=["deplete-host", "scrub-human", "classify"], default="deplete-host")
ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__), "model.safetensors"))
ap.add_argument("--threshold", type=float, default=0.0,
help="host_score decision boundary (higher = more human). Default 0.0.")
ap.add_argument("--out", default=None, help="filtered sequences (FASTA). Omit in classify mode.")
ap.add_argument("--report", default=None, help="per-read TSV: id, host_score, call, origin")
args = ap.parse_args()
clf = DnaOriginClassifier(args.model)
out_f = open(args.out, "w") if args.out else None
rep_f = open(args.report, "w") if args.report else None
if rep_f: rep_f.write("read_id\thost_score\tcall\torigin\n")
n = kept = human = 0
for rid, seq in read_seqs(args.input):
n += 1
hs = clf.host_score(seq)
is_human = hs >= args.threshold
human += int(is_human)
if rep_f:
rep_f.write(f"{rid}\t{hs:.4f}\t{'human' if is_human else 'non-host'}\t{clf.classify(seq)}\n")
if out_f:
# deplete-host and scrub-human both emit the non-human reads
if not is_human:
out_f.write(f">{rid}\n{seq}\n"); kept += 1
if out_f: out_f.close()
if rep_f: rep_f.close()
msg = f"reads={n} human={human} non-host={n-human}"
if args.out and args.mode != "classify":
msg += f" | wrote {kept} {'non-host (pathogen-enriched)' if args.mode=='deplete-host' else 'non-human (privacy-scrubbed)'} reads -> {args.out}"
print(msg, file=sys.stderr)
if __name__ == "__main__":
main()