Spaces:
Runtime error
Runtime error
| # newapi.py | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import io | |
| import os | |
| # Set writable cache directories | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache" | |
| os.environ["HF_HOME"] = "/tmp/.cache" | |
| # FastAPI setup | |
| app = FastAPI(title="🧠 Brain Tumor Detection API") | |
| # Allow CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define image transform (grayscale) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.Grayscale(num_output_channels=1), # Ensure grayscale | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.286], std=[0.229]), # Adjust mean/std if needed | |
| ]) | |
| # Define the exact same model used during training | |
| import torch.nn as nn | |
| class BrainTumorModel(nn.Module): | |
| def __init__(self): | |
| super(BrainTumorModel, self).__init__() | |
| self.con1d = nn.Conv2d(1, 32, kernel_size=3) # Input is grayscale (1 channel) | |
| self.con2d = nn.Conv2d(32, 64, kernel_size=3) | |
| self.con3d = nn.Conv2d(64, 128, kernel_size=3) | |
| self.pool = nn.MaxPool2d(2) | |
| self.fc1 = nn.Linear(128 * 28 * 28, 512) # Match the saved model's input size | |
| self.fc2 = nn.Linear(512, 256) | |
| self.output = nn.Linear(256, 4) # 4 classes expected | |
| def forward(self, x): | |
| x = self.pool(torch.relu(self.con1d(x))) | |
| x = self.pool(torch.relu(self.con2d(x))) | |
| x = self.pool(torch.relu(self.con3d(x))) | |
| x = x.view(-1, 128 * 28 * 28) # Flatten the feature maps | |
| x = torch.relu(self.fc1(x)) | |
| x = torch.relu(self.fc2(x)) | |
| x = self.output(x) | |
| return x | |
| # Load model | |
| model_path = "BTD_model.pth" | |
| if not os.path.exists(model_path): | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache") | |
| btd_model = BrainTumorModel() | |
| btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| btd_model.eval() | |
| # Prediction endpoint | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("L") # Grayscale | |
| image_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = btd_model(image_tensor) | |
| prediction = torch.argmax(output, dim=1).item() | |
| result = { | |
| 0: "No tumor", | |
| 1: "Glioma", | |
| 2: "Meningioma", | |
| 3: "Pituitary tumor" | |
| }[prediction] | |
| return {"prediction": result} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Health check | |
| def root(): | |
| return {"message": "🧠 Brain Tumor Detection API is running!"} |