DL_project / src /predict_fracture.py
Naman2302's picture
Upload 15 files
62bd7f8 verified
import cv2
import numpy as np
import joblib
from matplotlib import pyplot as plt
import os
import matplotlib
matplotlib.use('Agg') # For headless environments
from .glcm_feature_extractor import GLCMFeatureExtractor
class FracturePredictor:
def __init__(self, model_path='models/fracture_detection_model.joblib',
encoder_path='models/label_encoder.joblib'):
# Verify model paths
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
if not os.path.exists(encoder_path):
raise FileNotFoundError(f"Encoder file not found: {encoder_path}")
self.model = joblib.load(model_path)
self.le = joblib.load(encoder_path)
self.extractor = GLCMFeatureExtractor()
def predict(self, img_input, visualize=True, save_path='prediction_result.png'):
"""
Predict fracture from image input (file path)
Returns: (label, confidence, visualization_path)
"""
try:
# Preprocess image
img = self.extractor.preprocess_xray(img_input)
if img is None:
return "Error: Invalid image", 0.0, None
# Extract features
feat = self.extractor.extract_features(img)
if feat is None:
return "Error: Feature extraction failed", 0.0, None
# Make prediction
proba = self.model.predict_proba(feat.reshape(1, -1))[0]
pred = self.model.predict(feat.reshape(1, -1))[0]
label = self.le.inverse_transform([pred])[0]
confidence = max(proba)
# Generate visualization
vis_path = None
if visualize:
vis_path = save_path
self.visualize_prediction(img, label, confidence, proba, save_path)
return label, confidence, vis_path
except Exception as e:
print(f"Prediction error: {str(e)}")
return "Prediction error", 0.0, None
def visualize_prediction(self, img, label, confidence, proba, save_path):
"""Create and save prediction visualization"""
plt.figure(figsize=(12, 6))
# Original image
plt.subplot(1, 2, 1)
plt.imshow(img, cmap='gray')
plt.title(f"Original Image\nPrediction: {label}\nConfidence: {confidence:.2f}")
plt.axis('off')
# Probability distribution
plt.subplot(1, 2, 2)
colors = ['red' if cls != label else 'green' for cls in self.le.classes_]
plt.bar(self.le.classes_, proba, color=colors)
plt.title("Classification Probabilities")
plt.ylabel("Probability")
plt.ylim(0, 1)
plt.tight_layout()
plt.savefig(save_path)
plt.close()
return save_path