Spaces:
Configuration error
Configuration error
File size: 3,440 Bytes
1117049 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | # -*- coding: utf-8 -*-
"""predict_app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/18tTfrNXDQWf7MfzUe4SaNZe2yftvvJsn
"""
# app/main.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
from model.predictor import GenePredictor
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
app = FastAPI(title="F Gene Prediction API", description="API for predicting f gene start and end positions in DNA sequences")
class SequenceInput(BaseModel):
sequence: str
ground_truth_labels: Optional[str] = None
ground_truth_start: Optional[int] = None
ground_truth_end: Optional[int] = None
class PredictionResponse(BaseModel):
regions: list
confidence: float
metrics: Optional[dict] = None
message: str
try:
predictor = GenePredictor(model_path='model/best_boundary_aware_model.pth')
except Exception as e:
logging.error(f"Failed to initialize predictor: {e}")
raise
@app.post("/predict", response_model=PredictionResponse)
async def predict_gene(input_data: SequenceInput):
sequence = input_data.sequence.strip().upper()
if not sequence:
raise HTTPException(status_code=400, detail="Sequence cannot be empty")
if not all(c in 'ACTGN' for c in sequence):
raise HTTPException(status_code=400, detail="Sequence contains invalid characters. Only A, C, T, G, N allowed")
labels = None
if input_data.ground_truth_labels:
try:
labels = [int(x) for x in input_data.ground_truth_labels.split(',')]
if len(labels) != len(sequence):
raise HTTPException(status_code=400, detail=f"Labels length ({len(labels)}) must match sequence length ({len(sequence)})")
if not all(x in (0, 1) for x in labels):
raise HTTPException(status_code=400, detail="Labels must be 0 or 1")
except ValueError:
raise HTTPException(status_code=400, detail="Invalid labels format. Use comma-separated 0s and 1s")
elif input_data.ground_truth_start is not None and input_data.ground_truth_end is not None:
try:
start = input_data.ground_truth_start
end = input_data.ground_truth_end
if start < 0 or end > len(sequence) or start >= end:
raise HTTPException(status_code=400, detail=f"Invalid coordinates: start={start}, end={end}")
labels = predictor.labels_from_coordinates(len(sequence), start, end)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid start/end coordinates")
try:
predictions, probs_dict, confidence = predictor.predict(sequence)
regions = predictor.extract_gene_regions(predictions, sequence)
metrics = None
if labels is not None:
metrics = predictor.compute_accuracy(predictions, labels)
response = {
"regions": regions,
"confidence": float(confidence),
"metrics": metrics,
"message": "Prediction successful"
}
return response
except Exception as e:
logging.error(f"Prediction failed: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.get("/health")
async def health_check():
return {"status": "API is running"} |