Spaces:
Configuration error
Configuration error
| # -*- coding: utf-8 -*- | |
| """predict_app.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/18tTfrNXDQWf7MfzUe4SaNZe2yftvvJsn | |
| """ | |
| # app/main.py | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from model.predictor import GenePredictor | |
| import logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| app = FastAPI(title="F Gene Prediction API", description="API for predicting f gene start and end positions in DNA sequences") | |
| class SequenceInput(BaseModel): | |
| sequence: str | |
| ground_truth_labels: Optional[str] = None | |
| ground_truth_start: Optional[int] = None | |
| ground_truth_end: Optional[int] = None | |
| class PredictionResponse(BaseModel): | |
| regions: list | |
| confidence: float | |
| metrics: Optional[dict] = None | |
| message: str | |
| try: | |
| predictor = GenePredictor(model_path='model/best_boundary_aware_model.pth') | |
| except Exception as e: | |
| logging.error(f"Failed to initialize predictor: {e}") | |
| raise | |
| async def predict_gene(input_data: SequenceInput): | |
| sequence = input_data.sequence.strip().upper() | |
| if not sequence: | |
| raise HTTPException(status_code=400, detail="Sequence cannot be empty") | |
| if not all(c in 'ACTGN' for c in sequence): | |
| raise HTTPException(status_code=400, detail="Sequence contains invalid characters. Only A, C, T, G, N allowed") | |
| labels = None | |
| if input_data.ground_truth_labels: | |
| try: | |
| labels = [int(x) for x in input_data.ground_truth_labels.split(',')] | |
| if len(labels) != len(sequence): | |
| raise HTTPException(status_code=400, detail=f"Labels length ({len(labels)}) must match sequence length ({len(sequence)})") | |
| if not all(x in (0, 1) for x in labels): | |
| raise HTTPException(status_code=400, detail="Labels must be 0 or 1") | |
| except ValueError: | |
| raise HTTPException(status_code=400, detail="Invalid labels format. Use comma-separated 0s and 1s") | |
| elif input_data.ground_truth_start is not None and input_data.ground_truth_end is not None: | |
| try: | |
| start = input_data.ground_truth_start | |
| end = input_data.ground_truth_end | |
| if start < 0 or end > len(sequence) or start >= end: | |
| raise HTTPException(status_code=400, detail=f"Invalid coordinates: start={start}, end={end}") | |
| labels = predictor.labels_from_coordinates(len(sequence), start, end) | |
| except ValueError: | |
| raise HTTPException(status_code=400, detail="Invalid start/end coordinates") | |
| try: | |
| predictions, probs_dict, confidence = predictor.predict(sequence) | |
| regions = predictor.extract_gene_regions(predictions, sequence) | |
| metrics = None | |
| if labels is not None: | |
| metrics = predictor.compute_accuracy(predictions, labels) | |
| response = { | |
| "regions": regions, | |
| "confidence": float(confidence), | |
| "metrics": metrics, | |
| "message": "Prediction successful" | |
| } | |
| return response | |
| except Exception as e: | |
| logging.error(f"Prediction failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def health_check(): | |
| return {"status": "API is running"} |