Spaces:
Sleeping
Sleeping
| import itertools | |
| import warnings | |
| from io import BytesIO | |
| from copy import deepcopy | |
| from pathlib import Path | |
| from typing import List, Tuple, Dict, Optional, Any, Union | |
| from dataclasses import dataclass, field | |
| from PIL import Image | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from scipy import ndimage | |
| from skimage import measure | |
| import torchvision.transforms.functional as tvf | |
| from ultralytics import YOLO | |
| import torch | |
| import cv2 | |
| import gradio | |
| import sys, os | |
| sys.path.append(os.path.abspath('angle_calculation')) | |
| # from classic import measure_object | |
| from sampling import get_crop_batch | |
| from angle_model import PatchedPredictor, StripsModelLumenWidth | |
| from period_calculation.period_measurer import PeriodMeasurer | |
| # from grana_detection.mmwrapper import MMDetector # mmdet installation in docker is problematic for now | |
| class Granum: | |
| id: Optional[int] = None | |
| image: Any = None | |
| mask: Any = None | |
| scaler: Any = None | |
| nm_per_px: float = float('nan') | |
| detection_confidence: float = float('nan') | |
| img_oriented: Optional[np.ndarray] = None # oriented fragment of the image | |
| mask_oriented: Optional[np.ndarray] = None # oriented fragment of the mask | |
| measurements: dict = field(default_factory=dict) # dict with grana measurements | |
| class ScalerPadder: | |
| """resize and pad image to specific range. | |
| minimal_pad: obligatory padding, e.g. required for detector | |
| """ | |
| def __init__(self, target_size=1024, target_short_edge_min=640, minimal_pad=16, pad_to_multiply=32): | |
| self.minimal_pad = minimal_pad | |
| self.target_size = target_size - 2*self.minimal_pad # detection pad is necessary padding size | |
| self.target_short_edge_min = target_short_edge_min - 2*self.minimal_pad | |
| self.max_size_nm = 6000 # training images covers ~3100 nm | |
| self.min_size_nm = 2400 # training images covers ~3100 nm | |
| self.pad_to_multiply = pad_to_multiply | |
| def transform(self, image: Image.Image, px_per_nm: float=1.298) -> Image.Image: | |
| self.original_size = image.size | |
| self.original_px_per_nm = px_per_nm | |
| w, h = self.original_size | |
| longest_size = max(h, w) | |
| img_size_nm = longest_size / px_per_nm | |
| if img_size_nm > self.max_size_nm: | |
| error_message = f'too large image, image size: {img_size_nm:0.1f}nm, max allowed: {self.max_size_nm}nm' | |
| # raise ValueError(error_message) | |
| # warnings.warn(warning_message) | |
| gradio.Warning(error_message) | |
| # add_text(image, warning_message, location=(0.1, 0.1), color='blue', size=int(40*longest_size/self.target_size)) | |
| self.resize_factor = self.target_size / (max(self.min_size_nm, img_size_nm) * px_per_nm) | |
| self.px_per_nm_transformed = px_per_nm * self.resize_factor | |
| resized_image = resize_with_cv2(image, (int(h*self.resize_factor), int(w*self.resize_factor))) | |
| if w >= h: | |
| pad_w = self.target_size-resized_image.size[0] | |
| pad_h = max(0, self.target_short_edge_min-resized_image.size[1]) | |
| else: | |
| pad_w = max(0, self.target_short_edge_min-resized_image.size[0]) | |
| pad_h = self.target_size-resized_image.size[1] | |
| # apply minimal padding | |
| pad_w += 2*self.minimal_pad | |
| pad_h += 2*self.minimal_pad | |
| # round to multiplication | |
| pad_w += (self.pad_to_multiply - resized_image.size[0]%self.pad_to_multiply)%self.pad_to_multiply | |
| pad_h += (self.pad_to_multiply - resized_image.size[1]%self.pad_to_multiply)%self.pad_to_multiply | |
| self.pad_right = pad_w // 2 | |
| self.pad_left = pad_w - self.pad_right | |
| self.pad_up = pad_h // 2 | |
| self.pad_bottom = pad_h - self.pad_up | |
| padded_image = tvf.pad(resized_image, [self.pad_left,self.pad_up, self.pad_right, self.pad_bottom], padding_mode='reflect') # fill 114 as in YOLO | |
| return padded_image | |
| def unpad_slice(self) -> Tuple[slice]: | |
| return slice(self.pad_up,-self.pad_bottom if self.pad_bottom>0 else None), slice(self.pad_left,-self.pad_right if self.pad_right>0 else None) | |
| def inverse_transform(self, image: Union[np.ndarray, Image.Image], output_size: Optional[Tuple[int]]=None, output_nm_per_px: Optional[float]=None, return_pil: bool=True) -> Image.Image: | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| # h, w = image.shape[:2] | |
| # unpadded_image = image[self.pad_up:h-self.pad_bottom,self.pad_left:w-self.pad_right] | |
| unapdded_image = image[self.unpad_slice] | |
| if output_size is not None and output_nm_per_px is not None: | |
| raise ValueError("one of output_size or output_nm_per_px must not be None") | |
| elif output_nm_per_px is not None: | |
| resize_factor = self.original_nm_per_px/output_nm_per_px | |
| output_size = (int(self.original_size[0]*resize_factor), int(self.original_size[1]*resize_factor)) | |
| elif output_size is None: | |
| output_size = self.original_size | |
| resized_image = resize_with_cv2(unapdded_image, (output_size[1],output_size[0]), return_pil=return_pil) #Image.fromarray(unpadded_image).resize(self.original_size) | |
| return resized_image | |
| def close_contour(contour): | |
| if not np.array_equal(contour[0], contour[-1]): | |
| contour = np.vstack((contour, contour[0])) | |
| return contour | |
| def binary_mask_to_polygon(binary_mask, tol=0.01): | |
| padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0) | |
| contours = measure.find_contours(padded_binary_mask, 0.5) | |
| # assert len(contours) == 1 #raise error if there are more than 1 contour | |
| contour = contours[0] | |
| contour -= 1 # correct for padding | |
| contour = close_contour(contour) | |
| polygon = measure.approximate_polygon(contour, tol) | |
| polygon = np.flip(polygon, axis=1) | |
| # after padding and subtracting 1 we may get -0.5 points in our polygon. Replace it with 0 | |
| polygon = np.where(polygon>=0, polygon, 0) | |
| # segmentation = polygon.ravel().tolist() | |
| return polygon | |
| def measure_shape(binary_mask): | |
| contour = binary_mask_to_polygon(binary_mask) | |
| perimeter = np.sum(np.linalg.norm(contour[:-1] - contour[1:], axis=1)) | |
| area = np.sum(binary_mask) | |
| return perimeter, area | |
| def calculate_gsi(perimeter, height, area): | |
| a = 0.5*(perimeter - 2*height) | |
| return 1 - area/(a*height) | |
| def object_slice(mask): | |
| rows = np.any(mask, axis=1) | |
| cols = np.any(mask, axis=0) | |
| row_min, row_max = np.where(rows)[0][[0, -1]] | |
| col_min, col_max = np.where(cols)[0][[0, -1]] | |
| # Create a slice object for the bounding box | |
| bounding_box_slice = (slice(row_min, row_max + 1), slice(col_min, col_max + 1)) | |
| return bounding_box_slice | |
| def figure_to_pil(fig): | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png') | |
| buf.seek(0) | |
| # Load the image from the buffer as a PIL Image | |
| image = deepcopy(Image.open(buf)) | |
| # Close the buffer | |
| buf.close() | |
| return image | |
| def resize_to(image: Image.Image, s: int=4032, return_factor: bool =False) -> Image.Image: | |
| w, h = image.size | |
| longest_size = max(h, w) | |
| resize_factor = longest_size / s | |
| resized_image = image.resize((int(w/resize_factor), int(h/resize_factor))) | |
| if return_factor: | |
| return resized_image, resize_factor | |
| return resized_image | |
| def resize_with_cv2(image, shape, return_pil=True): | |
| """resize using cv2 with cv2.INTER_LINEAR - consistent with YOLO""" | |
| h, w = shape | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| resized = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR) | |
| if return_pil: | |
| return Image.fromarray(resized) | |
| else: | |
| return resized | |
| def select_unique_mask(mask): | |
| """if mask consists of multiple parts, select the largest""" | |
| if not np.any(mask): # if mask is empty, return without change | |
| return mask | |
| blobs = ndimage.label(mask)[0] | |
| blob_labels, blob_sizes = np.unique(blobs, return_counts=True) | |
| best_blob_label = blob_labels[1:][np.argmax(blob_sizes[1:])] | |
| return blobs == best_blob_label | |
| def sliced_mean(x, slice_size): | |
| cs_y = np.cumsum(x, axis=0) | |
| cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0) | |
| slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size | |
| cs_xy = np.cumsum(slices_y, axis=1) | |
| cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1) | |
| slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size | |
| return slices_xy | |
| def sliced_var(x, slice_size): | |
| x = x.astype('float64') | |
| return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2 | |
| def calculate_distance_map(mask): | |
| padded = np.pad(mask, pad_width=1, mode='constant', constant_values=False) | |
| distance_map_padded = ndimage.distance_transform_edt(padded) | |
| return distance_map_padded[1:-1,1:-1] | |
| def select_samples(granum_image, granum_mask, crop_size=96, n_samples=64, granum_fraction_min=0.75, variance_p=0.): | |
| granum_occupancy = sliced_mean(granum_mask, crop_size) | |
| possible_indices = np.stack(np.where(granum_occupancy >= granum_fraction_min), axis=1) | |
| if variance_p == 0: | |
| p = np.ones(len(possible_indices)) | |
| else: | |
| variance_map = sliced_var(granum_image, crop_size) | |
| p = variance_map[possible_indices[:,0], possible_indices[:,1]]**variance_p | |
| p /= np.sum(p) | |
| chosen_indices = np.random.choice( | |
| np.arange(len(possible_indices)), | |
| min(len(possible_indices), n_samples), | |
| replace=False, | |
| p = p | |
| ) | |
| crops = [] | |
| for crop_idx, idx in enumerate(chosen_indices): | |
| crops.append( | |
| granum_image[ | |
| possible_indices[idx,0]:possible_indices[idx,0]+crop_size, | |
| possible_indices[idx,1]:possible_indices[idx,1]+crop_size | |
| ] | |
| ) | |
| return np.array(crops) | |
| def calculate_height(mask_oriented): #HACK | |
| span = mask_oriented.shape[0] - np.argmax(mask_oriented[::-1], axis=0) - np.argmax(mask_oriented, axis=0) | |
| return np.quantile(span, 0.8) | |
| def calculate_diameter(mask_oriented): | |
| """returns mean diameter""" | |
| # calculate 0.25 and 0.75 lines | |
| vertical_mask = np.any(mask_oriented, axis=1) | |
| upper_granum_bound = np.argmax(vertical_mask) | |
| lower_granum_bound = mask_oriented.shape[0] - np.argmax(vertical_mask[::-1]) | |
| upper = round(0.75*upper_granum_bound + 0.25*lower_granum_bound) | |
| lower = max(upper+1, round(0.25*upper_granum_bound + 0.75*lower_granum_bound)) | |
| valid_rows_slice = slice(upper, lower) | |
| # calculate diameters | |
| span = mask_oriented.shape[1] - np.argmax(mask_oriented[valid_rows_slice,::-1], axis=1) - np.argmax(mask_oriented[valid_rows_slice], axis=1) | |
| return np.mean(span) | |
| def robust_mean(x, q=0.1): | |
| x_med = np.median(x) | |
| deviations = abs(x- x_med) | |
| if max(deviations) == 0: | |
| mask = np.ones(len(x), dtype='bool') | |
| else: | |
| threshold = np.quantile(deviations, 1-q) | |
| mask = x[deviations<= threshold] | |
| return np.mean(x[mask]) | |
| def rotate_image_and_mask(image, mask, direction): | |
| mask_oriented = ndimage.rotate(mask.astype('int'), -direction, reshape=True).astype('bool') | |
| idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])] | |
| idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])] | |
| img_oriented = ndimage.rotate(image, -direction, reshape=True) #[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x] | |
| return img_oriented, mask_oriented | |
| class GranaAnalyser: | |
| def __init__(self, weights_detector: str, weights_orientation: str, weights_period: str, period_sd_threshold_nm: float=2.5) -> None: | |
| """ | |
| Initializes the GranaAnalyser with specified weights for detection and measuring. | |
| This method loads the weights for the grana detection and measuring algorithms | |
| from the specified file paths. It also loads mock data for visualization and | |
| analysis purposes. | |
| Parameters: | |
| weights_detector (str): The file path to the weights file for the grana detection algorithm. | |
| weights_orientation (str): The file path to the weights file for the grana orientation algorithm. | |
| weights_period (str): The file path to the weights file for the grana period algorithm. | |
| """ | |
| self.detector = YOLO(weights_detector) | |
| self.orienter = PatchedPredictor( | |
| StripsModelLumenWidth.load_from_checkpoint(weights_orientation, map_location='cpu').eval(), | |
| normalization = dict(mean=0.250, std=0.135), | |
| n_samples=32, | |
| mask=None, | |
| crop_size=64, | |
| angle_confidence_threshold=0.2 | |
| ) | |
| self.measurement_px_per_nm = 1/0.768 # image scale required for measurement | |
| self.period_measurer = PeriodMeasurer( | |
| weights_period, | |
| px_per_nm=self.measurement_px_per_nm, | |
| sd_threshold_nm=period_sd_threshold_nm, | |
| period_threshold_nm_min=14, period_threshold_nm_max=30 | |
| ) | |
| def get_grana_data(self, image, detections, scaler, border_margin=1, min_count=1) -> List[Granum]: | |
| """filter detections and create grana data""" | |
| image_numpy = np.array(image) | |
| if image_numpy.ndim == 3: | |
| image_numpy = image_numpy[:,:,0] | |
| mask_all = None | |
| grana = [] | |
| for mask, confidence in zip( | |
| detections.masks.data.cpu().numpy().astype('bool'), | |
| detections.boxes.conf.cpu().numpy() | |
| ): | |
| granum_mask = select_unique_mask(mask[scaler.unpad_slice]) | |
| # check if mask is empty after padding | |
| if not np.any(granum_mask): | |
| continue | |
| granum_mask = ndimage.binary_fill_holes(granum_mask) | |
| # check if touches boundary: | |
| if (np.sum(granum_mask[:border_margin])>min_count) or \ | |
| (np.sum(granum_mask[-border_margin:])>min_count) or \ | |
| (np.sum(granum_mask[:,:border_margin])>min_count) or \ | |
| (np.sum(granum_mask[:,-border_margin:])>min_count): | |
| continue | |
| # check grana overlap | |
| if mask_all is None: | |
| mask_all = granum_mask | |
| else: | |
| intersection = mask_all & granum_mask | |
| if intersection.sum() >= (granum_mask.sum() * 0.2): | |
| continue | |
| mask_all = mask_all | granum_mask | |
| granum = Granum( | |
| image = image, | |
| mask = granum_mask, | |
| scaler=scaler, | |
| detection_confidence=float(confidence) | |
| ) | |
| granum.image_numpy = image_numpy | |
| grana.append(granum) | |
| return grana | |
| def measure_grana(self, grana: List[Granum], measurement_image: np.ndarray) -> List[Granum]: | |
| """measure grana: includes orientation detection, period detection and geometric measurements""" | |
| for granum in grana: | |
| measurement_mask = resize_with_cv2(granum.mask.astype(np.uint8), measurement_image.shape[:2], return_pil=False).astype('bool') | |
| granum.bounding_box_slice = object_slice(measurement_mask) | |
| granum.image_crop = measurement_image[granum.bounding_box_slice][:,:] | |
| granum.mask_crop = measurement_mask[granum.bounding_box_slice] | |
| # initialize measurements | |
| granum.measurements = {} | |
| # measure shape | |
| granum.measurements['perimeter px'], granum.measurements['area px'] = measure_shape(granum.mask_crop) | |
| # measrure orientation | |
| orienter_predictions = self.orienter(granum.image_crop, granum.mask_crop) | |
| granum.measurements['direction'] = orienter_predictions["est_angle"] | |
| granum.measurements['direction confidence'] = orienter_predictions["est_angle_confidence"] | |
| if not np.isnan(granum.measurements["direction"]): | |
| img_oriented, mask_oriented = rotate_image_and_mask(granum.image_crop, granum.mask_crop, granum.measurements["direction"]) | |
| oriented_granum_slice = object_slice(mask_oriented) | |
| granum.img_oriented = img_oriented[oriented_granum_slice] | |
| granum.mask_oriented = mask_oriented[oriented_granum_slice] | |
| granum.measurements['height px'] = calculate_height(granum.mask_oriented) | |
| granum.measurements['GSI'] = calculate_gsi( | |
| granum.measurements['perimeter px'], | |
| granum.measurements['height px'], | |
| granum.measurements['area px'] | |
| ) | |
| granum.measurements['diameter px'] = calculate_diameter(granum.mask_oriented) | |
| oriented_granum_slice = object_slice(granum.mask_oriented) | |
| granum.measurements["period nm"], granum.measurements["period SD nm"] = self.period_measurer(granum.img_oriented, granum.mask_oriented) | |
| if not pd.isna(granum.measurements['period nm']): | |
| granum.measurements['Number of layers'] = round(granum.measurements['height px']/ self.measurement_px_per_nm / granum.measurements['period nm']) | |
| return grana | |
| def extract_grana_data(self, grana: List[Granum]) -> pd.DataFrame: | |
| """collect and scale grana data""" | |
| grana_data = [] | |
| for granum in grana: | |
| granum_entry = { | |
| 'Granum ID': granum.id, | |
| 'detection confidence': granum.detection_confidence | |
| } | |
| # fill with None if absent: | |
| for key in ['direction', 'Number of layers', 'GSI', 'period nm', 'period SD nm']: | |
| granum_entry[key] = granum.measurements.get(key, None) | |
| # scale linearly: | |
| for key in ['height px', 'diameter px', 'perimeter px', 'perimeter px']: | |
| granum_entry[f"{key[:-3]} nm"] = granum.measurements.get(key, np.nan) / self.measurement_px_per_nm | |
| # scale quadratically | |
| granum_entry['area nm^2'] = granum.measurements['area px'] / self.measurement_px_per_nm**2 | |
| grana_data.append(granum_entry) | |
| return pd.DataFrame(grana_data) | |
| def visualize_detections(self, grana: List[Granum], image: Image.Image) -> Image.Image: | |
| visualization_longer_edge = 1024 | |
| scale = visualization_longer_edge/max(image.size) | |
| visualization_size = (round(scale*image.size[0]), round(scale*image.size[1])) | |
| visualization_image = np.array(image.resize(visualization_size).convert('RGB')) | |
| if len(grana) > 0: | |
| grana_mask = resize_with_cv2( | |
| np.any(np.array([granum.mask for granum in grana]),axis=0).astype(np.uint8), | |
| visualization_size[::-1], | |
| return_pil=False | |
| ).astype('bool') | |
| visualization_image[grana_mask]= (0.7*visualization_image[grana_mask] + 0.3*np.array([[[39, 179, 115]]])).astype(np.uint8) | |
| for granum in grana: | |
| scale = visualization_longer_edge/max(granum.mask.shape) | |
| y, x = ndimage.center_of_mass(granum.mask) | |
| cv2.putText(visualization_image, f'{granum.id}', org=(int(x*scale)-10, int(y*scale)+10), fontFace=cv2.FONT_HERSHEY_SIMPLEX , fontScale=1, color=(39, 179, 115),thickness = 2) | |
| return Image.fromarray(visualization_image) | |
| def generate_grana_images(self, grana: List[Granum], image_name: str ="") -> List[Image.Image]: | |
| grana_images = {} | |
| for granum in grana: | |
| fig, ax = plt.subplots() | |
| if granum.img_oriented is None: | |
| image_to_plot = granum.image_crop | |
| mask_to_plot = granum.mask_crop | |
| extra_caption = " orientation and period unknown" | |
| else: | |
| image_to_plot = granum.img_oriented | |
| mask_to_plot = granum.mask_oriented | |
| extra_caption = "" | |
| ax.imshow(0.5*255*(~mask_to_plot) +image_to_plot*(1-0.5*(~mask_to_plot)), cmap='gray', vmin=0, vmax=255) | |
| ax.axis('off') | |
| ax.set_title(f'[{granum.id}]{image_name}\n{extra_caption}') | |
| granum_image = figure_to_pil(fig) | |
| grana_images[granum.id] = granum_image | |
| plt.close('all') | |
| return grana_images | |
| def format_data(self, grana_data: pd.DataFrame) -> pd.DataFrame: | |
| rounding_roles = {'area nm^2': 0, 'perimeter nm': 1, 'diameter nm': 1, 'height nm': 1, 'period nm': 1, 'period SD nm': 2, 'GSI':2, 'direction': 1} | |
| rounded_data = grana_data.round(rounding_roles) | |
| columns_order = ['Granum ID', 'File name', 'area nm^2', 'perimeter nm', 'GSI','diameter nm', 'height nm', 'Number of layers','period nm', 'period SD nm', 'direction'] | |
| return rounded_data[columns_order] | |
| def aggregate_data(self, grana_data: pd.DataFrame, confidence: Optional[float]=None) -> Dict: | |
| if confidence is None: | |
| filtered = grana_data | |
| else: | |
| filtered = grana_data.loc[grana_data['aggregated confidence'] >= confidence] | |
| aggregation = filtered[['area nm^2', 'perimeter nm', 'diameter nm', 'height nm', 'Number of layers', 'period nm', 'GSI']].mean().to_dict() | |
| aggregation_std = filtered[['area nm^2', 'perimeter nm', 'diameter nm', 'height nm', 'Number of layers', 'period nm', 'GSI']].std().to_dict() | |
| aggregation_std = {f"{k} std": v for k, v in aggregation_std.items()} | |
| aggregation_result = {**aggregation, **aggregation_std, 'N grana': len(filtered)} | |
| return aggregation_result | |
| def predict_on_single(self, image: Image.Image, scale: float, detection_confidence: float=0.25, granum_id_start=1, image_name: str = "") -> Tuple[List[Image.Image], pd.DataFrame, List[Image.Image]]: | |
| """ | |
| Predicts and aggregates data related to grana using a dictionary of images. | |
| Parameters: | |
| image (Image.Image): PIL Image object to be analyzed | |
| scale (float): scale of the image: px per nm. | |
| detection_confidence (float): The detection confidence threshold shape measurement | |
| Returns: | |
| Tuple[Image.Image, pandas.DataFrame, List[Image.Image]]: | |
| A tuple containing: | |
| - detection_visualization (Image.Image): PIL image representing | |
| the detection visualizations. | |
| - grana_data (pandas.DataFrame): A DataFrame containing the simulated granum data. | |
| - grana_images (List[Image.Image]): A list of PIL images of the grana. | |
| """ | |
| # convert to grayscale | |
| image = image.convert("L") | |
| # detect | |
| scaler = ScalerPadder(target_size=1024, target_short_edge_min=640) | |
| scaled_image = scaler.transform(image, px_per_nm=scale) | |
| detections = self.detector.predict(source=scaled_image, conf=detection_confidence)[0] | |
| # get grana data | |
| grana = self.get_grana_data(image, detections, scaler) | |
| for granum_id, granum in enumerate(grana, start=granum_id_start): | |
| granum.id = granum_id | |
| # visualize detections | |
| detection_visualization = self.visualize_detections(grana, image) | |
| # measure grana | |
| measurement_image_resize_factor = self.measurement_px_per_nm / scale | |
| measurement_image_shape = ( | |
| int(image.size[1]*measurement_image_resize_factor), | |
| int(image.size[0]*measurement_image_resize_factor) | |
| ) | |
| measurement_image = resize_with_cv2( # numpy image in scale valid for measurement | |
| image, measurement_image_shape, return_pil=False | |
| ) | |
| grana = self.measure_grana(grana, measurement_image) | |
| # pandas DataFrame | |
| grana_data = self.extract_grana_data(grana) | |
| # list of PIL images | |
| grana_images = self.generate_grana_images(grana, image_name=image_name) | |
| return detection_visualization, grana_data, grana_images | |
| def predict(self, images: Dict[str, Image.Image], scale: float, detection_confidence: float=0.25, parameter_confidence: Optional[float]=None) -> Tuple[List[Image.Image], pd.DataFrame, List[Image.Image], Dict]: | |
| """ | |
| Predicts and aggregates data related to grana using a dictionary of images. | |
| Parameters: | |
| images (Dict[str, Image.Image]): A dictionary of PIL Image objects to be analyzed, | |
| keyed by their names. | |
| scale (float): scale of the image: px per nm | |
| detection_confidence (float): The detection confidence threshold shape measurement | |
| parameter_confidence (float): The confidence threshold used for data aggregation. Only | |
| data with aggregated confidence above this threshold will | |
| be considered. | |
| Returns: | |
| Tuple[List[Image.Image], pandas.DataFrame, List[Image.Image], Dict]: | |
| A tuple containing: | |
| - detection_visualizations (List[Image.Image]): A list of PIL images representing | |
| the detection visualizations. | |
| - grana_data (pandas.DataFrame): A DataFrame containing the simulated granum data. | |
| - grana_images (List[Image.Image]): A list of PIL images of the grana. | |
| - aggregated_data (Dict): A dictionary containing the aggregated data results. | |
| """ | |
| detection_visualizations_all = {} | |
| grana_data_all = None | |
| grana_images_all = {} | |
| granum_id_start = 1 | |
| for image_name, image in images.items(): | |
| detection_visualization, grana_data, grana_images = self.predict_on_single(image, scale=scale, detection_confidence=detection_confidence, granum_id_start=granum_id_start, image_name=image_name) | |
| granum_id_start += len(grana_data) | |
| detection_visualizations_all[image_name] = detection_visualization | |
| grana_images_all.update(grana_images) | |
| grana_data['File name'] = image_name | |
| if grana_data_all is None: | |
| grana_data_all = grana_data | |
| else: | |
| # grana_data['Granum ID'] += len(grana_data_all) | |
| grana_data_all = pd.concat([grana_data_all, grana_data]) | |
| # dict | |
| # grana_data_all.to_csv('grana_data_all.csv', index=False) | |
| aggregated_data = self.aggregate_data(grana_data_all, parameter_confidence) | |
| formatted_grana_data = self.format_data(grana_data_all) | |
| return detection_visualizations_all, formatted_grana_data, grana_images_all, aggregated_data | |
| class GranaDetector(GranaAnalyser): | |
| """supplementary class for grana detection only | |
| """ | |
| def __init__(self, weights_detector: str, detector_config: Optional[str] = None, model_type="yolo") -> None: | |
| if model_type == "yolo": | |
| self.detector = YOLO(weights_detector) | |
| elif model_type == "mmdetection": | |
| self.detector = MMDetector(model=detector_config, weights=weights_detector) | |
| else: | |
| raise NotImplementedError() | |
| def predict_on_single(self, image: Image.Image, scale: float, detection_confidence: float=0.25, granum_id_start=1, use_scaling=True, granum_border_margin=1, granum_border_min_count=1, scaler_sizes=(1024, 640)) -> List[Granum]: | |
| # convert to grayscale | |
| image = image.convert("L") | |
| # detect | |
| if use_scaling: | |
| scaler = ScalerPadder(target_size=scaler_sizes[0], target_short_edge_min=scaler_sizes[1]) | |
| else: | |
| #dummy scaler | |
| scaler = ScalerPadder(target_size=max(image.size), target_short_edge_min=min(image.size), minimal_pad=0, pad_to_multiply=1) | |
| scaled_image = scaler.transform(image, scale=scale) | |
| detections = self.detector.predict(source=scaled_image, conf=detection_confidence)[0] | |
| # get grana data | |
| grana = self.get_grana_data(image, detections, scaler, border_margin=granum_border_margin, min_count=granum_border_min_count) | |
| for i_granum, granum in enumerate(grana, start=1): | |
| granum.id = i_granum | |
| return grana |