# -*- coding: utf-8 -*- """predict_app.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/18tTfrNXDQWf7MfzUe4SaNZe2yftvvJsn """ # app/main.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional from model.predictor import GenePredictor import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') app = FastAPI(title="F Gene Prediction API", description="API for predicting f gene start and end positions in DNA sequences") class SequenceInput(BaseModel): sequence: str ground_truth_labels: Optional[str] = None ground_truth_start: Optional[int] = None ground_truth_end: Optional[int] = None class PredictionResponse(BaseModel): regions: list confidence: float metrics: Optional[dict] = None message: str try: predictor = GenePredictor(model_path='model/best_boundary_aware_model.pth') except Exception as e: logging.error(f"Failed to initialize predictor: {e}") raise @app.post("/predict", response_model=PredictionResponse) async def predict_gene(input_data: SequenceInput): sequence = input_data.sequence.strip().upper() if not sequence: raise HTTPException(status_code=400, detail="Sequence cannot be empty") if not all(c in 'ACTGN' for c in sequence): raise HTTPException(status_code=400, detail="Sequence contains invalid characters. Only A, C, T, G, N allowed") labels = None if input_data.ground_truth_labels: try: labels = [int(x) for x in input_data.ground_truth_labels.split(',')] if len(labels) != len(sequence): raise HTTPException(status_code=400, detail=f"Labels length ({len(labels)}) must match sequence length ({len(sequence)})") if not all(x in (0, 1) for x in labels): raise HTTPException(status_code=400, detail="Labels must be 0 or 1") except ValueError: raise HTTPException(status_code=400, detail="Invalid labels format. Use comma-separated 0s and 1s") elif input_data.ground_truth_start is not None and input_data.ground_truth_end is not None: try: start = input_data.ground_truth_start end = input_data.ground_truth_end if start < 0 or end > len(sequence) or start >= end: raise HTTPException(status_code=400, detail=f"Invalid coordinates: start={start}, end={end}") labels = predictor.labels_from_coordinates(len(sequence), start, end) except ValueError: raise HTTPException(status_code=400, detail="Invalid start/end coordinates") try: predictions, probs_dict, confidence = predictor.predict(sequence) regions = predictor.extract_gene_regions(predictions, sequence) metrics = None if labels is not None: metrics = predictor.compute_accuracy(predictions, labels) response = { "regions": regions, "confidence": float(confidence), "metrics": metrics, "message": "Prediction successful" } return response except Exception as e: logging.error(f"Prediction failed: {e}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") @app.get("/health") async def health_check(): return {"status": "API is running"}