v2: retrained on scaled+deduplicated benchmark (host 0.995/0.993, engineered 0.909/0.874, 5-class 0.69); add read-filter CLI
ba12ed6 verified | #!/usr/bin/env python | |
| """Reference-free human/host read filter built on dna-origin-classifier. | |
| Reads FASTA/FASTQ (optionally gzipped), scores each read with the host head of the | |
| closed-form classifier, and either depletes host reads (pathogen enrichment) or removes | |
| human reads (privacy). No alignment, no reference database; numpy + safetensors only. | |
| Examples: | |
| python dna_filter.py reads.fastq.gz --mode deplete-host --out nonhost.fasta | |
| python dna_filter.py reads.fastq --mode scrub-human --out scrubbed.fasta --report calls.tsv | |
| python dna_filter.py reads.fasta --mode classify --report calls.tsv | |
| """ | |
| import argparse, gzip, sys, os | |
| from model import DnaOriginClassifier, CLASSES | |
| def opener(path): | |
| return gzip.open(path, "rt") if path.endswith(".gz") else open(path) | |
| def read_seqs(path): | |
| """Yield (id, seq) from FASTA or FASTQ, gzip-aware.""" | |
| with opener(path) as f: | |
| first = f.readline() | |
| if not first: | |
| return | |
| is_fastq = first[0] == "@" | |
| if is_fastq: | |
| rid = first[1:].strip().split()[0]; seq = f.readline().strip() | |
| while seq is not None and seq != "": | |
| yield rid, seq | |
| f.readline(); f.readline() # '+' and qualities | |
| h = f.readline() | |
| if not h: break | |
| rid = h[1:].strip().split()[0]; seq = f.readline().strip() | |
| else: | |
| rid = first[1:].strip().split()[0]; chunks = [] | |
| for line in f: | |
| if line.startswith(">"): | |
| yield rid, "".join(chunks) | |
| rid = line[1:].strip().split()[0]; chunks = [] | |
| else: | |
| chunks.append(line.strip()) | |
| if chunks: yield rid, "".join(chunks) | |
| def main(): | |
| ap = argparse.ArgumentParser(description="Reference-free human/host read filter") | |
| ap.add_argument("input", help="FASTA/FASTQ file (.fa/.fq/.fasta/.fastq, optionally .gz)") | |
| ap.add_argument("--mode", choices=["deplete-host", "scrub-human", "classify"], default="deplete-host") | |
| ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__), "model.safetensors")) | |
| ap.add_argument("--threshold", type=float, default=0.0, | |
| help="host_score decision boundary (higher = more human). Default 0.0.") | |
| ap.add_argument("--out", default=None, help="filtered sequences (FASTA). Omit in classify mode.") | |
| ap.add_argument("--report", default=None, help="per-read TSV: id, host_score, call, origin") | |
| args = ap.parse_args() | |
| clf = DnaOriginClassifier(args.model) | |
| out_f = open(args.out, "w") if args.out else None | |
| rep_f = open(args.report, "w") if args.report else None | |
| if rep_f: rep_f.write("read_id\thost_score\tcall\torigin\n") | |
| n = kept = human = 0 | |
| for rid, seq in read_seqs(args.input): | |
| n += 1 | |
| hs = clf.host_score(seq) | |
| is_human = hs >= args.threshold | |
| human += int(is_human) | |
| if rep_f: | |
| rep_f.write(f"{rid}\t{hs:.4f}\t{'human' if is_human else 'non-host'}\t{clf.classify(seq)}\n") | |
| if out_f: | |
| # deplete-host and scrub-human both emit the non-human reads | |
| if not is_human: | |
| out_f.write(f">{rid}\n{seq}\n"); kept += 1 | |
| if out_f: out_f.close() | |
| if rep_f: rep_f.close() | |
| msg = f"reads={n} human={human} non-host={n-human}" | |
| if args.out and args.mode != "classify": | |
| msg += f" | wrote {kept} {'non-host (pathogen-enriched)' if args.mode=='deplete-host' else 'non-human (privacy-scrubbed)'} reads -> {args.out}" | |
| print(msg, file=sys.stderr) | |
| if __name__ == "__main__": | |
| main() | |