File size: 14,356 Bytes
b4123b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
"""
Texture feature extraction for the Sorghum Pipeline.

This module handles extraction of texture features including:
- Local Binary Patterns (LBP)
- Histogram of Oriented Gradients (HOG)
- Lacunarity features
- Edge Histogram Descriptor (EHD)
"""

import numpy as np
import cv2
import torch
import torch.nn.functional as F
from skimage.feature import local_binary_pattern, hog
from skimage import exposure
from scipy import ndimage, signal
from sklearn.decomposition import PCA
from typing import Dict, Tuple, Optional, Any
import logging

logger = logging.getLogger(__name__)


class TextureExtractor:
    """Extracts texture features from images."""
    
    def __init__(self, 
                 lbp_points: int = 8,
                 lbp_radius: int = 1,
                 hog_orientations: int = 9,
                 hog_pixels_per_cell: Tuple[int, int] = (8, 8),
                 hog_cells_per_block: Tuple[int, int] = (2, 2),
                 lacunarity_window: int = 15,
                 ehd_threshold: float = 0.3,
                 angle_resolution: int = 45):
        """
        Initialize texture extractor.
        
        Args:
            lbp_points: Number of points for LBP
            lbp_radius: Radius for LBP
            hog_orientations: Number of orientations for HOG
            hog_pixels_per_cell: Pixels per cell for HOG
            hog_cells_per_block: Cells per block for HOG
            lacunarity_window: Window size for lacunarity
            ehd_threshold: Threshold for EHD
            angle_resolution: Angle resolution for EHD
        """
        self.lbp_points = lbp_points
        self.lbp_radius = lbp_radius
        self.hog_orientations = hog_orientations
        self.hog_pixels_per_cell = hog_pixels_per_cell
        self.hog_cells_per_block = hog_cells_per_block
        self.lacunarity_window = lacunarity_window
        self.ehd_threshold = ehd_threshold
        self.angle_resolution = angle_resolution
    
    def extract_lbp(self, gray_image: np.ndarray) -> np.ndarray:
        """
        Extract Local Binary Pattern features.
        
        Args:
            gray_image: Grayscale input image
            
        Returns:
            LBP feature map
        """
        try:
            lbp = local_binary_pattern(
                gray_image, 
                self.lbp_points, 
                self.lbp_radius, 
                method='uniform'
            )
            return self._convert_to_uint8(lbp)
        except Exception as e:
            logger.error(f"LBP extraction failed: {e}")
            return np.zeros_like(gray_image, dtype=np.uint8)
    
    def extract_hog(self, gray_image: np.ndarray) -> np.ndarray:
        """
        Extract Histogram of Oriented Gradients features.
        
        Args:
            gray_image: Grayscale input image
            
        Returns:
            HOG feature map
        """
        try:
            _, vis = hog(
                gray_image,
                orientations=self.hog_orientations,
                pixels_per_cell=self.hog_pixels_per_cell,
                cells_per_block=self.hog_cells_per_block,
                visualize=True,
                feature_vector=True
            )
            return exposure.rescale_intensity(vis, out_range=(0, 255)).astype(np.uint8)
        except Exception as e:
            logger.error(f"HOG extraction failed: {e}")
            return np.zeros_like(gray_image, dtype=np.uint8)
    
    def compute_local_lacunarity(self, gray_image: np.ndarray, window_size: int) -> np.ndarray:
        """
        Compute local lacunarity.
        
        Args:
            gray_image: Grayscale input image
            window_size: Size of the sliding window
            
        Returns:
            Local lacunarity map
        """
        try:
            arr = gray_image.astype(np.float32)
            m1 = ndimage.uniform_filter(arr, size=window_size)
            m2 = ndimage.uniform_filter(arr * arr, size=window_size)
            var = m2 - m1 * m1
            eps = 1e-6
            lac = var / (m1 * m1 + eps) + 1
            lac[m1 <= eps] = 0
            return lac
        except Exception as e:
            logger.error(f"Local lacunarity computation failed: {e}")
            return np.zeros_like(gray_image, dtype=np.float32)
    
    def compute_lacunarity_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Compute three types of lacunarity features.
        
        Args:
            gray_image: Grayscale input image
            
        Returns:
            Tuple of (lac1, lac2, lac3) lacunarity maps
        """
        try:
            # L1: Single window lacunarity
            lac1 = self.compute_local_lacunarity(gray_image, self.lacunarity_window)
            
            # L2: Multi-scale lacunarity
            scales = [max(3, self.lacunarity_window//2), self.lacunarity_window, self.lacunarity_window*2]
            lac2 = np.mean([
                self.compute_local_lacunarity(gray_image, s) for s in scales
            ], axis=0)
            
            # L3: DBC Lacunarity (if available)
            try:
                from ..models.dbc_lacunarity import DBC_Lacunarity
                x = torch.from_numpy(gray_image.astype(np.float32)/255.0)[None, None]
                layer = DBC_Lacunarity(window_size=self.lacunarity_window).eval()
                with torch.no_grad():
                    lac3 = layer(x).squeeze().cpu().numpy()
            except ImportError:
                logger.warning("DBC Lacunarity not available, using L2 as L3")
                lac3 = lac2.copy()
            
            return (
                self._convert_to_uint8(lac1),
                self._convert_to_uint8(lac2), 
                self._convert_to_uint8(lac3)
            )
        except Exception as e:
            logger.error(f"Lacunarity features computation failed: {e}")
            empty = np.zeros_like(gray_image, dtype=np.uint8)
            return empty, empty, empty
    
    def generate_ehd_masks(self, mask_size: int = 3) -> np.ndarray:
        """
        Generate masks for Edge Histogram Descriptor.
        
        Args:
            mask_size: Size of the mask
            
        Returns:
            Array of EHD masks
        """
        if mask_size < 3:
            mask_size = 3
        if mask_size % 2 == 0:
            mask_size += 1
        
        # Base gradient mask
        Gy = np.outer([1, 0, -1], [1, 2, 1])
        
        # Expand if needed
        if mask_size > 3:
            expd = np.outer([1, 2, 1], [1, 2, 1])
            for _ in range((mask_size - 3) // 2):
                Gy = signal.convolve2d(expd, Gy, mode='full')
        
        # Generate masks for different angles
        angles = np.arange(0, 360, self.angle_resolution)
        masks = np.zeros((len(angles), mask_size, mask_size), dtype=np.float32)
        
        for i, angle in enumerate(angles):
            masks[i] = ndimage.rotate(Gy, angle, reshape=False, mode='nearest')
        
        return masks
    
    def extract_ehd_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Extract Edge Histogram Descriptor features.
        
        Args:
            gray_image: Grayscale input image
            
        Returns:
            Tuple of (ehd_features, ehd_map)
        """
        try:
            # Generate masks
            masks = self.generate_ehd_masks()
            
            # Convert to tensor
            X = torch.from_numpy(gray_image.astype(np.float32)/255.0).unsqueeze(0).unsqueeze(0)
            masks_tensor = torch.tensor(masks).unsqueeze(1).float()
            
            # Convolve with masks
            edge_responses = F.conv2d(X, masks_tensor, dilation=7)
            
            # Find maximum response
            values, indices = torch.max(edge_responses, dim=1)
            indices[values < self.ehd_threshold] = masks.shape[0]
            
            # Pool features
            feat_vect = []
            for edge in range(masks.shape[0] + 1):
                pooled = F.avg_pool2d(
                    (indices == edge).unsqueeze(1).float(),
                    kernel_size=5, stride=1, padding=2
                )
                feat_vect.append(pooled.squeeze(1))
            
            ehd_features = torch.stack(feat_vect, dim=1).squeeze(0).cpu().numpy()
            ehd_map = np.argmax(ehd_features, axis=0).astype(np.uint8)
            
            return ehd_features, ehd_map
            
        except Exception as e:
            logger.error(f"EHD features extraction failed: {e}")
            empty_features = np.zeros((9, gray_image.shape[0]-4, gray_image.shape[1]-4), dtype=np.float32)
            empty_map = np.zeros_like(gray_image, dtype=np.uint8)
            return empty_features, empty_map
    
    def extract_all_texture_features(self, gray_image: np.ndarray) -> Dict[str, np.ndarray]:
        """
        Extract all texture features from a grayscale image.
        
        Args:
            gray_image: Grayscale input image
            
        Returns:
            Dictionary of texture features
        """
        features = {}
        
        try:
            # LBP
            features['lbp'] = self.extract_lbp(gray_image)
            
            # HOG
            features['hog'] = self.extract_hog(gray_image)
            
            # Lacunarity
            lac1, lac2, lac3 = self.compute_lacunarity_features(gray_image)
            features['lac1'] = lac1
            features['lac2'] = lac2
            features['lac3'] = lac3
            
            # EHD
            ehd_features, ehd_map = self.extract_ehd_features(gray_image)
            features['ehd_features'] = ehd_features
            features['ehd_map'] = ehd_map
            
            logger.debug("All texture features extracted successfully")
            
        except Exception as e:
            logger.error(f"Texture feature extraction failed: {e}")
            # Return empty features
            features = {
                'lbp': np.zeros_like(gray_image, dtype=np.uint8),
                'hog': np.zeros_like(gray_image, dtype=np.uint8),
                'lac1': np.zeros_like(gray_image, dtype=np.uint8),
                'lac2': np.zeros_like(gray_image, dtype=np.uint8),
                'lac3': np.zeros_like(gray_image, dtype=np.uint8),
                'ehd_features': np.zeros((9, gray_image.shape[0]-4, gray_image.shape[1]-4), dtype=np.float32),
                'ehd_map': np.zeros_like(gray_image, dtype=np.uint8)
            }
        
        return features
    
    def _convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
        """Convert array to uint8 with proper normalization."""
        arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
        if arr.ptp() > 0:
            normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
        else:
            normalized = np.zeros_like(arr)
        return np.clip(normalized, 0, 255).astype(np.uint8)
    
    def compute_texture_statistics(self, features: Dict[str, np.ndarray], 
                                 mask: Optional[np.ndarray] = None) -> Dict[str, Dict[str, float]]:
        """
        Compute statistics for texture features.
        
        Args:
            features: Dictionary of texture features
            mask: Optional mask to apply
            
        Returns:
            Dictionary of feature statistics
        """
        stats = {}
        
        for feature_name, feature_data in features.items():
            if feature_name == 'ehd_features':
                # Special handling for EHD features
                if mask is not None:
                    # Apply mask to each channel
                    masked_features = []
                    for i in range(feature_data.shape[0]):
                        channel = feature_data[i]
                        if mask.shape != channel.shape:
                            # Resize mask to match channel
                            mask_resized = cv2.resize(mask, (channel.shape[1], channel.shape[0]), 
                                                    interpolation=cv2.INTER_NEAREST)
                            masked_channel = np.where(mask_resized > 0, channel, np.nan)
                        else:
                            masked_channel = np.where(mask > 0, channel, np.nan)
                        masked_features.append(masked_channel)
                    feature_data = np.stack(masked_features, axis=0)
                else:
                    feature_data = feature_data
                
                # Compute statistics for each EHD channel
                channel_stats = {}
                for i in range(feature_data.shape[0]):
                    channel = feature_data[i]
                    valid_data = channel[~np.isnan(channel)]
                    if len(valid_data) > 0:
                        channel_stats[f'channel_{i}'] = {
                            'mean': float(np.mean(valid_data)),
                            'std': float(np.std(valid_data)),
                            'min': float(np.min(valid_data)),
                            'max': float(np.max(valid_data)),
                            'median': float(np.median(valid_data))
                        }
                stats[feature_name] = channel_stats
            else:
                # Regular 2D features
                if mask is not None and mask.shape == feature_data.shape:
                    masked_data = np.where(mask > 0, feature_data, np.nan)
                else:
                    masked_data = feature_data
                
                valid_data = masked_data[~np.isnan(masked_data)]
                if len(valid_data) > 0:
                    stats[feature_name] = {
                        'mean': float(np.mean(valid_data)),
                        'std': float(np.std(valid_data)),
                        'min': float(np.min(valid_data)),
                        'max': float(np.max(valid_data)),
                        'median': float(np.median(valid_data))
                    }
                else:
                    stats[feature_name] = {
                        'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0, 'median': 0.0
                    }
        
        return stats