Skin-Cancer / app.py
umergohar's picture
1st Commit
f94b780 verified
import io
from PIL import Image
from models.efficient_net import EfficientNetB7
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from torchvision import models, transforms
import torch.nn.functional as F
import torch
import gdown
import os
from colorama import Fore, Style
####################### DOWNLOAD MODEL WEIGHTS #######################
if not os.path.exists("model_weights"):
drive_link = "https://drive.google.com/drive/folders/1JOd3O1c3me5JWE3Elq0MhhShP_XhTYRV"
output_dir = "model_weights"
# Download the folder recursively
gdown.download_folder(drive_link, output=output_dir, quiet=False, use_cookies=False)
# Check if model weights were downloaded
model_weights_path = os.path.join(output_dir, "model_weights/EfficientNetV7Large_v1/saved_models/best_test_model.pth") # Update with actual filename
if os.path.exists(model_weights_path):
print(Fore.GREEN + "✅ Model weights downloaded successfully!" + Style.RESET_ALL)
else:
print(Fore.RED + "❌ Model weights not found. Check the folder or link." + Style.RESET_ALL)
print("Download completed!")
else:
print(Fore.YELLOW + "⚠️ Model weights already exist. Skipping download." + Style.RESET_ALL)
print("Download skipped!")
######################################################################
ml_models = {}
cancer_nocancer_model_1 = None
cancer_nocancer_model_2 = None
################### for cancer no cancer model ################################
def load_model(model_path, num_classes):
"""
Load the EfficientNet-B0 model from the state dict
"""
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
return model
efficient_net_b0_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
cancer_no_caner_class_mapping = {0: 'cancer', 1: 'nocancer'}
################################################################################
# make a lifespan
@asynccontextmanager
async def lifespan(app: FastAPI):
global ml_models
global cancer_nocancer_model_1
global cancer_nocancer_model_2
cancer_nocancer_1_weight_path = "model_weights/cancer_nocancer.pth"
cancer_nocancer_2_weight_path = "model_weights/cancer_nocancer_100.pth"
efficient_net_model = EfficientNetB7(weights_path="model_weights/EfficientNetV7Large_v1/saved_models/best_test_model.pth")
cancer_nocancer_model_1 = load_model(cancer_nocancer_1_weight_path, len(cancer_no_caner_class_mapping))
cancer_nocancer_model_2 = load_model(cancer_nocancer_2_weight_path, len(cancer_no_caner_class_mapping))
ml_models = {
"EFFICIENT NET V7": efficient_net_model,
}
yield
ml_models.clear()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins (or specify your frontend URL)
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/api/classify", response_class=JSONResponse)
async def classify(
image: UploadFile = File(...), # Accepts an image file
model: str = Form(...) # Accepts a model name as form data
):
global ml_models
global cancer_nocancer_model_1
global cancer_nocancer_model_2
model = model.upper()
if model not in ml_models.keys():
raise HTTPException(status_code=400, detail="Invalid model specified")
image_data = await image.read()
img = Image.open(io.BytesIO(image_data)).convert("RGB")
# check whether image if of cancer or not
img_for_check = efficient_net_b0_transform(img).unsqueeze(0)
check_cancer_1 = F.softmax(cancer_nocancer_model_1(img_for_check), dim=1)
check_cancer_2 = F.softmax(cancer_nocancer_model_2(img_for_check), dim=1)
check_cancer_1_index = check_cancer_1.argmax()
check_cancer_2_index = check_cancer_2.argmax()
print("-" * 25, "Cancer No Cancer", "-" * 25)
print(check_cancer_1_index)
print(check_cancer_2_index)
print("-" * 50)
if check_cancer_1_index == 1 and check_cancer_2_index == 1:
print("Cancer Not Detected")
return JSONResponse({"result": "Cancer Not Detected"})
else:
print("Cancer Detected")
prediction = predict(ml_models[model], model, img)
return JSONResponse(prediction)
def predict(model, model_name, image_data):
classes = ['akiec','bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
print("-" * 50)
print(f"Model Used: {model_name}")
pred = model.make_prediction(image_data)[0]
print(pred)
pred_json = {"result" : [{"class": cls, "confidence": float(pred[i])} for i, cls in enumerate(classes)]}
pred_json["result"] = sorted(pred_json["result"], key=lambda x: x["confidence"], reverse=True)
print(pred_json)
print("-" * 50)
return pred_json
@app.get("/api/model_performance", response_class=JSONResponse)
async def home():
return {
"models": [
{
"name": "Efficient Net V7",
"description": "A high-efficiency architecture that leverages compound scaling for superior performance across various tasks.",
"performance_tags": [
{ "icon": "fa-tachometer-alt", "label": "High Efficiency" },
{ "icon": "fa-star", "label": "State-of-the-Art" }
]
},
]
}
@app.get("/api/health", response_class=JSONResponse)
def health():
return JSONResponse(status_code=200, content={"status": "Working fine!"})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8501)