Spaces:
Running
Running
| import asyncio | |
| import logging | |
| import os | |
| import re | |
| import shutil | |
| import tempfile | |
| from typing import Any | |
| import httpx | |
| from app.tools.base import BaseTool | |
| logger = logging.getLogger(__name__) | |
| BIN_DIR = os.path.join(os.path.dirname(__file__), "..", "bin") | |
| MINIMAP2_PATH = shutil.which("minimap2") or os.path.join(BIN_DIR, "minimap2") | |
| MINIMAP2_URL = "https://github.com/lh3/minimap2/releases/download/v2.28/minimap2-2.28_x64-linux.tar.bz2" | |
| PIPELINE_TIMEOUT = 600 | |
| REFERENCE_URLS = { | |
| "sars-cov-2": "https://hgdownload.soe.ucsc.edu/goldenPath/wuhCor1/bigZips/wuhCor1.fa.gz", | |
| "lambda": "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=nuccore&id=NC_001416&rettype=fasta&retmode=text", | |
| } | |
| SMALL_REFERENCE = "sars-cov-2" | |
| MAX_FASTQ_SIZE = 50 * 1024 * 1024 | |
| REF_CACHE_DIR = os.path.join(os.path.dirname(__file__), "..", "data", "references") | |
| async def _ensure_minimap2() -> str: | |
| if os.path.exists(MINIMAP2_PATH) and os.access(MINIMAP2_PATH, os.X_OK): | |
| return MINIMAP2_PATH | |
| dest = MINIMAP2_PATH | |
| os.makedirs(BIN_DIR, exist_ok=True) | |
| logger.info("Downloading minimap2 binary ...") | |
| async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client: | |
| r = await client.get(MINIMAP2_URL) | |
| r.raise_for_status() | |
| import tarfile, io | |
| with tarfile.open(fileobj=io.BytesIO(r.content)) as tar: | |
| for member in tar.getmembers(): | |
| if member.name.endswith("minimap2"): | |
| f = tar.extractfile(member) | |
| if f: | |
| with open(dest, "wb") as out: | |
| out.write(f.read()) | |
| break | |
| os.chmod(dest, 0o755) | |
| return dest | |
| def _generate_synthetic_fastq(ref_seq: str, num_reads: int = 100, read_len: int = 100) -> str: | |
| import random | |
| ref = "".join(line.strip().upper() for line in ref_seq.splitlines() if not line.startswith(">")) | |
| if len(ref) < read_len: | |
| ref = ref * ((read_len // len(ref)) + 1) | |
| lines: list[str] = [] | |
| for i in range(num_reads): | |
| start = random.randint(0, len(ref) - read_len) | |
| seq = ref[start:start + read_len] | |
| mut_rate = 0.01 | |
| seq = "".join( | |
| random.choice("ACGT") if random.random() < mut_rate else b | |
| for b in seq | |
| ) | |
| qual = "".join(chr(33 + min(40, random.randint(20, 40))) for _ in range(read_len)) | |
| lines.append(f"@read{i + 1}") | |
| lines.append(seq) | |
| lines.append("+") | |
| lines.append(qual) | |
| return "\n".join(lines) | |
| def _parse_fastq_quality(fastq_path: str) -> dict: | |
| total_reads = 0 | |
| total_bases = 0 | |
| gc_count = 0 | |
| at_count = 0 | |
| q_scores: list[int] = [] | |
| read_lengths: list[int] = [] | |
| seen_seqs: dict[str, int] = {} | |
| line_no = 0 | |
| with open(fastq_path) as f: | |
| for line in f: | |
| line_no += 1 | |
| if line_no % 4 == 1: | |
| total_reads += 1 | |
| elif line_no % 4 == 2: | |
| seq = line.strip() | |
| l = len(seq) | |
| read_lengths.append(l) | |
| total_bases += l | |
| gc_count += seq.count("G") + seq.count("C") + seq.count("g") + seq.count("c") | |
| at_count += seq.count("A") + seq.count("T") + seq.count("a") + seq.count("t") | |
| seen_seqs[seq] = seen_seqs.get(seq, 0) + 1 | |
| elif line_no % 4 == 0: | |
| qual = line.strip() | |
| for ch in qual: | |
| q_scores.append(ord(ch) - 33) | |
| if total_reads == 0: | |
| return {"error": "Empty FASTQ file", "total_reads": 0} | |
| mean_q = sum(q_scores) / len(q_scores) if q_scores else 0 | |
| min_q = min(q_scores) if q_scores else 0 | |
| max_q = max(q_scores) if q_scores else 0 | |
| q20 = sum(1 for q in q_scores if q >= 20) / len(q_scores) * 100 if q_scores else 0 | |
| q30 = sum(1 for q in q_scores if q >= 30) / len(q_scores) * 100 if q_scores else 0 | |
| gc_pct = gc_count / (gc_count + at_count) * 100 if (gc_count + at_count) > 0 else 0 | |
| avg_len = sum(read_lengths) / len(read_lengths) if read_lengths else 0 | |
| overrepresented = sorted(seen_seqs.items(), key=lambda x: -x[1])[:10] | |
| overrep_pct = [(s, c, c / total_reads * 100) for s, c in overrepresented] | |
| return { | |
| "total_reads": total_reads, | |
| "total_bases": total_bases, | |
| "avg_read_length": round(avg_len, 1), | |
| "min_read_length": min(read_lengths) if read_lengths else 0, | |
| "max_read_length": max(read_lengths) if read_lengths else 0, | |
| "gc_percent": round(gc_pct, 2), | |
| "mean_quality": round(mean_q, 2), | |
| "min_quality": min_q, | |
| "max_quality": max_q, | |
| "q20_percent": round(q20, 2), | |
| "q30_percent": round(q30, 2), | |
| "overrepresented_sequences": [ | |
| {"sequence": s[:50], "count": c, "percent": round(p, 2)} | |
| for s, c, p in overrep_pct | |
| ], | |
| } | |
| def _parse_sam_for_variants(sam_path: str, reference_seq: str) -> list[dict]: | |
| ref_lines = reference_seq.splitlines() | |
| ref = "".join(line.strip().upper() for line in ref_lines if not line.startswith(">")) | |
| pileup: dict[int, dict[str, int]] = {} | |
| depth_by_pos: dict[int, int] = {} | |
| with open(sam_path) as f: | |
| for line in f: | |
| if line.startswith("@"): | |
| continue | |
| parts = line.strip().split("\t") | |
| if len(parts) < 6: | |
| continue | |
| flag = int(parts[1]) | |
| if flag & 4: | |
| continue | |
| pos = int(parts[3]) | |
| cigar = parts[5] | |
| seq = parts[9] | |
| genome_pos = pos - 1 | |
| ops = re.findall(r"(\d+)([MIDNSHPX=])", cigar) | |
| offset = 0 | |
| for length, op in ops: | |
| l = int(length) | |
| if op == "M": | |
| for i in range(l): | |
| p = genome_pos + i | |
| if p < len(ref): | |
| base = seq[offset + i].upper() if offset + i < len(seq) else "N" | |
| if p not in pileup: | |
| pileup[p] = {"A": 0, "C": 0, "G": 0, "T": 0, "N": 0, "del": 0, "ins": 0} | |
| depth_by_pos[p] = depth_by_pos.get(p, 0) + 1 | |
| if base in pileup[p]: | |
| pileup[p][base] += 1 | |
| else: | |
| pileup[p]["N"] += 1 | |
| offset += l | |
| elif op == "I": | |
| offset += l | |
| elif op == "D": | |
| for i in range(l): | |
| p = genome_pos + i | |
| if p not in pileup: | |
| pileup[p] = {"A": 0, "C": 0, "G": 0, "T": 0, "N": 0, "del": 0, "ins": 0} | |
| pileup[p]["del"] += 1 | |
| elif op in ("S", "H"): | |
| if op == "S": | |
| offset += l | |
| min_depth = 2 | |
| min_alt_freq = 0.2 | |
| variants: list[dict] = [] | |
| for pos in sorted(pileup.keys()): | |
| counts = pileup[pos] | |
| depth = depth_by_pos.get(pos, sum(counts.values()) - counts.get("del", 0) - counts.get("ins", 0)) | |
| if depth < min_depth: | |
| continue | |
| ref_base = ref[pos].upper() if pos < len(ref) else "N" | |
| total = sum(counts.get(b, 0) for b in "ACGTN") | |
| if total == 0: | |
| continue | |
| for base in "ACGT": | |
| if base == ref_base: | |
| continue | |
| alt_count = counts.get(base, 0) | |
| freq = alt_count / total | |
| if freq >= min_alt_freq: | |
| variants.append({ | |
| "pos": pos + 1, "ref": ref_base, "alt": base, | |
| "depth": depth, "alt_count": alt_count, "freq": round(freq, 4), | |
| }) | |
| variants.sort(key=lambda v: -v["freq"]) | |
| return variants[:50] | |
| def _generate_report(qc: dict, variants: list[dict], ref_name: str) -> dict: | |
| total_variants = len(variants) | |
| snv_count = sum(1 for v in variants if len(v["ref"]) == 1 and len(v["alt"]) == 1) | |
| avg_depth = round(sum(v["depth"] for v in variants) / total_variants, 1) if total_variants else 0 | |
| return { | |
| "reference": ref_name, | |
| "qc_summary": { | |
| "total_reads": qc.get("total_reads", 0), | |
| "total_bases": qc.get("total_bases", 0), | |
| "mean_quality": qc.get("mean_quality", 0), | |
| "q30_percent": qc.get("q30_percent", 0), | |
| "gc_percent": qc.get("gc_percent", 0), | |
| }, | |
| "variant_summary": { | |
| "total_variants": total_variants, | |
| "snv_count": snv_count, | |
| "avg_depth": avg_depth, | |
| }, | |
| "variants": variants, | |
| } | |
| async def _download_fastq(url: str, dest: str) -> str: | |
| async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client: | |
| async with client.stream("GET", url) as r: | |
| r.raise_for_status() | |
| content_length = int(r.headers.get("content-length", 0)) | |
| if content_length > MAX_FASTQ_SIZE: | |
| raise ValueError(f"FASTQ too large: {content_length} bytes (max {MAX_FASTQ_SIZE})") | |
| with open(dest, "wb") as f: | |
| async for chunk in r.aiter_bytes(): | |
| f.write(chunk) | |
| return dest | |
| async def _download_reference(ref_name: str, dest_dir: str | None = None) -> str: | |
| url = REFERENCE_URLS.get(ref_name) | |
| if not url: | |
| raise ValueError(f"Unknown reference genome: {ref_name}") | |
| cache_dir = dest_dir or REF_CACHE_DIR | |
| os.makedirs(cache_dir, exist_ok=True) | |
| fa_path = os.path.join(cache_dir, f"{ref_name}.fa") | |
| if os.path.exists(fa_path) and os.path.getsize(fa_path) > 0: | |
| logger.info(f"Using cached reference {ref_name} ({os.path.getsize(fa_path)} bytes)") | |
| return fa_path | |
| async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client: | |
| r = await client.get(url) | |
| r.raise_for_status() | |
| data = r.content | |
| if url.endswith(".gz"): | |
| import gzip | |
| data = gzip.decompress(data) | |
| with open(fa_path, "wb") as f: | |
| f.write(data) | |
| return fa_path | |
| class SequencingPipeline(BaseTool): | |
| name = "sequencing" | |
| async def run(self, input: dict) -> dict: | |
| fastq_url = input.get("fastq_url", "").strip() | |
| reference = input.get("reference", SMALL_REFERENCE).strip().lower() | |
| if not fastq_url: | |
| return {"error": "fastq_url is required"} | |
| tmpdir = tempfile.mkdtemp(prefix="seqpipe_") | |
| try: | |
| ref_path = await _download_reference(reference) | |
| with open(ref_path) as f: | |
| ref_content = f.read() | |
| fastq_path = os.path.join(tmpdir, "input.fastq") | |
| synthetic = fastq_url.lower() in ("synthetic", "demo", "test") | |
| fastq_source = "synthetic" | |
| if synthetic: | |
| logger.info("Generating synthetic FASTQ reads") | |
| fastq_data = _generate_synthetic_fastq(ref_content, num_reads=500, read_len=100) | |
| with open(fastq_path, "w") as f: | |
| f.write(fastq_data) | |
| else: | |
| fastq_source = "url" | |
| try: | |
| await asyncio.wait_for(_download_fastq(fastq_url, fastq_path), timeout=120) | |
| except Exception: | |
| logger.info("FASTQ download failed, generating synthetic reads from reference") | |
| fastq_source = "synthetic" | |
| fastq_data = _generate_synthetic_fastq(ref_content, num_reads=500, read_len=100) | |
| with open(fastq_path, "w") as f: | |
| f.write(fastq_data) | |
| qc = _parse_fastq_quality(fastq_path) | |
| if "error" in qc: | |
| return {"error": qc["error"], "step": "qc"} | |
| mm2_path = await asyncio.wait_for(_ensure_minimap2(), timeout=120) | |
| sam_path = os.path.join(tmpdir, "aln.sam") | |
| minimap2_proc = await asyncio.create_subprocess_exec( | |
| mm2_path, "-ax", "sr", ref_path, fastq_path, | |
| "-o", sam_path, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE, | |
| ) | |
| try: | |
| mm_stdout, mm_stderr = await asyncio.wait_for(minimap2_proc.communicate(), timeout=300) | |
| except asyncio.TimeoutError: | |
| minimap2_proc.kill() | |
| await minimap2_proc.communicate() | |
| return {"error": "Alignment timed out after 5 minutes", "step": "align"} | |
| if minimap2_proc.returncode != 0 or not os.path.exists(sam_path): | |
| err = mm_stderr.decode("utf-8", errors="replace")[:500] if mm_stderr else "" | |
| return {"error": f"minimap2 failed (exit {minimap2_proc.returncode}): {err}", "step": "align"} | |
| aln_stats = {"mapped_reads": 0, "unmapped_reads": 0, "total_alignments": 0} | |
| with open(sam_path) as f: | |
| for line in f: | |
| if line.startswith("@"): | |
| continue | |
| aln_stats["total_alignments"] += 1 | |
| parts = line.strip().split("\t", maxsplit=2) | |
| if len(parts) >= 2: | |
| flag = int(parts[1]) | |
| if flag & 4: | |
| aln_stats["unmapped_reads"] += 1 | |
| else: | |
| aln_stats["mapped_reads"] += 1 | |
| variants = _parse_sam_for_variants(sam_path, ref_content) | |
| report = _generate_report(qc, variants, reference) | |
| return { | |
| "reference": reference, | |
| "fastq_source": fastq_source, | |
| "qc": qc, | |
| "alignment": aln_stats, | |
| "variants": variants[:20], | |
| "report": report, | |
| "steps_completed": ["qc", "align", "variants", "report"], | |
| } | |
| except ValueError as e: | |
| return {"error": str(e)} | |
| except httpx.HTTPStatusError as e: | |
| return {"error": f"Download failed (HTTP {e.response.status_code})"} | |
| except asyncio.TimeoutError: | |
| return {"error": "Pipeline timed out"} | |
| except Exception as e: | |
| logger.exception("Sequencing pipeline failed") | |
| return {"error": f"Pipeline failed: {e}"} | |
| finally: | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |