File size: 3,909 Bytes
ae467e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import warnings
import os
import tempfile
from pathlib import Path

from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from PIL import Image
import torch

from scripts.logging import get_logger
from scripts.utils import ViTBrainTumorClassifier
from scripts.data_model import ClassificationResponse, Prediction

warnings.filterwarnings("ignore")
logger = get_logger(__name__)

app = FastAPI(
    title="Brain Tumor Classification Inference API",
    description="Vision Transformer based brain tumor classification",
    version="1.0.0"
)

BASE_DIR = Path(__file__).parent
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
MODEL = None


@app.on_event("startup")
async def startup_event():
    global MODEL
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Using device: {device}")
        MODEL = ViTBrainTumorClassifier(device=device)
        logger.info("Application startup complete")
    except Exception as e:
        logger.error(f"Failed to initialize model: {e}")
        raise


@app.on_event("shutdown")
async def shutdown_event():
    logger.info("Application shutting down...")


@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model_loaded": MODEL is not None,
        "version": "1.0.0"
    }


@app.post("/api/v1/classify")
async def classify_image(file: UploadFile = File(...)) -> ClassificationResponse:
    """

    Classify a brain tumor from an uploaded image.

    """
    if MODEL is None:
        raise HTTPException(status_code=500, detail="Model not initialized")
    
    allowed_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp"}
    file_extension = Path(file.filename).suffix.lower()
    
    if file_extension not in allowed_extensions:
        raise HTTPException(
            status_code=400,
            detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}"
        )
    
    try:
        contents = await file.read()
        
        with tempfile.NamedTemporaryFile(suffix=file_extension, delete=False) as tmp:
            tmp.write(contents)
            tmp_path = tmp.name
        
        try:
            Image.open(tmp_path).verify()
        except Exception:
            if os.path.exists(tmp_path):
                os.remove(tmp_path)
            raise HTTPException(status_code=400, detail="Invalid image file")
        
        try:
            logger.info(f"Processing: {file.filename}")
            prediction_result = MODEL.predict(tmp_path)
            
            response = ClassificationResponse(
                success=True,
                prediction=Prediction(
                    predicted_class=prediction_result["predicted_class"],
                    confidence=prediction_result["confidence"],
                    all_predictions=prediction_result["all_predictions"]
                ),
                message=f"Successfully classified as {prediction_result['predicted_class']}"
            )
            
            logger.info(f"Complete: {prediction_result['predicted_class']}")
            return response
            
        finally:
            if os.path.exists(tmp_path):
                os.remove(tmp_path)
    
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error: {e}")
        raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app="app:app", port=8000, reload=True, host="0.0.0.0")