Spaces:
Configuration error
Configuration error
| import io | |
| from PIL import Image | |
| from models.efficient_net import EfficientNetB7 | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from contextlib import asynccontextmanager | |
| from torchvision import models, transforms | |
| import torch.nn.functional as F | |
| import torch | |
| import gdown | |
| import os | |
| from colorama import Fore, Style | |
| ####################### DOWNLOAD MODEL WEIGHTS ####################### | |
| if not os.path.exists("model_weights"): | |
| drive_link = "https://drive.google.com/drive/folders/1JOd3O1c3me5JWE3Elq0MhhShP_XhTYRV" | |
| output_dir = "model_weights" | |
| # Download the folder recursively | |
| gdown.download_folder(drive_link, output=output_dir, quiet=False, use_cookies=False) | |
| # Check if model weights were downloaded | |
| model_weights_path = os.path.join(output_dir, "model_weights/EfficientNetV7Large_v1/saved_models/best_test_model.pth") # Update with actual filename | |
| if os.path.exists(model_weights_path): | |
| print(Fore.GREEN + "✅ Model weights downloaded successfully!" + Style.RESET_ALL) | |
| else: | |
| print(Fore.RED + "❌ Model weights not found. Check the folder or link." + Style.RESET_ALL) | |
| print("Download completed!") | |
| else: | |
| print(Fore.YELLOW + "⚠️ Model weights already exist. Skipping download." + Style.RESET_ALL) | |
| print("Download skipped!") | |
| ###################################################################### | |
| ml_models = {} | |
| cancer_nocancer_model_1 = None | |
| cancer_nocancer_model_2 = None | |
| ################### for cancer no cancer model ################################ | |
| def load_model(model_path, num_classes): | |
| """ | |
| Load the EfficientNet-B0 model from the state dict | |
| """ | |
| model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1) | |
| model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes) | |
| state_dict = torch.load(model_path, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| efficient_net_b0_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| cancer_no_caner_class_mapping = {0: 'cancer', 1: 'nocancer'} | |
| ################################################################################ | |
| # make a lifespan | |
| async def lifespan(app: FastAPI): | |
| global ml_models | |
| global cancer_nocancer_model_1 | |
| global cancer_nocancer_model_2 | |
| cancer_nocancer_1_weight_path = "model_weights/cancer_nocancer.pth" | |
| cancer_nocancer_2_weight_path = "model_weights/cancer_nocancer_100.pth" | |
| efficient_net_model = EfficientNetB7(weights_path="model_weights/EfficientNetV7Large_v1/saved_models/best_test_model.pth") | |
| cancer_nocancer_model_1 = load_model(cancer_nocancer_1_weight_path, len(cancer_no_caner_class_mapping)) | |
| cancer_nocancer_model_2 = load_model(cancer_nocancer_2_weight_path, len(cancer_no_caner_class_mapping)) | |
| ml_models = { | |
| "EFFICIENT NET V7": efficient_net_model, | |
| } | |
| yield | |
| ml_models.clear() | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins (or specify your frontend URL) | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def classify( | |
| image: UploadFile = File(...), # Accepts an image file | |
| model: str = Form(...) # Accepts a model name as form data | |
| ): | |
| global ml_models | |
| global cancer_nocancer_model_1 | |
| global cancer_nocancer_model_2 | |
| model = model.upper() | |
| if model not in ml_models.keys(): | |
| raise HTTPException(status_code=400, detail="Invalid model specified") | |
| image_data = await image.read() | |
| img = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| # check whether image if of cancer or not | |
| img_for_check = efficient_net_b0_transform(img).unsqueeze(0) | |
| check_cancer_1 = F.softmax(cancer_nocancer_model_1(img_for_check), dim=1) | |
| check_cancer_2 = F.softmax(cancer_nocancer_model_2(img_for_check), dim=1) | |
| check_cancer_1_index = check_cancer_1.argmax() | |
| check_cancer_2_index = check_cancer_2.argmax() | |
| print("-" * 25, "Cancer No Cancer", "-" * 25) | |
| print(check_cancer_1_index) | |
| print(check_cancer_2_index) | |
| print("-" * 50) | |
| if check_cancer_1_index == 1 and check_cancer_2_index == 1: | |
| print("Cancer Not Detected") | |
| return JSONResponse({"result": "Cancer Not Detected"}) | |
| else: | |
| print("Cancer Detected") | |
| prediction = predict(ml_models[model], model, img) | |
| return JSONResponse(prediction) | |
| def predict(model, model_name, image_data): | |
| classes = ['akiec','bcc', 'bkl', 'df', 'mel', 'nv', 'vasc'] | |
| print("-" * 50) | |
| print(f"Model Used: {model_name}") | |
| pred = model.make_prediction(image_data)[0] | |
| print(pred) | |
| pred_json = {"result" : [{"class": cls, "confidence": float(pred[i])} for i, cls in enumerate(classes)]} | |
| pred_json["result"] = sorted(pred_json["result"], key=lambda x: x["confidence"], reverse=True) | |
| print(pred_json) | |
| print("-" * 50) | |
| return pred_json | |
| async def home(): | |
| return { | |
| "models": [ | |
| { | |
| "name": "Efficient Net V7", | |
| "description": "A high-efficiency architecture that leverages compound scaling for superior performance across various tasks.", | |
| "performance_tags": [ | |
| { "icon": "fa-tachometer-alt", "label": "High Efficiency" }, | |
| { "icon": "fa-star", "label": "State-of-the-Art" } | |
| ] | |
| }, | |
| ] | |
| } | |
| def health(): | |
| return JSONResponse(status_code=200, content={"status": "Working fine!"}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8501) |