codeby-hp's picture
Upload 7 files
ae467e7 verified
import warnings
import os
import tempfile
from pathlib import Path
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from PIL import Image
import torch
from scripts.logging import get_logger
from scripts.utils import ViTBrainTumorClassifier
from scripts.data_model import ClassificationResponse, Prediction
warnings.filterwarnings("ignore")
logger = get_logger(__name__)
app = FastAPI(
title="Brain Tumor Classification Inference API",
description="Vision Transformer based brain tumor classification",
version="1.0.0"
)
BASE_DIR = Path(__file__).parent
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
MODEL = None
@app.on_event("startup")
async def startup_event():
global MODEL
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
MODEL = ViTBrainTumorClassifier(device=device)
logger.info("Application startup complete")
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Application shutting down...")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": MODEL is not None,
"version": "1.0.0"
}
@app.post("/api/v1/classify")
async def classify_image(file: UploadFile = File(...)) -> ClassificationResponse:
"""
Classify a brain tumor from an uploaded image.
"""
if MODEL is None:
raise HTTPException(status_code=500, detail="Model not initialized")
allowed_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp"}
file_extension = Path(file.filename).suffix.lower()
if file_extension not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}"
)
try:
contents = await file.read()
with tempfile.NamedTemporaryFile(suffix=file_extension, delete=False) as tmp:
tmp.write(contents)
tmp_path = tmp.name
try:
Image.open(tmp_path).verify()
except Exception:
if os.path.exists(tmp_path):
os.remove(tmp_path)
raise HTTPException(status_code=400, detail="Invalid image file")
try:
logger.info(f"Processing: {file.filename}")
prediction_result = MODEL.predict(tmp_path)
response = ClassificationResponse(
success=True,
prediction=Prediction(
predicted_class=prediction_result["predicted_class"],
confidence=prediction_result["confidence"],
all_predictions=prediction_result["all_predictions"]
),
message=f"Successfully classified as {prediction_result['predicted_class']}"
)
logger.info(f"Complete: {prediction_result['predicted_class']}")
return response
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error: {e}")
raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app="app:app", port=8000, reload=True, host="0.0.0.0")