Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import os | |
| import matplotlib.pyplot as plt | |
| from torchvision.datasets import ImageFolder | |
| from torchvision import transforms | |
| import albumentations as A | |
| from pathlib import Path | |
| from torch.utils.data import Dataset, DataLoader | |
| from tqdm.auto import tqdm | |
| from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor | |
| import io | |
| from PIL import Image, ImageDraw, ImageFont | |
| from sklearn.cluster import KMeans | |
| import cv2 # 使用 OpenCV 来进行颜色空间的转换 | |
| id2color={0: [0, 0, 0], | |
| 1: [255, 204, 51], | |
| 2: [240, 120, 240], | |
| 3: [172, 196, 170], | |
| 4: [178, 80, 80], | |
| 5: [36, 179, 83], | |
| 6: [89, 89, 89], | |
| 7: [160, 146, 229], | |
| 8: [18, 17, 20], | |
| 9: [90, 98, 89], | |
| 10: [155, 149, 205], | |
| 11: [31, 73, 125], | |
| 12: [204, 153, 51], | |
| 13: [37, 12, 156], | |
| 14: [52, 209, 183], | |
| 15: [163, 160, 172], | |
| 16: [61, 245, 61], | |
| 17: [230, 203, 104], | |
| 18: [125, 104, 227], | |
| 19: [154, 150, 169], | |
| 20: [51, 221, 255], | |
| 21: [95, 95, 95], | |
| 22: [128, 128, 128], | |
| 23: [156, 239, 255], | |
| 24: [153, 102, 51], | |
| 25: [255, 106, 77], | |
| 26: [0, 0, 226], | |
| 27: [254, 242, 208], | |
| 28: [255, 191, 0], | |
| 29: [89, 134, 179], | |
| 30: [115, 51, 128], | |
| 31: [65, 112, 192], | |
| 32: [255, 0, 204], | |
| 33: [170, 240, 209], | |
| 34: [140, 120, 240], | |
| 35: [118, 255, 166], | |
| 36: [250, 250, 55], | |
| 37: [243, 232, 208], | |
| 38: [1, 118, 141], | |
| 39: [243, 241, 255], | |
| 40: [245, 147, 49], | |
| 41: [158, 108, 4], | |
| 42: [132, 0, 0]} | |
| id2label={0: 'Background', | |
| 1: 'Wood/Bamboo', | |
| 2: 'Ground tile', | |
| 3: 'Brick', | |
| 4: 'Cardboard/Paper', | |
| 5: 'Tree', | |
| 6: 'Roof tile', | |
| 7: 'Ceramic', | |
| 8: 'Chalkboard/Blackboard', | |
| 9: 'Asphalt', | |
| 10: 'Cement/ Concrete', | |
| 11: 'Composite decorative board', | |
| 12: 'Rammed earth', | |
| 13: 'Fabric/Cloth', | |
| 14: 'Water', | |
| 15: 'Windows with metal fences on the outside (distant view)', | |
| 16: 'Foliage', | |
| 17: 'Food', | |
| 18: 'Fur', | |
| 19: 'Pottery', | |
| 20: 'Glass', | |
| 21: 'Hair', | |
| 22: 'Roofing waterproof material', | |
| 23: 'Ice', | |
| 24: 'Leather', | |
| 25: 'Carved brick', | |
| 26: 'Metal', | |
| 27: 'Mirror', | |
| 28: 'Enamel', | |
| 29: 'Paint/ Coating/ Plaster', | |
| 30: 'Window screen', | |
| 31: 'Whiteboard', | |
| 32: 'Photograph/ Painting/ Airbrushed fabric', | |
| 33: 'Plastic, clear', | |
| 34: 'Plastic, non-clear', | |
| 35: 'Rubber/Latex', | |
| 36: 'Sand', | |
| 37: 'Skin/Lips', | |
| 38: 'Sky', | |
| 39: 'Snow', | |
| 40: 'Engineered Stone/ Imitation Stone', | |
| 41: 'Soil/Mud', | |
| 42: 'Natural Stone'} | |
| device =torch.device('cpu') | |
| model = Mask2FormerForUniversalSegmentation.from_pretrained("jinfengxie/BFMS_1014") | |
| model.to(device) | |
| model.eval() | |
| preprocessor = Mask2FormerImageProcessor(ignore_index=0, | |
| reduce_labels=False, | |
| do_resize=False, | |
| do_rescale=False, | |
| do_normalize=False) | |
| def predict_and_visualize(image): | |
| image_np = np.array(image) | |
| image = Image.fromarray(image_np) | |
| height, width, _ = image_np.shape | |
| new_width, new_height = width, height | |
| maxhl = max(height, width) | |
| if maxhl > 2000: | |
| scale_factor = 2000 / maxhl | |
| new_width = int(width * scale_factor) | |
| new_height = int(height * scale_factor) | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| # 保存调整后的图像数组 | |
| resized_image_np = np.array(image) | |
| image_tensor = torch.tensor(resized_image_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0 | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| pred_mask = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=[(new_height, new_width)]) | |
| pred_mask = pred_mask[0] | |
| pred_mask_np = pred_mask.numpy() | |
| color_mask = np.zeros((pred_mask_np.shape[0], pred_mask_np.shape[1], 3), dtype=np.uint8) | |
| label_positions = {} | |
| pixel_counts={} | |
| for key, value in id2color.items(): | |
| color_mask[pred_mask_np == key] = np.array(value) | |
| indices = np.where(pred_mask_np == key) | |
| if indices[0].size > 0: | |
| label_positions[key] = (np.mean(indices[1]), np.mean(indices[0])) | |
| pixel_counts[key]=indices[0].size | |
| result_image = Image.fromarray(color_mask) | |
| draw = ImageDraw.Draw(result_image) | |
| font = ImageFont.truetype("arial.ttf", int(min(result_image.size) / 30)) | |
| for key, position in label_positions.items(): | |
| if key in id2label: | |
| draw.text((position[0], position[1]), str(id2label[key]), font=font, fill='white') | |
| # 返回调整后的图像数组 | |
| return pred_mask_np, result_image, pixel_counts, resized_image_np | |
| def ext_colors(image_path, mask, n_clusters=4, v_weight_factor=1.0): | |
| # 读取图像并转换为 numpy 数组 | |
| image_np = np.array(image_path) | |
| mask_np = np.array(mask) | |
| # 获取掩码中的唯一类别 | |
| unique_classes = np.unique(mask_np) | |
| # 为每个类别提取颜色 | |
| colors_per_class = {} | |
| for cls in unique_classes: | |
| # 提取当前类别的像素点 | |
| indices = np.where(mask_np == cls) | |
| pixels = image_np[indices] | |
| # 如果像素点太少,则跳过 | |
| if pixels.shape[0] < n_clusters: | |
| continue | |
| # 将 RGB 颜色转换为 HSV | |
| pixels_hsv = cv2.cvtColor(pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV).reshape(-1, 3) | |
| # 计算权重,亮度 V 归一化并开平方 | |
| v_weights = v_weight_factor * np.sqrt(pixels_hsv[:, 2] / 255.0) | |
| # 使用带权重的 K-means 聚类来找到主要颜色 | |
| kmeans = KMeans(n_clusters=n_clusters, n_init=10) | |
| kmeans.fit(pixels_hsv, sample_weight=v_weights) | |
| dominant_colors_hsv = kmeans.cluster_centers_ | |
| # 颜色从 HSV 转换回 RGB 存储为整数值 | |
| # K-means返回float类型的中心点,需要转为uint8才能用于颜色空间转换 | |
| dominant_colors_hsv_uint8 = dominant_colors_hsv.astype(np.uint8) | |
| dominant_colors_rgb = cv2.cvtColor(dominant_colors_hsv_uint8.reshape(-1, 1, 3), cv2.COLOR_HSV2RGB).reshape(-1, 3) | |
| # 保存颜色 | |
| colors_per_class[cls] = dominant_colors_rgb | |
| return colors_per_class | |
| def plot_material_color_palette_grid(material_dict, materials_per_row=4): | |
| # Calculate total number of color rows and header rows needed | |
| total_rows = sum((len(colors) + 1) for colors in material_dict.values()) # +1 for the header row per material | |
| num_materials = len(material_dict) | |
| grid_rows = (num_materials + materials_per_row - 1) // materials_per_row | |
| total_grid_rows = 0 | |
| for i in range(grid_rows): | |
| row_materials = list(material_dict.keys())[i * materials_per_row:(i + 1) * materials_per_row] | |
| row_height = max(len(material_dict[mat]) for mat in row_materials if mat in material_dict) + 1 | |
| total_grid_rows += row_height | |
| # Set dimensions and spacing | |
| block_width = 1 | |
| block_height = 0.5 | |
| text_gap = 0.2 | |
| row_gap = 0.2 | |
| column_gap = 1.5 # Gap between material columns within the same row | |
| # Calculate figure width and height dynamically | |
| fig_width = materials_per_row * (block_width + text_gap + column_gap) | |
| fig_height = total_grid_rows * (block_height + row_gap) | |
| # Create a figure and a set of subplots | |
| fig, ax = plt.subplots(figsize=(fig_width, fig_height)) | |
| # Set the title of the figure | |
| #ax.set_title('Material Color Palette Grid') | |
| # Remove axes | |
| ax.axis('off') | |
| # Reverse the Y-axis to top-align the origin | |
| ax.invert_yaxis() | |
| current_row = 0 # Tracker for the current row position in the grid | |
| for i in range(grid_rows): | |
| row_materials = list(material_dict.keys())[i * materials_per_row:(i + 1) * materials_per_row] | |
| max_row_height = max(len(material_dict[mat]) for mat in row_materials if mat in material_dict) + 1 | |
| for j, material in enumerate(row_materials): | |
| if material not in material_dict: | |
| continue | |
| colors = material_dict[material] | |
| # Add a header for each material class | |
| ax.text(j * (block_width + text_gap + column_gap), current_row * (block_height + row_gap)+0.5, | |
| material, va='center', fontsize=12, fontweight='bold', ha='left') | |
| material_row_start = current_row | |
| for k, color in enumerate(colors): | |
| # Normalize the RGB values to [0, 1] for Matplotlib | |
| normalized_color = np.array(color) / 255.0 | |
| y_pos = (material_row_start + 1 + k) * (block_height + row_gap) | |
| # Draw a rectangle for each color | |
| ax.add_patch(plt.Rectangle((j * (block_width + text_gap + column_gap), y_pos), | |
| block_width, block_height, color=normalized_color)) | |
| # Annotate the RGB values to the right of each color block | |
| ax.text(j * (block_width + text_gap + column_gap) + block_width + text_gap, y_pos + block_height / 2, | |
| str(color), va='center', fontsize=10) | |
| current_row += max_row_height | |
| # Adjust plot limits | |
| ax.set_xlim(0, fig_width) | |
| ax.set_ylim(current_row * (block_height + row_gap), 0) | |
| # 保存到内存,而不是显示图像 | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| plt.close() | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| # 将matplotlib图转换为图像 | |
| def plt_to_image(): | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png',dpi=300) | |
| plt.close() | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| def calculate_slice_statistics(one_mask, slice_size=128): | |
| """计算每个切片的材质占比""" | |
| num_rows, num_cols = one_mask.shape[0] // slice_size, one_mask.shape[1] // slice_size | |
| slice_stats = {} | |
| for i in range(num_rows): | |
| for j in range(num_cols): | |
| slice_mask = one_mask[i*slice_size:(i+1)*slice_size, j*slice_size:(j+1)*slice_size] | |
| unique, counts = np.unique(slice_mask, return_counts=True) | |
| total_pixels = counts.sum() | |
| slice_stats[(i, j)] = {k: v / total_pixels for k, v in zip(unique, counts)} | |
| return slice_stats | |
| def find_top_slices(slice_stats, exclusion_list, min_percent=0.7, min_slices=1, top_k=3): | |
| """找出每个类材质占比最高的前三个切片,加入新的筛选条件""" | |
| from collections import defaultdict | |
| import heapq | |
| top_slices = defaultdict(list) | |
| for slice_pos, stats in slice_stats.items(): | |
| for material_id, percent in stats.items(): | |
| # 第一个判断:材质是否在排除列表中 | |
| if material_id in exclusion_list: | |
| continue | |
| # 第二个判断:材质占比是否至少为70% | |
| if percent < min_percent: | |
| continue | |
| # 将符合条件的切片添加到堆中 | |
| if len(top_slices[material_id]) < top_k: | |
| heapq.heappush(top_slices[material_id], (percent, slice_pos)) | |
| else: | |
| heapq.heappushpop(top_slices[material_id], (percent, slice_pos)) | |
| # 过滤出符合第三个条件的材质 | |
| valid_top_slices = {} | |
| for material_id, slices in top_slices.items(): | |
| if len(slices) > min_slices: # 至少有超过一个切片 | |
| valid_top_slices[material_id] = sorted(slices, reverse=True) | |
| return valid_top_slices | |
| def extract_and_visualize_top_slices(image, top_slices, slice_size=128): | |
| fig, axs = plt.subplots(nrows=len(top_slices), ncols=3, figsize=(15, 5 * len(top_slices))) | |
| image=Image.fromarray(image) | |
| if len(top_slices) == 1: | |
| axs = [axs] | |
| for idx, (material_id, slices) in enumerate(top_slices.items()): | |
| for col, (_, pos) in enumerate(slices): | |
| i, j = pos | |
| img_slice = image.crop((j * slice_size, i * slice_size, (j + 1) * slice_size, (i + 1) * slice_size)) | |
| axs[idx][col].imshow(img_slice) | |
| axs[idx][col].set_title(f'Material {id2label[material_id]} - Slice {pos}') | |
| axs[idx][col].axis('off') | |
| plt.tight_layout() | |
| # 保存到内存,而不是显示图像 | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| plt.close() | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| # main program | |
| def process_image(image_path, extract_colors): | |
| # 获取预测结果和调整后的图像 | |
| one_mask, color_mask, counts_dict, resized_image = predict_and_visualize(image_path) | |
| palette_image = None | |
| if extract_colors: | |
| # 使用调整后的图像进行颜色提取 | |
| colors_per_class = ext_colors(resized_image, one_mask, n_clusters=4) | |
| colors_per_label = {id2label[key]: value for key, value in colors_per_class.items()} | |
| labels_to_remove = ['Sky', 'Background/Unrecognized','Glass','Tree','Water','Plastic, clear'] | |
| colors_per_label = {key: value for key, value in colors_per_label.items() if key not in labels_to_remove} | |
| if colors_per_label: | |
| palette_image = plot_material_color_palette_grid(colors_per_label) | |
| plt.figure(figsize=(5, 5)) | |
| plt.imshow(color_mask) | |
| plt.tight_layout() | |
| plt.axis('off') | |
| color_mask_img = plt_to_image() | |
| counts_dict2 = {id2label[key]: value for key, value in counts_dict.items()} | |
| counts_df = pd.DataFrame(list(counts_dict2.items()), columns=['material', '计数']) | |
| total_count = counts_df['计数'].sum() | |
| counts_df['百分比'] = (counts_df['计数'] / total_count * 100).round(2) | |
| percentage_df = counts_df.rename(columns={'计数': 'pixels', '百分比': 'percentage (%)'}) | |
| percentage_df = percentage_df[['material', 'percentage (%)']] | |
| slice_size = 64 | |
| exclusion_list = [38] | |
| slice_stats = calculate_slice_statistics(one_mask, slice_size=slice_size) | |
| top_slices = find_top_slices(slice_stats, exclusion_list=exclusion_list, min_percent=0.5, min_slices=1) | |
| slice_image = extract_and_visualize_top_slices(resized_image, top_slices, slice_size=slice_size) | |
| return color_mask_img, palette_image, slice_image, percentage_df | |
| with gr.Blocks(title="Building Facade Material Segmentation") as iface: | |
| gr.Markdown("<h1>Building Facade Material Segmentation</h1>") | |
| gr.Markdown("Upload an image to segment materials. Images are compressed and mini version of the model is applied due to limited computing power, which may result in inaccuracies.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil") | |
| extract_colors_checkbox = gr.Checkbox(label="Extract Color Palette", value=False) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Examples( | |
| examples=[ | |
| ["example1.jpg", False], | |
| ["example2.jpg", False] | |
| ], | |
| inputs=[image_input, extract_colors_checkbox], | |
| label="Click an example to run" | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("<h2>Results</h2>") | |
| with gr.Row(): | |
| color_mask_output = gr.Image(type="pil", label="Color Mask") | |
| palette_output = gr.Image(type="pil", label="Color Palette") | |
| with gr.Row(): | |
| slice_image_output = gr.Image(type='pil', label='Texture Slices') | |
| df_output = gr.DataFrame(label="Material Percentages") | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[image_input, extract_colors_checkbox], | |
| outputs=[color_mask_output, palette_output, slice_image_output, df_output] | |
| ) | |
| iface.launch() |