#!/usr/bin/env python """Reference-free human/host read filter built on dna-origin-classifier. Reads FASTA/FASTQ (optionally gzipped), scores each read with the host head of the closed-form classifier, and either depletes host reads (pathogen enrichment) or removes human reads (privacy). No alignment, no reference database; numpy + safetensors only. Examples: python dna_filter.py reads.fastq.gz --mode deplete-host --out nonhost.fasta python dna_filter.py reads.fastq --mode scrub-human --out scrubbed.fasta --report calls.tsv python dna_filter.py reads.fasta --mode classify --report calls.tsv """ import argparse, gzip, sys, os from model import DnaOriginClassifier, CLASSES def opener(path): return gzip.open(path, "rt") if path.endswith(".gz") else open(path) def read_seqs(path): """Yield (id, seq) from FASTA or FASTQ, gzip-aware.""" with opener(path) as f: first = f.readline() if not first: return is_fastq = first[0] == "@" if is_fastq: rid = first[1:].strip().split()[0]; seq = f.readline().strip() while seq is not None and seq != "": yield rid, seq f.readline(); f.readline() # '+' and qualities h = f.readline() if not h: break rid = h[1:].strip().split()[0]; seq = f.readline().strip() else: rid = first[1:].strip().split()[0]; chunks = [] for line in f: if line.startswith(">"): yield rid, "".join(chunks) rid = line[1:].strip().split()[0]; chunks = [] else: chunks.append(line.strip()) if chunks: yield rid, "".join(chunks) def main(): ap = argparse.ArgumentParser(description="Reference-free human/host read filter") ap.add_argument("input", help="FASTA/FASTQ file (.fa/.fq/.fasta/.fastq, optionally .gz)") ap.add_argument("--mode", choices=["deplete-host", "scrub-human", "classify"], default="deplete-host") ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__), "model.safetensors")) ap.add_argument("--threshold", type=float, default=0.0, help="host_score decision boundary (higher = more human). Default 0.0.") ap.add_argument("--out", default=None, help="filtered sequences (FASTA). Omit in classify mode.") ap.add_argument("--report", default=None, help="per-read TSV: id, host_score, call, origin") args = ap.parse_args() clf = DnaOriginClassifier(args.model) out_f = open(args.out, "w") if args.out else None rep_f = open(args.report, "w") if args.report else None if rep_f: rep_f.write("read_id\thost_score\tcall\torigin\n") n = kept = human = 0 for rid, seq in read_seqs(args.input): n += 1 hs = clf.host_score(seq) is_human = hs >= args.threshold human += int(is_human) if rep_f: rep_f.write(f"{rid}\t{hs:.4f}\t{'human' if is_human else 'non-host'}\t{clf.classify(seq)}\n") if out_f: # deplete-host and scrub-human both emit the non-human reads if not is_human: out_f.write(f">{rid}\n{seq}\n"); kept += 1 if out_f: out_f.close() if rep_f: rep_f.close() msg = f"reads={n} human={human} non-host={n-human}" if args.out and args.mode != "classify": msg += f" | wrote {kept} {'non-host (pathogen-enriched)' if args.mode=='deplete-host' else 'non-human (privacy-scrubbed)'} reads -> {args.out}" print(msg, file=sys.stderr) if __name__ == "__main__": main()