genatator-pipeline / tests /test_gene_memory_optimizations.py
gettheworkdone
Optimize gene-finding memory with sparse edge peaks and compact masks
b428993
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import numpy as np
import genatator_core as core
from configuration_genatator_pipeline import GenatatorPipelineConfig
def to_tuple(intervals):
return [(i.start,i.end,i.strand) for i in intervals]
def test_dense_sparse_pairing_equivalence():
arr=np.zeros((4,50),dtype=np.uint8); arr[0,[5,10]]=1; arr[1,[20,25]]=1; arr[2,[40]]=1; arr[3,[30,35]]=1
d=core.find_tss_polya_pairs_right_left_only(arr,'chr',window_size=30,k=2)
s=core.find_tss_polya_pairs_from_peak_indices(np.where(arr[0])[0],np.where(arr[1])[0],np.where(arr[2])[0],np.where(arr[3])[0],50,'chr',window_size=30,k=2)
assert to_tuple(d)==to_tuple(s)
def test_filter_bool_equivalence():
pairs=[core.TranscriptInterval('c',0,10,'+',[]),core.TranscriptInterval('c',10,20,'-',[])]
plus=np.linspace(0,1,20,dtype=np.float32); minus=plus[::-1].copy(); th=0.5
old=core.filter_intervals_by_intragenic(pairs,plus,minus,th,0.8)
new=core.filter_intervals_by_intragenic_bool(pairs,plus>th,minus>th,0.8)
assert to_tuple(old)==to_tuple(new)
def test_config_param_alias():
c=GenatatorPipelineConfig()
assert c.gene_finding_global_chunk_size==70_000_000
assert 'gene_finding_global_chunk_size' in c.to_runtime_defaults()
c2=GenatatorPipelineConfig(global_chunk_size=123)
assert c2.gene_finding_global_chunk_size==123
def test_edge_chunk_calls(monkeypatch):
calls=[]
def fake(**kwargs):
seq=kwargs['sequence']; calls.append(len(seq)); n=len(seq)
x=np.zeros((4,n),dtype=np.float32)
if n>4: x[0,1]=1; x[1,n-2]=1
return x
monkeypatch.setattr(core,'infer_token_classification_tracks_with_rc',fake)
out=core.infer_edge_peak_indices_with_global_chunks('A'*25,None,None,['a','b','c','d'],(0,1,2,3),10,1,0.5,1,1,False,None,1,type('D',(),{'type':'cpu'})(),False,False,0.5,0.0,1,None,'x')
assert len(calls)>=3
assert set(out.keys())=={'tss_plus','polya_plus','tss_minus','polya_minus'}
assert out['tss_plus'].dtype==np.int64