import io from PIL import Image from models.efficient_net import EfficientNetB7 from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from torchvision import models, transforms import torch.nn.functional as F import torch import gdown import os from colorama import Fore, Style ####################### DOWNLOAD MODEL WEIGHTS ####################### if not os.path.exists("model_weights"): drive_link = "https://drive.google.com/drive/folders/1JOd3O1c3me5JWE3Elq0MhhShP_XhTYRV" output_dir = "model_weights" # Download the folder recursively gdown.download_folder(drive_link, output=output_dir, quiet=False, use_cookies=False) # Check if model weights were downloaded model_weights_path = os.path.join(output_dir, "model_weights/EfficientNetV7Large_v1/saved_models/best_test_model.pth") # Update with actual filename if os.path.exists(model_weights_path): print(Fore.GREEN + "✅ Model weights downloaded successfully!" + Style.RESET_ALL) else: print(Fore.RED + "❌ Model weights not found. Check the folder or link." + Style.RESET_ALL) print("Download completed!") else: print(Fore.YELLOW + "⚠️ Model weights already exist. Skipping download." + Style.RESET_ALL) print("Download skipped!") ###################################################################### ml_models = {} cancer_nocancer_model_1 = None cancer_nocancer_model_2 = None ################### for cancer no cancer model ################################ def load_model(model_path, num_classes): """ Load the EfficientNet-B0 model from the state dict """ model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1) model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes) state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict) model.eval() return model efficient_net_b0_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) cancer_no_caner_class_mapping = {0: 'cancer', 1: 'nocancer'} ################################################################################ # make a lifespan @asynccontextmanager async def lifespan(app: FastAPI): global ml_models global cancer_nocancer_model_1 global cancer_nocancer_model_2 cancer_nocancer_1_weight_path = "model_weights/cancer_nocancer.pth" cancer_nocancer_2_weight_path = "model_weights/cancer_nocancer_100.pth" efficient_net_model = EfficientNetB7(weights_path="model_weights/EfficientNetV7Large_v1/saved_models/best_test_model.pth") cancer_nocancer_model_1 = load_model(cancer_nocancer_1_weight_path, len(cancer_no_caner_class_mapping)) cancer_nocancer_model_2 = load_model(cancer_nocancer_2_weight_path, len(cancer_no_caner_class_mapping)) ml_models = { "EFFICIENT NET V7": efficient_net_model, } yield ml_models.clear() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins (or specify your frontend URL) allow_methods=["*"], allow_headers=["*"], ) @app.post("/api/classify", response_class=JSONResponse) async def classify( image: UploadFile = File(...), # Accepts an image file model: str = Form(...) # Accepts a model name as form data ): global ml_models global cancer_nocancer_model_1 global cancer_nocancer_model_2 model = model.upper() if model not in ml_models.keys(): raise HTTPException(status_code=400, detail="Invalid model specified") image_data = await image.read() img = Image.open(io.BytesIO(image_data)).convert("RGB") # check whether image if of cancer or not img_for_check = efficient_net_b0_transform(img).unsqueeze(0) check_cancer_1 = F.softmax(cancer_nocancer_model_1(img_for_check), dim=1) check_cancer_2 = F.softmax(cancer_nocancer_model_2(img_for_check), dim=1) check_cancer_1_index = check_cancer_1.argmax() check_cancer_2_index = check_cancer_2.argmax() print("-" * 25, "Cancer No Cancer", "-" * 25) print(check_cancer_1_index) print(check_cancer_2_index) print("-" * 50) if check_cancer_1_index == 1 and check_cancer_2_index == 1: print("Cancer Not Detected") return JSONResponse({"result": "Cancer Not Detected"}) else: print("Cancer Detected") prediction = predict(ml_models[model], model, img) return JSONResponse(prediction) def predict(model, model_name, image_data): classes = ['akiec','bcc', 'bkl', 'df', 'mel', 'nv', 'vasc'] print("-" * 50) print(f"Model Used: {model_name}") pred = model.make_prediction(image_data)[0] print(pred) pred_json = {"result" : [{"class": cls, "confidence": float(pred[i])} for i, cls in enumerate(classes)]} pred_json["result"] = sorted(pred_json["result"], key=lambda x: x["confidence"], reverse=True) print(pred_json) print("-" * 50) return pred_json @app.get("/api/model_performance", response_class=JSONResponse) async def home(): return { "models": [ { "name": "Efficient Net V7", "description": "A high-efficiency architecture that leverages compound scaling for superior performance across various tasks.", "performance_tags": [ { "icon": "fa-tachometer-alt", "label": "High Efficiency" }, { "icon": "fa-star", "label": "State-of-the-Art" } ] }, ] } @app.get("/api/health", response_class=JSONResponse) def health(): return JSONResponse(status_code=200, content={"status": "Working fine!"}) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8501)