from __future__ import annotations import math import csv import threading import time import uuid from pathlib import Path from typing import Any from fastapi import FastAPI, File, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel BASE_DIR = Path(__file__).resolve().parent WEB_DIR = BASE_DIR / "web" JOB_DIR = BASE_DIR / "jobs" JOB_DIR.mkdir(exist_ok=True) app = FastAPI(title="Vydra Inference API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) app.mount("/assets", StaticFiles(directory=WEB_DIR / "assets"), name="assets") vydra_module = None jobs: dict[str, dict[str, Any]] = {} jobs_lock = threading.Lock() warmup_status = {"status": "pending", "error": None} class InferenceRequest(BaseModel): sequence_or_fasta: str def get_vydra(): global vydra_module if vydra_module is None: import inference_gradio vydra_module = inference_gradio return vydra_module def warmup_models(): warmup_status["status"] = "loading" warmup_status["error"] = None try: vydra = get_vydra() vydra.load_prost_t5() warmup_status["status"] = "ready" print("Vydra warmup complete: MLP weights and ProstT5 loaded.") except Exception as exc: warmup_status["status"] = "error" warmup_status["error"] = str(exc) print(f"Vydra warmup failed: {exc}") @app.on_event("startup") def startup_warmup(): threading.Thread(target=warmup_models, daemon=True).start() PART_ASSETS = { "baseplate": "assets/Parts/Baseplate.webp", "collar": "assets/Parts/Collar.webp", "head_tail_joining": "assets/Parts/Head-Tail_Joining.webp", "major_capsid": "assets/Parts/Major_Capsid.webp", "major_tail": "assets/Parts/Major_Tail.webp", "minor_capsid": "assets/Parts/Minor_Capsid.webp", "minor_tail": "assets/Parts/Minor_Tail.webp", "portal": "assets/Parts/Portal.webp", "tail_fiber": "assets/Parts/Tail_Fiber.webp", "tail_sheath": "assets/Parts/Tail_Sheath.webp", "non_structural": "assets/Parts/Non_Struct.webp", "viral_eukaryote": "assets/Parts/Viral_Euk.webp", "cellular": "assets/Parts/Cellular.webp", "bacteria": "assets/Parts/Cell_Bacteria.webp", "eukariota": "assets/Parts/Cell_Euk.webp", "archaea": "assets/Parts/Arhaea.webp", "viral": "assets/Parts/Viral.webp", "sequence": "assets/Parts/Seq_Protein.webp", } STRUCTURAL_LABELS = { "baseplate": "Baseplate", "collar": "Collar", "head_tail_joining": "Head-tail joining", "major_capsid": "Major capsid", "major_tail": "Major tail", "minor_capsid": "Minor capsid", "minor_tail": "Minor tail", "portal": "Portal", "tail_fiber": "Tail fiber", "tail_sheath": "Tail sheath", } def safe_float(value: Any, default: float = 0.0) -> float: try: out = float(value) if math.isnan(out): return default return out except (TypeError, ValueError): return default def title_label(value: str) -> str: return value.replace("_", " ").replace("-", " ").title() def prediction_description(class_key: str, path: list[str]) -> str: descriptions = { "major_capsid": "Proteina estructural mayor de capside, componente principal de la cabeza del fago.", "tail_fiber": "Proteina asociada al reconocimiento del hospedero y union inicial del fago.", "portal": "Componente del portal de empaquetamiento y salida del ADN viral.", "tail_sheath": "Componente contractil asociado a la cola del fago.", "non_structural": "Proteina de fago sin rol estructural directo; se reporta contexto por similitud.", } return descriptions.get(class_key, f"Prediccion final en la ruta {' / '.join(path)}.") def class_icon_for(path: list[str], class_key: str) -> str: if class_key == "non_structural" or "non_structural" in path: return "assets/Non_Structural_Phage.webp" if class_key in STRUCTURAL_LABELS or "structural" in path: return "assets/Structural_Phage.webp" return "assets/protein.webp" def visual_asset_for(path: list[str], class_key: str) -> str: if class_key in PART_ASSETS: return PART_ASSETS[class_key] if "non_structural" in path: return PART_ASSETS["non_structural"] if "phage" in path: return PART_ASSETS["viral"] if "cellular" in path: return PART_ASSETS["cellular"] return PART_ASSETS["sequence"] def hierarchy_items(path: list[str], confidences: dict[str, float]) -> list[dict[str, Any]]: colors = ["coral", "green", "cyan", "blue", "violet"] scores = list(confidences.values()) items = [] for index, label in enumerate(path[:5]): score = scores[index] if index < len(scores) else scores[-1] if scores else 0.0 items.append({ "level": f"Nivel {index + 1}", "label": title_label(label), "score": safe_float(score), "color": colors[index % len(colors)], }) return items def cluster_payload(row: dict[str, Any]) -> dict[str, Any]: annotation = row.get("cluster_annotation") or {} cluster_id = row.get("cluster_id", -1) cluster_label = f"C{cluster_id}" if isinstance(cluster_id, int) and cluster_id >= 0 else "N/A" functions = [] biological_label = row.get("cluster_biological_label") or "" if biological_label and biological_label != "unannotated": functions.append(biological_label) if annotation.get("best_taxonomic_level"): functions.append(f"Nivel: {annotation['best_taxonomic_level']}") if annotation.get("confidence"): functions.append(f"Confianza: {annotation['confidence']}") if not functions: functions = ["Contexto por similitud", "FAISS nearest neighbor", "Cluster biologico"] organisms = [ annotation.get("dominant_virus_name"), annotation.get("dominant_family"), annotation.get("dominant_genus"), annotation.get("dominant_host"), ] organisms = [str(item) for item in organisms if item] if not organisms: organisms = [row.get("cluster_nearest_seq_id") or "Referencia mas cercana N/A"] return { "id": cluster_label, "size": int(safe_float(annotation.get("n_sequences"), 0)), "avg_similarity": safe_float(row.get("cluster_similarity"), 0.0), "functions": functions[:3], "organisms": organisms[:4], "nearest_seq_id": row.get("cluster_nearest_seq_id") or "", "threshold": row.get("cluster_similarity_threshold") or "", "annotation": annotation, } def to_service_result(record: Any, prediction: dict[str, Any], elapsed_seconds: float) -> dict[str, Any]: path = [str(item) for item in prediction["path"]] final_key = str(path[-1]) if path else "sequence" if final_key == "OtherDedup50": final_key = "non_structural" class_label = STRUCTURAL_LABELS.get(final_key, title_label(final_key)) route_label = " - ".join(title_label(item) for item in path[:3]) if path else "Unknown" confidence_values = list(prediction.get("confidences", {}).values()) confidence = safe_float(confidence_values[-1] if confidence_values else 0.0) confidence_label = "Muy alta" if confidence >= 0.9 else "Alta" if confidence >= 0.8 else "Media" if confidence >= 0.6 else "Baja" row = {"seq_id": record.seq_id, "sequence": record.sequence, **prediction} return { "query_id": record.seq_id, "elapsed_seconds": elapsed_seconds, "status": "Prediccion completa", "prediction": { "route_label": route_label, "class_key": final_key, "class_label": class_label, "class_icon": class_icon_for(path, final_key), "visual_asset": visual_asset_for(path, final_key), "description": prediction_description(final_key, path), }, "confidence": {"value": confidence, "label": confidence_label}, "hierarchy": hierarchy_items(path, prediction.get("confidences", {})), "cluster": cluster_payload(row), } def result_csv_row(result: dict[str, Any]) -> dict[str, Any]: prediction = result["prediction"] confidence = result["confidence"] cluster = result["cluster"] return { "query_id": result["query_id"], "elapsed_seconds": result["elapsed_seconds"], "route_label": prediction["route_label"], "class_key": prediction["class_key"], "class_label": prediction["class_label"], "confidence": confidence["value"], "confidence_label": confidence["label"], "cluster_id": cluster["id"], "cluster_size": cluster["size"], "cluster_avg_similarity": cluster["avg_similarity"], "cluster_nearest_seq_id": cluster.get("nearest_seq_id", ""), "cluster_threshold": cluster.get("threshold", ""), "hierarchy": " > ".join(item["label"] for item in result["hierarchy"]), "cluster_functions": " | ".join(cluster.get("functions", [])), "cluster_organisms": " | ".join(cluster.get("organisms", [])), } def write_results_csv(job_id: str, results: list[dict[str, Any]]) -> str: path = JOB_DIR / f"{job_id}_results.csv" rows = [result_csv_row(result) for result in results] with path.open("w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) return str(path) def set_job(job_id: str, **updates): with jobs_lock: jobs[job_id].update(updates) def infer_records(records: list[Any], job_id: str | None = None) -> dict[str, Any]: vydra = get_vydra() if not records: raise HTTPException(status_code=400, detail="No valid protein sequence found.") start = time.perf_counter() def progress(processed: int, total: int): if job_id: set_job(job_id, processed=processed, total=total, stage="prostt5") embeddings = vydra.embed_records(records, progress_callback=progress) elapsed = time.perf_counter() - start results = [] for record, embedding in zip(records, embeddings): prediction = vydra.run_hierarchy(embedding) results.append(to_service_result(record, prediction, elapsed)) csv_path = write_results_csv(job_id or uuid.uuid4().hex, results) return {"result": results[0], "results": results, "csv_path": csv_path} def infer_text(sequence_or_fasta: str) -> dict[str, Any]: vydra = get_vydra() records = vydra.parse_fasta_text(sequence_or_fasta or "") output = infer_records(records) return {"result": output["result"], "results": output["results"]} def run_job(job_id: str, records: list[Any]): try: output = infer_records(records, job_id=job_id) set_job( job_id, status="complete", stage="complete", processed=len(records), results=output["results"], result=output["result"], csv_path=output["csv_path"], ) except Exception as exc: set_job(job_id, status="error", stage="error", error=str(exc)) def start_job_from_text(sequence_or_fasta: str) -> dict[str, Any]: vydra = get_vydra() records = vydra.parse_fasta_text(sequence_or_fasta or "") if not records: raise HTTPException(status_code=400, detail="No valid protein sequence found.") job_id = uuid.uuid4().hex with jobs_lock: jobs[job_id] = { "status": "running", "stage": "queued", "total": len(records), "processed": 0, "results": [], "result": None, "csv_path": None, "error": None, } thread = threading.Thread(target=run_job, args=(job_id, records), daemon=True) thread.start() return {"job_id": job_id, "total": len(records), "processed": 0, "status": "running"} @app.get("/") def index(): return FileResponse(WEB_DIR / "index.html") @app.get("/styles.css") def styles(): return FileResponse(WEB_DIR / "styles.css") @app.get("/script.js") def script(): return FileResponse(WEB_DIR / "script.js") @app.get("/health") def health(): return {"ok": True, "warmup": warmup_status} @app.post("/api/infer") def api_infer(request: InferenceRequest): return infer_text(request.sequence_or_fasta) @app.post("/api/jobs") def api_start_job(request: InferenceRequest): return start_job_from_text(request.sequence_or_fasta) @app.post("/api/infer-file") async def api_infer_file(file: UploadFile = File(...)): raw = await file.read() text = raw.decode("utf-8", errors="replace") return infer_text(text) @app.post("/api/jobs-file") async def api_start_job_file(file: UploadFile = File(...)): raw = await file.read() text = raw.decode("utf-8", errors="replace") return start_job_from_text(text) @app.get("/api/jobs/{job_id}") def api_job_status(job_id: str): with jobs_lock: job = jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") payload = dict(job) payload["download_url"] = f"/api/jobs/{job_id}/download" if payload.get("csv_path") else None return payload @app.get("/api/jobs/{job_id}/download") def api_job_download(job_id: str): with jobs_lock: job = jobs.get(job_id) if not job or not job.get("csv_path"): raise HTTPException(status_code=404, detail="CSV not found") path = Path(job["csv_path"]) if not path.exists(): raise HTTPException(status_code=404, detail="CSV not found") return FileResponse(path, media_type="text/csv", filename="vydra_results.csv")