Spaces:
Runtime error
Runtime error
| import os | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as transforms | |
| from utils import BrainTumorModel, GliomaStageModel | |
| app = FastAPI() | |
| # Load models (updated to local .pth files) | |
| btd_model_path = "brain_tumor_model.pth" | |
| glioma_model_path = "glioma_stage_model.pth" | |
| # Initialize and load Brain Tumor Detection Model | |
| btd_model = BrainTumorModel() | |
| btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu'))) | |
| btd_model.eval() | |
| # Initialize and load Glioma Stage Detection Model | |
| glioma_model = GliomaStageModel() | |
| glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu'))) | |
| glioma_model.eval() | |
| # Define preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| image = Image.open(file.file).convert("RGB") | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = btd_model(image) | |
| predicted = torch.argmax(output, dim=1).item() | |
| classes = ['No Tumor', 'Pituitary', 'Meningioma', 'Glioma'] | |
| result = classes[predicted] | |
| return JSONResponse(content={"prediction": result}) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}) | |
| async def glioma_stage(file: UploadFile = File(...)): | |
| try: | |
| image = Image.open(file.file).convert("RGB") | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = glioma_model(image) | |
| predicted = torch.argmax(output, dim=1).item() | |
| stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4'] | |
| result = stages[predicted] | |
| return JSONResponse(content={"glioma_stage": result}) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}) |