phanerozoic commited on
Commit
ba12ed6
·
verified ·
1 Parent(s): e40004a

v2: retrained on scaled+deduplicated benchmark (host 0.995/0.993, engineered 0.909/0.874, 5-class 0.69); add read-filter CLI

Browse files
Files changed (1) hide show
  1. dna_filter.py +80 -0
dna_filter.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Reference-free human/host read filter built on dna-origin-classifier.
3
+
4
+ Reads FASTA/FASTQ (optionally gzipped), scores each read with the host head of the
5
+ closed-form classifier, and either depletes host reads (pathogen enrichment) or removes
6
+ human reads (privacy). No alignment, no reference database; numpy + safetensors only.
7
+
8
+ Examples:
9
+ python dna_filter.py reads.fastq.gz --mode deplete-host --out nonhost.fasta
10
+ python dna_filter.py reads.fastq --mode scrub-human --out scrubbed.fasta --report calls.tsv
11
+ python dna_filter.py reads.fasta --mode classify --report calls.tsv
12
+ """
13
+ import argparse, gzip, sys, os
14
+ from model import DnaOriginClassifier, CLASSES
15
+
16
+ def opener(path):
17
+ return gzip.open(path, "rt") if path.endswith(".gz") else open(path)
18
+
19
+ def read_seqs(path):
20
+ """Yield (id, seq) from FASTA or FASTQ, gzip-aware."""
21
+ with opener(path) as f:
22
+ first = f.readline()
23
+ if not first:
24
+ return
25
+ is_fastq = first[0] == "@"
26
+ if is_fastq:
27
+ rid = first[1:].strip().split()[0]; seq = f.readline().strip()
28
+ while seq is not None and seq != "":
29
+ yield rid, seq
30
+ f.readline(); f.readline() # '+' and qualities
31
+ h = f.readline()
32
+ if not h: break
33
+ rid = h[1:].strip().split()[0]; seq = f.readline().strip()
34
+ else:
35
+ rid = first[1:].strip().split()[0]; chunks = []
36
+ for line in f:
37
+ if line.startswith(">"):
38
+ yield rid, "".join(chunks)
39
+ rid = line[1:].strip().split()[0]; chunks = []
40
+ else:
41
+ chunks.append(line.strip())
42
+ if chunks: yield rid, "".join(chunks)
43
+
44
+ def main():
45
+ ap = argparse.ArgumentParser(description="Reference-free human/host read filter")
46
+ ap.add_argument("input", help="FASTA/FASTQ file (.fa/.fq/.fasta/.fastq, optionally .gz)")
47
+ ap.add_argument("--mode", choices=["deplete-host", "scrub-human", "classify"], default="deplete-host")
48
+ ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__), "model.safetensors"))
49
+ ap.add_argument("--threshold", type=float, default=0.0,
50
+ help="host_score decision boundary (higher = more human). Default 0.0.")
51
+ ap.add_argument("--out", default=None, help="filtered sequences (FASTA). Omit in classify mode.")
52
+ ap.add_argument("--report", default=None, help="per-read TSV: id, host_score, call, origin")
53
+ args = ap.parse_args()
54
+
55
+ clf = DnaOriginClassifier(args.model)
56
+ out_f = open(args.out, "w") if args.out else None
57
+ rep_f = open(args.report, "w") if args.report else None
58
+ if rep_f: rep_f.write("read_id\thost_score\tcall\torigin\n")
59
+
60
+ n = kept = human = 0
61
+ for rid, seq in read_seqs(args.input):
62
+ n += 1
63
+ hs = clf.host_score(seq)
64
+ is_human = hs >= args.threshold
65
+ human += int(is_human)
66
+ if rep_f:
67
+ rep_f.write(f"{rid}\t{hs:.4f}\t{'human' if is_human else 'non-host'}\t{clf.classify(seq)}\n")
68
+ if out_f:
69
+ # deplete-host and scrub-human both emit the non-human reads
70
+ if not is_human:
71
+ out_f.write(f">{rid}\n{seq}\n"); kept += 1
72
+ if out_f: out_f.close()
73
+ if rep_f: rep_f.close()
74
+ msg = f"reads={n} human={human} non-host={n-human}"
75
+ if args.out and args.mode != "classify":
76
+ msg += f" | wrote {kept} {'non-host (pathogen-enriched)' if args.mode=='deplete-host' else 'non-human (privacy-scrubbed)'} reads -> {args.out}"
77
+ print(msg, file=sys.stderr)
78
+
79
+ if __name__ == "__main__":
80
+ main()