Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| from PIL import Image | |
| import os | |
| import uvicorn | |
| import torch | |
| import numpy as np | |
| from io import BytesIO | |
| from torchvision import transforms , models | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| import tempfile | |
| from pathlib import Path | |
| # Set up cache directory in a user-accessible location | |
| CACHE_DIR = Path(tempfile.gettempdir()) / "huggingface_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR) | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| app = FastAPI() | |
| # Define preprocessing | |
| preprocessDensenet = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.RandomHorizontalFlip(p=0.3), | |
| transforms.RandomAffine( | |
| degrees=(-15, 15), | |
| translate=(0.1, 0.1), | |
| scale=(0.85, 1.15), | |
| fill=0 | |
| ), | |
| transforms.RandomApply([ | |
| transforms.ColorJitter( | |
| brightness=0.2, | |
| contrast=0.2 | |
| ) | |
| ], p=0.3), | |
| transforms.RandomApply([ | |
| transforms.GaussianBlur(kernel_size=3) | |
| ], p=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| transforms.RandomErasing(p=0.1) | |
| ]) | |
| preprocessResnet = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomAffine( | |
| degrees=(-10, 10), | |
| translate=(0.1, 0.1), | |
| scale=(0.9, 1.1), | |
| fill=0 | |
| ), | |
| transforms.RandomApply([ | |
| transforms.ColorJitter( | |
| brightness=0.3, | |
| contrast=0.3 | |
| ) | |
| ], p=0.3), | |
| transforms.RandomApply([ | |
| transforms.GaussianBlur(kernel_size=3) | |
| ], p=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| transforms.RandomErasing(p=0.2) | |
| ]) | |
| preprocessGooglenet = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.RandomHorizontalFlip(p=0.3), # Less aggressive flipping for medical images | |
| transforms.RandomAffine( | |
| degrees=(-5, 5), # Slight rotation | |
| translate=(0.05, 0.05), # Small translations | |
| scale=(0.95, 1.05), # Subtle scaling | |
| fill=0 # Fill with black | |
| ), | |
| transforms.RandomApply([ | |
| transforms.ColorJitter( | |
| brightness=0.2, | |
| contrast=0.2 | |
| ) | |
| ], p=0.3), # Subtle intensity variations | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def create_densenet169(): | |
| model = models.densenet169(pretrained=False) | |
| model.classifier = nn.Sequential( | |
| nn.BatchNorm1d(model.classifier.in_features), # Added batch normalization | |
| nn.Dropout(p=0.4), # Increased dropout | |
| nn.Linear(model.classifier.in_features, 512), # Added intermediate layer | |
| nn.ReLU(), | |
| nn.Dropout(p=0.3), | |
| nn.Linear(512, 2) | |
| ) | |
| return model | |
| def create_resnet18(): | |
| model = models.resnet18(pretrained=False) | |
| model.fc = nn.Sequential( | |
| nn.Dropout(p=0.5), | |
| nn.Linear(model.fc.in_features, 2) | |
| ) | |
| return model | |
| def create_googlenet(): | |
| model = models.googlenet(pretrained=False) | |
| model.aux1 = None | |
| model.aux2 = None | |
| model.fc = nn.Sequential( | |
| nn.Dropout(p=0.5), | |
| nn.Linear(model.fc.in_features, 2) | |
| ) | |
| return model | |
| def load_model_from_hf(repo_id, model_creator): | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="model.pth", | |
| cache_dir=CACHE_DIR | |
| ) | |
| # Create model architecture | |
| model = model_creator() | |
| # Load the checkpoint | |
| checkpoint = torch.load(model_path, map_location=torch.device('cpu')) | |
| # Extract model_state_dict from the checkpoint | |
| if "model_state_dict" in checkpoint: | |
| state_dict = checkpoint["model_state_dict"] | |
| else: | |
| state_dict = checkpoint # In case it's just the state_dict without wrapping | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model from {repo_id}: {str(e)}") | |
| return None | |
| modelss = {"Densenet169": None, "Resnet18": None, "Googlenet": None} | |
| modelss["Densenet169"] = load_model_from_hf( | |
| "Arham-Irfan/Densenet169_pnuemonia_binaryclassification", | |
| create_densenet169 | |
| ) | |
| modelss["Resnet18"] = load_model_from_hf( | |
| "Arham-Irfan/Resnet18_pnuemonia_binaryclassification", | |
| create_resnet18 | |
| ) | |
| modelss["Googlenet"] = load_model_from_hf( | |
| "Arham-Irfan/Googlenet_pnuemonia_binaryclassification", | |
| create_googlenet | |
| ) | |
| classes = ["Normal", "Pneumonia"] | |
| async def predict_pneumonia(file: UploadFile = File(...)): | |
| try: | |
| image = Image.open(BytesIO(await file.read())).convert("RGB") | |
| img_tensor1 = preprocessDensenet(image).unsqueeze(0) | |
| img_tensor2 = preprocessResnet(image).unsqueeze(0) | |
| img_tensor3 = preprocessGooglenet(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output1 = torch.softmax(modelss["Densenet169"](img_tensor1), dim=1).numpy()[0] | |
| output2 = torch.softmax(modelss["Resnet18"](img_tensor2), dim=1).numpy()[0] | |
| output3 = torch.softmax(modelss["Googlenet"](img_tensor3), dim=1).numpy()[0] | |
| weights = [0.45, 0.33, 0.22] | |
| ensemble_prob = weights[0] * output1 + weights[1] * output2 + weights[2] * output3 | |
| pred_index = np.argmax(ensemble_prob) | |
| return { | |
| "prediction": classes[pred_index], | |
| "confidence": float(ensemble_prob[pred_index]), | |
| "model_details": { | |
| "Densenet169": float(output1[pred_index]), | |
| "Resnet18": float(output2[pred_index]), | |
| "Googlenet": float(output3[pred_index]) | |
| } | |
| } | |
| except Exception as e: | |
| return {"error": f"Prediction error: {str(e)}"} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |