Naman2302 commited on
Commit
3f9e49a
·
verified ·
1 Parent(s): d7cbbd5

uploaded 9 min

Browse files
__pycache__/train_pipeline.cpython-310.pyc ADDED
Binary file (8.59 kB). View file
 
samples/fractured_1.jpg ADDED
samples/normal_1.jpg ADDED
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
src/__pycache__/glcm_feature_extractor.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
src/__pycache__/predict_fracture.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
src/glcm_feature_extractor.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from skimage.feature import graycomatrix, graycoprops
4
+ import os
5
+ from glob import glob
6
+ from PIL import Image, UnidentifiedImageError
7
+
8
+ class GLCMFeatureExtractor:
9
+ def __init__(self, distances=[1, 3, 5], angles=[0, np.pi/4, np.pi/2, 3*np.pi/4]):
10
+ self.distances = distances
11
+ self.angles = angles
12
+
13
+ def preprocess_xray(self, img_path):
14
+ """Robust image loading with multiple fallbacks"""
15
+ try:
16
+ # First try with OpenCV
17
+ img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
18
+ if img is None:
19
+ # Fallback to PIL for problematic images
20
+ try:
21
+ with Image.open(img_path) as pil_img:
22
+ img = np.array(pil_img.convert('L'))
23
+ except (IOError, UnidentifiedImageError) as e:
24
+ raise ValueError(f"PIL cannot read image: {img_path}") from e
25
+
26
+ # Handle empty images
27
+ if img.size == 0:
28
+ raise ValueError(f"Empty image: {img_path}")
29
+
30
+ # Resize and normalize
31
+ img = cv2.resize(img, (256, 256))
32
+
33
+ # Improved normalization
34
+ img = img.astype(np.float32)
35
+ min_val = np.min(img)
36
+ max_val = np.max(img)
37
+
38
+ # Handle zero-contrast images
39
+ if max_val - min_val < 1e-5:
40
+ img = np.zeros_like(img) # Return black image
41
+ else:
42
+ img = (img - min_val) / (max_val - min_val) * 255
43
+
44
+ return img.astype(np.uint8)
45
+ except Exception as e:
46
+ print(f"Error processing {img_path}: {str(e)}")
47
+ return None
48
+
49
+ def extract_features(self, img):
50
+ """Extract GLCM features with validation"""
51
+ if img is None:
52
+ return None
53
+
54
+ try:
55
+ # Calculate GLCM with optimized parameters
56
+ glcm = graycomatrix(
57
+ img,
58
+ distances=self.distances,
59
+ angles=self.angles,
60
+ levels=256,
61
+ symmetric=True,
62
+ normed=True
63
+ )
64
+
65
+ # Extract texture properties
66
+ features = []
67
+ props = ['contrast', 'dissimilarity', 'homogeneity',
68
+ 'energy', 'correlation', 'ASM']
69
+
70
+ for prop in props:
71
+ feat = graycoprops(glcm, prop)
72
+ features.extend(feat.flatten())
73
+
74
+ return np.array(features)
75
+ except Exception as e:
76
+ print(f"Feature extraction error: {str(e)}")
77
+ return None
78
+
79
+ def extract_from_folder(self, folder_path, max_samples=None):
80
+ """Batch feature extraction with error handling"""
81
+ features = []
82
+ labels = []
83
+ class_name = os.path.basename(folder_path)
84
+
85
+ # Find all image files
86
+ image_paths = []
87
+ for ext in ('*.png', '*.jpg', '*.jpeg', '*.dcm', '*.tif', '*.bmp'):
88
+ image_paths.extend(glob(os.path.join(folder_path, ext)))
89
+
90
+ if not image_paths:
91
+ print(f"Warning: No images found in {folder_path}")
92
+ return np.array([]), np.array([])
93
+
94
+ # Apply sampling if requested
95
+ if max_samples and len(image_paths) > max_samples:
96
+ image_paths = np.random.choice(image_paths, max_samples, replace=False)
97
+
98
+ # Process each image
99
+ for img_path in image_paths:
100
+ img = self.preprocess_xray(img_path)
101
+ if img is None:
102
+ continue
103
+
104
+ feat = self.extract_features(img)
105
+ if feat is not None:
106
+ features.append(feat)
107
+ labels.append(class_name)
108
+
109
+ print(f"Successfully processed {len(features)}/{len(image_paths)} images in {folder_path}")
110
+ return np.array(features), np.array(labels)
src/predict_fracture.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import joblib
4
+ from matplotlib import pyplot as plt
5
+ import os
6
+ import matplotlib
7
+ matplotlib.use('Agg') # For headless environments
8
+ from .glcm_feature_extractor import GLCMFeatureExtractor
9
+
10
+ class FracturePredictor:
11
+ def __init__(self, model_path='models/fracture_detection_model.joblib',
12
+ encoder_path='models/label_encoder.joblib'):
13
+ # Verify model paths
14
+ if not os.path.exists(model_path):
15
+ raise FileNotFoundError(f"Model file not found: {model_path}")
16
+ if not os.path.exists(encoder_path):
17
+ raise FileNotFoundError(f"Encoder file not found: {encoder_path}")
18
+
19
+ self.model = joblib.load(model_path)
20
+ self.le = joblib.load(encoder_path)
21
+ self.extractor = GLCMFeatureExtractor()
22
+
23
+ def predict(self, img_input, visualize=True, save_path='prediction_result.png'):
24
+ """
25
+ Predict fracture from image input (file path)
26
+ Returns: (label, confidence, visualization_path)
27
+ """
28
+ try:
29
+ # Preprocess image
30
+ img = self.extractor.preprocess_xray(img_input)
31
+ if img is None:
32
+ return "Error: Invalid image", 0.0, None
33
+
34
+ # Extract features
35
+ feat = self.extractor.extract_features(img)
36
+ if feat is None:
37
+ return "Error: Feature extraction failed", 0.0, None
38
+
39
+ # Make prediction
40
+ proba = self.model.predict_proba(feat.reshape(1, -1))[0]
41
+ pred = self.model.predict(feat.reshape(1, -1))[0]
42
+ label = self.le.inverse_transform([pred])[0]
43
+ confidence = max(proba)
44
+
45
+ # Generate visualization
46
+ vis_path = None
47
+ if visualize:
48
+ vis_path = save_path
49
+ self.visualize_prediction(img, label, confidence, proba, save_path)
50
+
51
+ return label, confidence, vis_path
52
+ except Exception as e:
53
+ print(f"Prediction error: {str(e)}")
54
+ return "Prediction error", 0.0, None
55
+
56
+ def visualize_prediction(self, img, label, confidence, proba, save_path):
57
+ """Create and save prediction visualization"""
58
+ plt.figure(figsize=(12, 6))
59
+
60
+ # Original image
61
+ plt.subplot(1, 2, 1)
62
+ plt.imshow(img, cmap='gray')
63
+ plt.title(f"Original Image\nPrediction: {label}\nConfidence: {confidence:.2f}")
64
+ plt.axis('off')
65
+
66
+ # Probability distribution
67
+ plt.subplot(1, 2, 2)
68
+ colors = ['red' if cls != label else 'green' for cls in self.le.classes_]
69
+ plt.bar(self.le.classes_, proba, color=colors)
70
+ plt.title("Classification Probabilities")
71
+ plt.ylabel("Probability")
72
+ plt.ylim(0, 1)
73
+
74
+ plt.tight_layout()
75
+ plt.savefig(save_path)
76
+ plt.close()
77
+ return save_path