local5 commited on
Commit
c971a0f
·
verified ·
1 Parent(s): 985d8d9

Upload texture_classification.py

Browse files
Files changed (1) hide show
  1. texture_classification.py +182 -0
texture_classification.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ from skimage.feature import local_binary_pattern, graycomatrix, graycoprops
4
+ from sklearn.svm import LinearSVC
5
+ import os
6
+ from sklearn.metrics import accuracy_score, precision_score, \
7
+ classification_report, confusion_matrix
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ from sklearn.model_selection import train_test_split
11
+ import joblib
12
+
13
+ IMAGE_SIZE_GLCM = 256
14
+ IMAGE_SIZE_LBP = 128
15
+
16
+ # LBP parameters
17
+ RADIUS = 1
18
+ N_POINTS = 8 * RADIUS
19
+ LBP_METHOD = "uniform"
20
+
21
+
22
+ def compute_glcm_histogram_pil(image, distances=[1], angles=[0], levels=8,
23
+ symmetric=True):
24
+
25
+ # Convert the PIL image to a NumPy array
26
+ image_np = np.array(image)
27
+
28
+ # Quantize the grayscale image to the specified number of levels
29
+ image_np = (image_np * (levels - 1) / 255).astype(np.uint8)
30
+
31
+ # Compute the GLCM using skimage's graycomatrix function
32
+ glcm = graycomatrix(image_np,
33
+ distances=distances,
34
+ angles=angles,
35
+ levels=levels,
36
+ symmetric=symmetric,
37
+ normed=True)
38
+
39
+ # Extract GLCM properties
40
+ homogeneity = graycoprops(glcm, 'homogeneity')[0, 0]
41
+ correlation = graycoprops(glcm, 'correlation')[0, 0]
42
+
43
+ # Create the feature vector
44
+ feature_vector = np.array([homogeneity, correlation])
45
+
46
+ return feature_vector
47
+
48
+
49
+ def image_resize(img, n):
50
+ # Crop the image to a square by finding the minimum dimension
51
+ min_dimension = min(img.size)
52
+ left = (img.width - min_dimension) / 2
53
+ top = (img.height - min_dimension) / 2
54
+ right = (img.width + min_dimension) / 2
55
+ bottom = (img.height + min_dimension) / 2
56
+ img = img.crop((left, top, right, bottom))
57
+ img = img.resize((n, n))
58
+ return img
59
+
60
+
61
+ def get_lbp_hist(gray_image, n_points, radius, method):
62
+ # Compute LBP for the image
63
+ lbp = local_binary_pattern(gray_image, n_points, radius, method)
64
+
65
+ # Compute LBP histogram
66
+ lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, n_points + 3),
67
+ range=(0, n_points + 2))
68
+
69
+ # Normalize the histogram
70
+ lbp_hist = lbp_hist.astype("float")
71
+ lbp_hist /= (lbp_hist.sum() + 1e-6) # Normalized histogram
72
+ return lbp_hist
73
+
74
+
75
+ def get_features(input_folder, class_label, method):
76
+ data = []
77
+ labels = []
78
+ filenames = []
79
+ image_files = [f for f in os.listdir(input_folder) if f.lower().endswith((
80
+ '.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
81
+
82
+ print(f"Total images found: {len(image_files)}")
83
+
84
+ for _, file_name in enumerate(sorted(image_files)):
85
+ img_path = os.path.join(input_folder, file_name)
86
+ try:
87
+ img = Image.open(img_path)
88
+ img.verify()
89
+ img = Image.open(img_path)
90
+ img_gray = img.convert("L")
91
+
92
+ if method == "GLCM":
93
+ img_resized = image_resize(img_gray, IMAGE_SIZE_GLCM)
94
+ hist = compute_glcm_histogram_pil(img_resized)
95
+ elif method == "LBP":
96
+ img_resized = image_resize(img_gray, IMAGE_SIZE_LBP)
97
+ hist = get_lbp_hist(np.array(img_resized), N_POINTS, RADIUS,
98
+ LBP_METHOD)
99
+
100
+ data.append(hist)
101
+ labels.append(class_label)
102
+ filenames.append(file_name) # Store the filenames
103
+
104
+ except (FileNotFoundError, PermissionError) as file_err:
105
+ print(f"File error with {file_name}: {file_err}")
106
+ except Image.UnidentifiedImageError:
107
+ print(f"Unidentified image file: {file_name}. Skipping this file.")
108
+ except Exception as e:
109
+ print(f"Unexpected error processing {file_name}: {e}")
110
+
111
+ return data, labels, filenames
112
+
113
+
114
+ def main():
115
+
116
+ # Set method
117
+ method = "LBP"
118
+
119
+ # Define paths
120
+ grass_data, grass_labels, grass_filenames = get_features(
121
+ "./raw_data/raw_grass_dataset", "Grass", method)
122
+ wood_data, wood_labels, wood_filenames = get_features(
123
+ "./raw_data/raw_wood_dataset", "Wood", method)
124
+ data = grass_data + wood_data
125
+ labels = grass_labels + wood_labels
126
+ filenames = grass_filenames + wood_filenames # Combine filenames
127
+
128
+ # Train-test split
129
+ X_train, X_test, y_train, y_test, train_filenames, test_filenames = \
130
+ train_test_split(data, labels, filenames, test_size=0.3,
131
+ random_state=9, stratify=labels)
132
+
133
+ # Train the model
134
+ model = LinearSVC(C=100, loss="squared_hinge")
135
+ model.fit(X_train, y_train)
136
+
137
+ # Make predictions on the test set
138
+ y_pred = model.predict(X_test)
139
+
140
+ # Calculate accuracy and precision
141
+ accuracy = accuracy_score(y_test, y_pred)
142
+ precision = precision_score(y_test, y_pred, average='macro')
143
+
144
+ # Print the results
145
+ print(f"Accuracy: {accuracy:.2f}")
146
+ print(f"Precision: {precision:.2f}")
147
+
148
+ # Get a classification report for additional metrics
149
+ print("\nClassification Report:")
150
+ print(classification_report(y_test, y_pred))
151
+ # print(f"Radius: {RADIUS}, N: {N_POINTS}")
152
+
153
+ # Calculate the confusion matrix
154
+ conf_matrix = confusion_matrix(y_test, y_pred)
155
+
156
+ # Print the confusion matrix
157
+ print("Confusion Matrix:")
158
+ print(conf_matrix)
159
+
160
+ # Create a heatmap for visualization
161
+ plt.figure(figsize=(6, 4))
162
+ sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
163
+ xticklabels=["Grass", "Wood"], yticklabels=["Grass", "Wood"])
164
+ plt.xlabel('Predicted')
165
+ plt.ylabel('True')
166
+ plt.title('Confusion Matrix')
167
+ plt.show()
168
+
169
+ # Identify misclassified images
170
+ misclassified = [fname for i, fname in enumerate(test_filenames)
171
+ if y_test[i] != y_pred[i]]
172
+
173
+ print("Misclassified Images:")
174
+ for fname in misclassified:
175
+ print(fname)
176
+
177
+ # Save model parameters for deployment
178
+ joblib.dump(model, method + '_model.joblib')
179
+
180
+
181
+ if __name__ == "__main__":
182
+ main()