Spaces:
Sleeping
Sleeping
File size: 3,088 Bytes
4ce9bb0 d576da9 4ce9bb0 d576da9 4ce9bb0 d576da9 4ce9bb0 d576da9 4ce9bb0 d576da9 4ce9bb0 d576da9 4ce9bb0 d576da9 4ce9bb0 d576da9 e6cc4e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# --- IMPORTS ---
import os
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn # For running the app directly if needed
# Your existing ML components
from cnnClassifier.utils.common import decodeImage
from cnnClassifier.pipeline.prediction import PredictionPipeline
# --- CONFIGURATION ---
os.putenv('LANG', 'en_US.UTF-8')
os.putenv('LC_ALL', 'en_US.UTF-8')
# --- INITIALIZE FastAPI APP ---
app = FastAPI(
title="Chest Cancer Classification API",
description="An API to predict whether a chest CT scan shows signs of adenocarcinoma cancer."
)
# --- MIDDLEWARE (for CORS) ---
# This is the FastAPI equivalent of Flask-CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- MOUNT STATIC FILES AND TEMPLATES ---
# This is how FastAPI serves your CSS, JS, and HTML files
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# --- LOAD THE PREDICTION PIPELINE ON STARTUP ---
# This ensures the model is loaded only once when the application starts.
classifier = PredictionPipeline(filename="inputImage.jpg")
# --- DEFINE THE REQUEST BODY STRUCTURE ---
# Pydantic model for automatic validation of the incoming JSON
class ImagePayload(BaseModel):
image: str
# --- API ENDPOINTS ---
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
"""
Renders the main user interface (index.html).
"""
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/train")
async def trainRoute():
"""
Triggers the DVC pipeline to retrain the model.
NOTE: This is a blocking operation and not recommended for a real-world, high-traffic production server.
It's suitable for this project's demonstration purposes.
"""
os.system("dvc repro")
return {"message": "Training done successfully!"}
@app.post("/predict")
async def predictRoute(payload: ImagePayload):
"""
Accepts a base64 encoded image, saves it, runs prediction, and returns the result.
"""
# 1. Decode the image and save it
decodeImage(payload.image, "inputImage.jpg")
# 2. Run the prediction pipeline
prediction_value = classifier.predict()
# 3. Translate the numeric prediction into a human-readable string
# Based on your confirmed class indices: {'adenocarcinoma': 0, 'normal': 1}
if prediction_value == 1:
prediction_text = "Normal"
else: # The value was 0
prediction_text = "Cancer"
# 4. Return the result. FastAPI handles the JSON conversion.
return [{"prediction": prediction_text}]
# --- RUN THE APP ---
# This block is for local development. Gunicorn/Uvicorn will run the app in production.
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8080) |