| import torch
|
| import torch.nn as nn
|
| from torchvision import transforms, models
|
| from PIL import Image
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
| transform = transforms.Compose([
|
| transforms.Resize((224, 224)),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.485, 0.456, 0.406],
|
| [0.229, 0.224, 0.225])
|
| ])
|
|
|
|
|
| label_map = {
|
| 0: "glioma",
|
| 1: "meningioma",
|
| 2: "no_tumor",
|
| 3: "pituitary"
|
| }
|
|
|
|
|
| model = models.resnet18(weights=None)
|
| num_ftrs = model.fc.in_features
|
| model.fc = nn.Linear(num_ftrs, 4)
|
| model = model.to(device)
|
| model.load_state_dict(torch.load("brain_tumor_resnet18.pth", map_location=device))
|
| model.eval()
|
|
|
|
|
| def predict_image(image_path):
|
| img = Image.open(image_path).convert('RGB')
|
| img = transform(img).unsqueeze(0).to(device)
|
| with torch.no_grad():
|
| outputs = model(img)
|
| _, pred_idx = torch.max(outputs, 1)
|
| pred_label = label_map[pred_idx.item()]
|
| print(f"Prediction: {pred_label}")
|
|
|
|
|
| predict_image("D:\\python\\Advanced_tumor\\img_3.png")
|
|
|