|
|
import cv2
|
|
|
import numpy as np
|
|
|
import joblib
|
|
|
from matplotlib import pyplot as plt
|
|
|
import os
|
|
|
import matplotlib
|
|
|
matplotlib.use('Agg')
|
|
|
from .glcm_feature_extractor import GLCMFeatureExtractor
|
|
|
|
|
|
class FracturePredictor:
|
|
|
def __init__(self, model_path='models/fracture_detection_model.joblib',
|
|
|
encoder_path='models/label_encoder.joblib'):
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
if not os.path.exists(encoder_path):
|
|
|
raise FileNotFoundError(f"Encoder file not found: {encoder_path}")
|
|
|
|
|
|
self.model = joblib.load(model_path)
|
|
|
self.le = joblib.load(encoder_path)
|
|
|
self.extractor = GLCMFeatureExtractor()
|
|
|
|
|
|
def predict(self, img_input, visualize=True, save_path='prediction_result.png'):
|
|
|
"""
|
|
|
Predict fracture from image input (file path)
|
|
|
Returns: (label, confidence, visualization_path)
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
img = self.extractor.preprocess_xray(img_input)
|
|
|
if img is None:
|
|
|
return "Error: Invalid image", 0.0, None
|
|
|
|
|
|
|
|
|
feat = self.extractor.extract_features(img)
|
|
|
if feat is None:
|
|
|
return "Error: Feature extraction failed", 0.0, None
|
|
|
|
|
|
|
|
|
proba = self.model.predict_proba(feat.reshape(1, -1))[0]
|
|
|
pred = self.model.predict(feat.reshape(1, -1))[0]
|
|
|
label = self.le.inverse_transform([pred])[0]
|
|
|
confidence = max(proba)
|
|
|
|
|
|
|
|
|
vis_path = None
|
|
|
if visualize:
|
|
|
vis_path = save_path
|
|
|
self.visualize_prediction(img, label, confidence, proba, save_path)
|
|
|
|
|
|
return label, confidence, vis_path
|
|
|
except Exception as e:
|
|
|
print(f"Prediction error: {str(e)}")
|
|
|
return "Prediction error", 0.0, None
|
|
|
|
|
|
def visualize_prediction(self, img, label, confidence, proba, save_path):
|
|
|
"""Create and save prediction visualization"""
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 1)
|
|
|
plt.imshow(img, cmap='gray')
|
|
|
plt.title(f"Original Image\nPrediction: {label}\nConfidence: {confidence:.2f}")
|
|
|
plt.axis('off')
|
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 2)
|
|
|
colors = ['red' if cls != label else 'green' for cls in self.le.classes_]
|
|
|
plt.bar(self.le.classes_, proba, color=colors)
|
|
|
plt.title("Classification Probabilities")
|
|
|
plt.ylabel("Probability")
|
|
|
plt.ylim(0, 1)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(save_path)
|
|
|
plt.close()
|
|
|
return save_path
|
|
|
|