Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as transforms | |
| from utils import BrainTumorModel, get_precautions_from_gemini | |
| app = FastAPI() | |
| # Load the model | |
| btd_model = BrainTumorModel() | |
| btd_model_path = "brain_tumor_model.pth" | |
| try: | |
| btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu'))) | |
| btd_model.eval() | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| # Define image transform | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| # Class labels (adjust if your model uses different labels) | |
| classes = ['glioma', 'meningioma', 'notumor', 'pituitary'] | |
| def read_root(): | |
| return {"message": "Brain Tumor Detection API is running 🚀"} | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| image = Image.open(file.file).convert("RGB") | |
| image = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224] | |
| with torch.no_grad(): | |
| outputs = btd_model(image) | |
| _, predicted = torch.max(outputs.data, 1) | |
| predicted_class = classes[predicted.item()] | |
| precautions = get_precautions_from_gemini(predicted_class) | |
| return JSONResponse(content={ | |
| "prediction": predicted_class, | |
| "precautions": precautions | |
| }) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}, status_code=500) | |