bio-nexus-api / app /tools /sequencing.py
Samad14's picture
Upload app/tools/sequencing.py with huggingface_hub
c599f93 verified
Raw
History Blame Contribute Delete
14.4 kB
import asyncio
import logging
import os
import re
import shutil
import tempfile
from typing import Any
import httpx
from app.tools.base import BaseTool
logger = logging.getLogger(__name__)
BIN_DIR = os.path.join(os.path.dirname(__file__), "..", "bin")
MINIMAP2_PATH = shutil.which("minimap2") or os.path.join(BIN_DIR, "minimap2")
MINIMAP2_URL = "https://github.com/lh3/minimap2/releases/download/v2.28/minimap2-2.28_x64-linux.tar.bz2"
PIPELINE_TIMEOUT = 600
REFERENCE_URLS = {
"sars-cov-2": "https://hgdownload.soe.ucsc.edu/goldenPath/wuhCor1/bigZips/wuhCor1.fa.gz",
"lambda": "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=nuccore&id=NC_001416&rettype=fasta&retmode=text",
}
SMALL_REFERENCE = "sars-cov-2"
MAX_FASTQ_SIZE = 50 * 1024 * 1024
REF_CACHE_DIR = os.path.join(os.path.dirname(__file__), "..", "data", "references")
async def _ensure_minimap2() -> str:
if os.path.exists(MINIMAP2_PATH) and os.access(MINIMAP2_PATH, os.X_OK):
return MINIMAP2_PATH
dest = MINIMAP2_PATH
os.makedirs(BIN_DIR, exist_ok=True)
logger.info("Downloading minimap2 binary ...")
async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
r = await client.get(MINIMAP2_URL)
r.raise_for_status()
import tarfile, io
with tarfile.open(fileobj=io.BytesIO(r.content)) as tar:
for member in tar.getmembers():
if member.name.endswith("minimap2"):
f = tar.extractfile(member)
if f:
with open(dest, "wb") as out:
out.write(f.read())
break
os.chmod(dest, 0o755)
return dest
def _generate_synthetic_fastq(ref_seq: str, num_reads: int = 100, read_len: int = 100) -> str:
import random
ref = "".join(line.strip().upper() for line in ref_seq.splitlines() if not line.startswith(">"))
if len(ref) < read_len:
ref = ref * ((read_len // len(ref)) + 1)
lines: list[str] = []
for i in range(num_reads):
start = random.randint(0, len(ref) - read_len)
seq = ref[start:start + read_len]
mut_rate = 0.01
seq = "".join(
random.choice("ACGT") if random.random() < mut_rate else b
for b in seq
)
qual = "".join(chr(33 + min(40, random.randint(20, 40))) for _ in range(read_len))
lines.append(f"@read{i + 1}")
lines.append(seq)
lines.append("+")
lines.append(qual)
return "\n".join(lines)
def _parse_fastq_quality(fastq_path: str) -> dict:
total_reads = 0
total_bases = 0
gc_count = 0
at_count = 0
q_scores: list[int] = []
read_lengths: list[int] = []
seen_seqs: dict[str, int] = {}
line_no = 0
with open(fastq_path) as f:
for line in f:
line_no += 1
if line_no % 4 == 1:
total_reads += 1
elif line_no % 4 == 2:
seq = line.strip()
l = len(seq)
read_lengths.append(l)
total_bases += l
gc_count += seq.count("G") + seq.count("C") + seq.count("g") + seq.count("c")
at_count += seq.count("A") + seq.count("T") + seq.count("a") + seq.count("t")
seen_seqs[seq] = seen_seqs.get(seq, 0) + 1
elif line_no % 4 == 0:
qual = line.strip()
for ch in qual:
q_scores.append(ord(ch) - 33)
if total_reads == 0:
return {"error": "Empty FASTQ file", "total_reads": 0}
mean_q = sum(q_scores) / len(q_scores) if q_scores else 0
min_q = min(q_scores) if q_scores else 0
max_q = max(q_scores) if q_scores else 0
q20 = sum(1 for q in q_scores if q >= 20) / len(q_scores) * 100 if q_scores else 0
q30 = sum(1 for q in q_scores if q >= 30) / len(q_scores) * 100 if q_scores else 0
gc_pct = gc_count / (gc_count + at_count) * 100 if (gc_count + at_count) > 0 else 0
avg_len = sum(read_lengths) / len(read_lengths) if read_lengths else 0
overrepresented = sorted(seen_seqs.items(), key=lambda x: -x[1])[:10]
overrep_pct = [(s, c, c / total_reads * 100) for s, c in overrepresented]
return {
"total_reads": total_reads,
"total_bases": total_bases,
"avg_read_length": round(avg_len, 1),
"min_read_length": min(read_lengths) if read_lengths else 0,
"max_read_length": max(read_lengths) if read_lengths else 0,
"gc_percent": round(gc_pct, 2),
"mean_quality": round(mean_q, 2),
"min_quality": min_q,
"max_quality": max_q,
"q20_percent": round(q20, 2),
"q30_percent": round(q30, 2),
"overrepresented_sequences": [
{"sequence": s[:50], "count": c, "percent": round(p, 2)}
for s, c, p in overrep_pct
],
}
def _parse_sam_for_variants(sam_path: str, reference_seq: str) -> list[dict]:
ref_lines = reference_seq.splitlines()
ref = "".join(line.strip().upper() for line in ref_lines if not line.startswith(">"))
pileup: dict[int, dict[str, int]] = {}
depth_by_pos: dict[int, int] = {}
with open(sam_path) as f:
for line in f:
if line.startswith("@"):
continue
parts = line.strip().split("\t")
if len(parts) < 6:
continue
flag = int(parts[1])
if flag & 4:
continue
pos = int(parts[3])
cigar = parts[5]
seq = parts[9]
genome_pos = pos - 1
ops = re.findall(r"(\d+)([MIDNSHPX=])", cigar)
offset = 0
for length, op in ops:
l = int(length)
if op == "M":
for i in range(l):
p = genome_pos + i
if p < len(ref):
base = seq[offset + i].upper() if offset + i < len(seq) else "N"
if p not in pileup:
pileup[p] = {"A": 0, "C": 0, "G": 0, "T": 0, "N": 0, "del": 0, "ins": 0}
depth_by_pos[p] = depth_by_pos.get(p, 0) + 1
if base in pileup[p]:
pileup[p][base] += 1
else:
pileup[p]["N"] += 1
offset += l
elif op == "I":
offset += l
elif op == "D":
for i in range(l):
p = genome_pos + i
if p not in pileup:
pileup[p] = {"A": 0, "C": 0, "G": 0, "T": 0, "N": 0, "del": 0, "ins": 0}
pileup[p]["del"] += 1
elif op in ("S", "H"):
if op == "S":
offset += l
min_depth = 2
min_alt_freq = 0.2
variants: list[dict] = []
for pos in sorted(pileup.keys()):
counts = pileup[pos]
depth = depth_by_pos.get(pos, sum(counts.values()) - counts.get("del", 0) - counts.get("ins", 0))
if depth < min_depth:
continue
ref_base = ref[pos].upper() if pos < len(ref) else "N"
total = sum(counts.get(b, 0) for b in "ACGTN")
if total == 0:
continue
for base in "ACGT":
if base == ref_base:
continue
alt_count = counts.get(base, 0)
freq = alt_count / total
if freq >= min_alt_freq:
variants.append({
"pos": pos + 1, "ref": ref_base, "alt": base,
"depth": depth, "alt_count": alt_count, "freq": round(freq, 4),
})
variants.sort(key=lambda v: -v["freq"])
return variants[:50]
def _generate_report(qc: dict, variants: list[dict], ref_name: str) -> dict:
total_variants = len(variants)
snv_count = sum(1 for v in variants if len(v["ref"]) == 1 and len(v["alt"]) == 1)
avg_depth = round(sum(v["depth"] for v in variants) / total_variants, 1) if total_variants else 0
return {
"reference": ref_name,
"qc_summary": {
"total_reads": qc.get("total_reads", 0),
"total_bases": qc.get("total_bases", 0),
"mean_quality": qc.get("mean_quality", 0),
"q30_percent": qc.get("q30_percent", 0),
"gc_percent": qc.get("gc_percent", 0),
},
"variant_summary": {
"total_variants": total_variants,
"snv_count": snv_count,
"avg_depth": avg_depth,
},
"variants": variants,
}
async def _download_fastq(url: str, dest: str) -> str:
async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client:
async with client.stream("GET", url) as r:
r.raise_for_status()
content_length = int(r.headers.get("content-length", 0))
if content_length > MAX_FASTQ_SIZE:
raise ValueError(f"FASTQ too large: {content_length} bytes (max {MAX_FASTQ_SIZE})")
with open(dest, "wb") as f:
async for chunk in r.aiter_bytes():
f.write(chunk)
return dest
async def _download_reference(ref_name: str, dest_dir: str | None = None) -> str:
url = REFERENCE_URLS.get(ref_name)
if not url:
raise ValueError(f"Unknown reference genome: {ref_name}")
cache_dir = dest_dir or REF_CACHE_DIR
os.makedirs(cache_dir, exist_ok=True)
fa_path = os.path.join(cache_dir, f"{ref_name}.fa")
if os.path.exists(fa_path) and os.path.getsize(fa_path) > 0:
logger.info(f"Using cached reference {ref_name} ({os.path.getsize(fa_path)} bytes)")
return fa_path
async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client:
r = await client.get(url)
r.raise_for_status()
data = r.content
if url.endswith(".gz"):
import gzip
data = gzip.decompress(data)
with open(fa_path, "wb") as f:
f.write(data)
return fa_path
class SequencingPipeline(BaseTool):
name = "sequencing"
async def run(self, input: dict) -> dict:
fastq_url = input.get("fastq_url", "").strip()
reference = input.get("reference", SMALL_REFERENCE).strip().lower()
if not fastq_url:
return {"error": "fastq_url is required"}
tmpdir = tempfile.mkdtemp(prefix="seqpipe_")
try:
ref_path = await _download_reference(reference)
with open(ref_path) as f:
ref_content = f.read()
fastq_path = os.path.join(tmpdir, "input.fastq")
synthetic = fastq_url.lower() in ("synthetic", "demo", "test")
fastq_source = "synthetic"
if synthetic:
logger.info("Generating synthetic FASTQ reads")
fastq_data = _generate_synthetic_fastq(ref_content, num_reads=500, read_len=100)
with open(fastq_path, "w") as f:
f.write(fastq_data)
else:
fastq_source = "url"
try:
await asyncio.wait_for(_download_fastq(fastq_url, fastq_path), timeout=120)
except Exception:
logger.info("FASTQ download failed, generating synthetic reads from reference")
fastq_source = "synthetic"
fastq_data = _generate_synthetic_fastq(ref_content, num_reads=500, read_len=100)
with open(fastq_path, "w") as f:
f.write(fastq_data)
qc = _parse_fastq_quality(fastq_path)
if "error" in qc:
return {"error": qc["error"], "step": "qc"}
mm2_path = await asyncio.wait_for(_ensure_minimap2(), timeout=120)
sam_path = os.path.join(tmpdir, "aln.sam")
minimap2_proc = await asyncio.create_subprocess_exec(
mm2_path, "-ax", "sr", ref_path, fastq_path,
"-o", sam_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
mm_stdout, mm_stderr = await asyncio.wait_for(minimap2_proc.communicate(), timeout=300)
except asyncio.TimeoutError:
minimap2_proc.kill()
await minimap2_proc.communicate()
return {"error": "Alignment timed out after 5 minutes", "step": "align"}
if minimap2_proc.returncode != 0 or not os.path.exists(sam_path):
err = mm_stderr.decode("utf-8", errors="replace")[:500] if mm_stderr else ""
return {"error": f"minimap2 failed (exit {minimap2_proc.returncode}): {err}", "step": "align"}
aln_stats = {"mapped_reads": 0, "unmapped_reads": 0, "total_alignments": 0}
with open(sam_path) as f:
for line in f:
if line.startswith("@"):
continue
aln_stats["total_alignments"] += 1
parts = line.strip().split("\t", maxsplit=2)
if len(parts) >= 2:
flag = int(parts[1])
if flag & 4:
aln_stats["unmapped_reads"] += 1
else:
aln_stats["mapped_reads"] += 1
variants = _parse_sam_for_variants(sam_path, ref_content)
report = _generate_report(qc, variants, reference)
return {
"reference": reference,
"fastq_source": fastq_source,
"qc": qc,
"alignment": aln_stats,
"variants": variants[:20],
"report": report,
"steps_completed": ["qc", "align", "variants", "report"],
}
except ValueError as e:
return {"error": str(e)}
except httpx.HTTPStatusError as e:
return {"error": f"Download failed (HTTP {e.response.status_code})"}
except asyncio.TimeoutError:
return {"error": "Pipeline timed out"}
except Exception as e:
logger.exception("Sequencing pipeline failed")
return {"error": f"Pipeline failed: {e}"}
finally:
shutil.rmtree(tmpdir, ignore_errors=True)