File size: 2,988 Bytes
62bd7f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import cv2
import numpy as np
import joblib
from matplotlib import pyplot as plt
import os
import matplotlib
matplotlib.use('Agg')  # For headless environments
from .glcm_feature_extractor import GLCMFeatureExtractor

class FracturePredictor:
    def __init__(self, model_path='models/fracture_detection_model.joblib', 

                 encoder_path='models/label_encoder.joblib'):
        # Verify model paths
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")
        if not os.path.exists(encoder_path):
            raise FileNotFoundError(f"Encoder file not found: {encoder_path}")
            
        self.model = joblib.load(model_path)
        self.le = joblib.load(encoder_path)
        self.extractor = GLCMFeatureExtractor()
    
    def predict(self, img_input, visualize=True, save_path='prediction_result.png'):
        """

        Predict fracture from image input (file path)

        Returns: (label, confidence, visualization_path)

        """
        try:
            # Preprocess image
            img = self.extractor.preprocess_xray(img_input)
            if img is None:
                return "Error: Invalid image", 0.0, None
            
            # Extract features
            feat = self.extractor.extract_features(img)
            if feat is None:
                return "Error: Feature extraction failed", 0.0, None
            
            # Make prediction
            proba = self.model.predict_proba(feat.reshape(1, -1))[0]
            pred = self.model.predict(feat.reshape(1, -1))[0]
            label = self.le.inverse_transform([pred])[0]
            confidence = max(proba)
            
            # Generate visualization
            vis_path = None
            if visualize:
                vis_path = save_path
                self.visualize_prediction(img, label, confidence, proba, save_path)
            
            return label, confidence, vis_path
        except Exception as e:
            print(f"Prediction error: {str(e)}")
            return "Prediction error", 0.0, None
    
    def visualize_prediction(self, img, label, confidence, proba, save_path):
        """Create and save prediction visualization"""
        plt.figure(figsize=(12, 6))
        
        # Original image
        plt.subplot(1, 2, 1)
        plt.imshow(img, cmap='gray')
        plt.title(f"Original Image\nPrediction: {label}\nConfidence: {confidence:.2f}")
        plt.axis('off')
        
        # Probability distribution
        plt.subplot(1, 2, 2)
        colors = ['red' if cls != label else 'green' for cls in self.le.classes_]
        plt.bar(self.le.classes_, proba, color=colors)
        plt.title("Classification Probabilities")
        plt.ylabel("Probability")
        plt.ylim(0, 1)
        
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
        return save_path