File size: 2,773 Bytes
0e038f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# backend/app/inference.py
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from pathlib import Path
from .model import create_detection_model


class InferenceEngine:
    def __init__(self, model_path: str):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = None
        self.model_path = model_path
        
        # Class mapping
        self.class_map = {
            0: 'glioma',
            1: 'meningioma',
            2: 'pituitary',
            3: 'notumor'
        }
        
        # Preprocessing transforms
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                               [0.229, 0.224, 0.225])
        ])
        
        self._load_model()
    
    def _load_model(self):
        """Load the PyTorch model"""
        try:
            self.model = create_detection_model(num_classes=4)
            self.model.load_state_dict(
                torch.load(self.model_path, map_location=self.device)
            )
            self.model.to(self.device)
            self.model.eval()
            print(f"✅ Model loaded successfully on {self.device}")
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            raise
    
    def predict(self, image_path: str) -> dict:
        """
        Run inference on an image
        
        Args:
            image_path: Path to the image file
            
        Returns:
            Dictionary with prediction results
        """
        try:
            # Load and preprocess image
            image = Image.open(image_path).convert('RGB')
            input_tensor = self.transform(image).unsqueeze(0).to(self.device)
            
            # Run inference
            with torch.no_grad():
                output = self.model(input_tensor)
            
            # Get probabilities and prediction
            probabilities = F.softmax(output, dim=1)
            confidence, predicted_index = torch.max(probabilities, 1)
            
            predicted_class = self.class_map[predicted_index.item()]
            
            return {
                "success": True,
                "predicted_class": predicted_class,
                "confidence": float(confidence.item()),
                "all_probabilities": {
                    self.class_map[i]: float(probabilities[0][i].item())
                    for i in range(len(self.class_map))
                }
            }
            
        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }