File size: 3,088 Bytes
4ce9bb0
d576da9
4ce9bb0
 
 
 
 
 
 
 
 
d576da9
 
 
4ce9bb0
d576da9
 
 
4ce9bb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d576da9
4ce9bb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d576da9
 
4ce9bb0
d576da9
 
4ce9bb0
 
d576da9
 
4ce9bb0
 
d576da9
e6cc4e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# --- IMPORTS ---
import os
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn # For running the app directly if needed

# Your existing ML components
from cnnClassifier.utils.common import decodeImage
from cnnClassifier.pipeline.prediction import PredictionPipeline

# --- CONFIGURATION ---
os.putenv('LANG', 'en_US.UTF-8')
os.putenv('LC_ALL', 'en_US.UTF-8')

# --- INITIALIZE FastAPI APP ---
app = FastAPI(
    title="Chest Cancer Classification API",
    description="An API to predict whether a chest CT scan shows signs of adenocarcinoma cancer."
)

# --- MIDDLEWARE (for CORS) ---
# This is the FastAPI equivalent of Flask-CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- MOUNT STATIC FILES AND TEMPLATES ---
# This is how FastAPI serves your CSS, JS, and HTML files
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")


# --- LOAD THE PREDICTION PIPELINE ON STARTUP ---
# This ensures the model is loaded only once when the application starts.
classifier = PredictionPipeline(filename="inputImage.jpg")


# --- DEFINE THE REQUEST BODY STRUCTURE ---
# Pydantic model for automatic validation of the incoming JSON
class ImagePayload(BaseModel):
    image: str


# --- API ENDPOINTS ---

@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
    """
    Renders the main user interface (index.html).
    """
    return templates.TemplateResponse("index.html", {"request": request})


@app.post("/train")
async def trainRoute():
    """
    Triggers the DVC pipeline to retrain the model.
    NOTE: This is a blocking operation and not recommended for a real-world, high-traffic production server.
    It's suitable for this project's demonstration purposes.
    """
    os.system("dvc repro")
    return {"message": "Training done successfully!"}


@app.post("/predict")
async def predictRoute(payload: ImagePayload):
    """
    Accepts a base64 encoded image, saves it, runs prediction, and returns the result.
    """
    # 1. Decode the image and save it
    decodeImage(payload.image, "inputImage.jpg")

    # 2. Run the prediction pipeline
    prediction_value = classifier.predict()

    # 3. Translate the numeric prediction into a human-readable string
    # Based on your confirmed class indices: {'adenocarcinoma': 0, 'normal': 1}
    if prediction_value == 1:
        prediction_text = "Normal"
    else:  # The value was 0
        prediction_text = "Cancer"

    # 4. Return the result. FastAPI handles the JSON conversion.
    return [{"prediction": prediction_text}]


# --- RUN THE APP ---
# This block is for local development. Gunicorn/Uvicorn will run the app in production.
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8080)