jinfengxie's picture
Update app.py
7581d53 verified
import pandas as pd
import numpy as np
import gradio as gr
import torch
import os
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torchvision import transforms
import albumentations as A
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
import io
from PIL import Image, ImageDraw, ImageFont
from sklearn.cluster import KMeans
import cv2 # 使用 OpenCV 来进行颜色空间的转换
id2color={0: [0, 0, 0],
1: [255, 204, 51],
2: [240, 120, 240],
3: [172, 196, 170],
4: [178, 80, 80],
5: [36, 179, 83],
6: [89, 89, 89],
7: [160, 146, 229],
8: [18, 17, 20],
9: [90, 98, 89],
10: [155, 149, 205],
11: [31, 73, 125],
12: [204, 153, 51],
13: [37, 12, 156],
14: [52, 209, 183],
15: [163, 160, 172],
16: [61, 245, 61],
17: [230, 203, 104],
18: [125, 104, 227],
19: [154, 150, 169],
20: [51, 221, 255],
21: [95, 95, 95],
22: [128, 128, 128],
23: [156, 239, 255],
24: [153, 102, 51],
25: [255, 106, 77],
26: [0, 0, 226],
27: [254, 242, 208],
28: [255, 191, 0],
29: [89, 134, 179],
30: [115, 51, 128],
31: [65, 112, 192],
32: [255, 0, 204],
33: [170, 240, 209],
34: [140, 120, 240],
35: [118, 255, 166],
36: [250, 250, 55],
37: [243, 232, 208],
38: [1, 118, 141],
39: [243, 241, 255],
40: [245, 147, 49],
41: [158, 108, 4],
42: [132, 0, 0]}
id2label={0: 'Background',
1: 'Wood/Bamboo',
2: 'Ground tile',
3: 'Brick',
4: 'Cardboard/Paper',
5: 'Tree',
6: 'Roof tile',
7: 'Ceramic',
8: 'Chalkboard/Blackboard',
9: 'Asphalt',
10: 'Cement/ Concrete',
11: 'Composite decorative board',
12: 'Rammed earth',
13: 'Fabric/Cloth',
14: 'Water',
15: 'Windows with metal fences on the outside (distant view)',
16: 'Foliage',
17: 'Food',
18: 'Fur',
19: 'Pottery',
20: 'Glass',
21: 'Hair',
22: 'Roofing waterproof material',
23: 'Ice',
24: 'Leather',
25: 'Carved brick',
26: 'Metal',
27: 'Mirror',
28: 'Enamel',
29: 'Paint/ Coating/ Plaster',
30: 'Window screen',
31: 'Whiteboard',
32: 'Photograph/ Painting/ Airbrushed fabric',
33: 'Plastic, clear',
34: 'Plastic, non-clear',
35: 'Rubber/Latex',
36: 'Sand',
37: 'Skin/Lips',
38: 'Sky',
39: 'Snow',
40: 'Engineered Stone/ Imitation Stone',
41: 'Soil/Mud',
42: 'Natural Stone'}
device =torch.device('cpu')
model = Mask2FormerForUniversalSegmentation.from_pretrained("jinfengxie/BFMS_1014")
model.to(device)
model.eval()
preprocessor = Mask2FormerImageProcessor(ignore_index=0,
reduce_labels=False,
do_resize=False,
do_rescale=False,
do_normalize=False)
def predict_and_visualize(image):
image_np = np.array(image)
image = Image.fromarray(image_np)
height, width, _ = image_np.shape
new_width, new_height = width, height
maxhl = max(height, width)
if maxhl > 2000:
scale_factor = 2000 / maxhl
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
image = image.resize((new_width, new_height), Image.LANCZOS)
# 保存调整后的图像数组
resized_image_np = np.array(image)
image_tensor = torch.tensor(resized_image_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
with torch.no_grad():
outputs = model(image_tensor)
pred_mask = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=[(new_height, new_width)])
pred_mask = pred_mask[0]
pred_mask_np = pred_mask.numpy()
color_mask = np.zeros((pred_mask_np.shape[0], pred_mask_np.shape[1], 3), dtype=np.uint8)
label_positions = {}
pixel_counts={}
for key, value in id2color.items():
color_mask[pred_mask_np == key] = np.array(value)
indices = np.where(pred_mask_np == key)
if indices[0].size > 0:
label_positions[key] = (np.mean(indices[1]), np.mean(indices[0]))
pixel_counts[key]=indices[0].size
result_image = Image.fromarray(color_mask)
draw = ImageDraw.Draw(result_image)
font = ImageFont.truetype("arial.ttf", int(min(result_image.size) / 30))
for key, position in label_positions.items():
if key in id2label:
draw.text((position[0], position[1]), str(id2label[key]), font=font, fill='white')
# 返回调整后的图像数组
return pred_mask_np, result_image, pixel_counts, resized_image_np
def ext_colors(image_path, mask, n_clusters=4, v_weight_factor=1.0):
# 读取图像并转换为 numpy 数组
image_np = np.array(image_path)
mask_np = np.array(mask)
# 获取掩码中的唯一类别
unique_classes = np.unique(mask_np)
# 为每个类别提取颜色
colors_per_class = {}
for cls in unique_classes:
# 提取当前类别的像素点
indices = np.where(mask_np == cls)
pixels = image_np[indices]
# 如果像素点太少,则跳过
if pixels.shape[0] < n_clusters:
continue
# 将 RGB 颜色转换为 HSV
pixels_hsv = cv2.cvtColor(pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV).reshape(-1, 3)
# 计算权重,亮度 V 归一化并开平方
v_weights = v_weight_factor * np.sqrt(pixels_hsv[:, 2] / 255.0)
# 使用带权重的 K-means 聚类来找到主要颜色
kmeans = KMeans(n_clusters=n_clusters, n_init=10)
kmeans.fit(pixels_hsv, sample_weight=v_weights)
dominant_colors_hsv = kmeans.cluster_centers_
# 颜色从 HSV 转换回 RGB 存储为整数值
# K-means返回float类型的中心点,需要转为uint8才能用于颜色空间转换
dominant_colors_hsv_uint8 = dominant_colors_hsv.astype(np.uint8)
dominant_colors_rgb = cv2.cvtColor(dominant_colors_hsv_uint8.reshape(-1, 1, 3), cv2.COLOR_HSV2RGB).reshape(-1, 3)
# 保存颜色
colors_per_class[cls] = dominant_colors_rgb
return colors_per_class
def plot_material_color_palette_grid(material_dict, materials_per_row=4):
# Calculate total number of color rows and header rows needed
total_rows = sum((len(colors) + 1) for colors in material_dict.values()) # +1 for the header row per material
num_materials = len(material_dict)
grid_rows = (num_materials + materials_per_row - 1) // materials_per_row
total_grid_rows = 0
for i in range(grid_rows):
row_materials = list(material_dict.keys())[i * materials_per_row:(i + 1) * materials_per_row]
row_height = max(len(material_dict[mat]) for mat in row_materials if mat in material_dict) + 1
total_grid_rows += row_height
# Set dimensions and spacing
block_width = 1
block_height = 0.5
text_gap = 0.2
row_gap = 0.2
column_gap = 1.5 # Gap between material columns within the same row
# Calculate figure width and height dynamically
fig_width = materials_per_row * (block_width + text_gap + column_gap)
fig_height = total_grid_rows * (block_height + row_gap)
# Create a figure and a set of subplots
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
# Set the title of the figure
#ax.set_title('Material Color Palette Grid')
# Remove axes
ax.axis('off')
# Reverse the Y-axis to top-align the origin
ax.invert_yaxis()
current_row = 0 # Tracker for the current row position in the grid
for i in range(grid_rows):
row_materials = list(material_dict.keys())[i * materials_per_row:(i + 1) * materials_per_row]
max_row_height = max(len(material_dict[mat]) for mat in row_materials if mat in material_dict) + 1
for j, material in enumerate(row_materials):
if material not in material_dict:
continue
colors = material_dict[material]
# Add a header for each material class
ax.text(j * (block_width + text_gap + column_gap), current_row * (block_height + row_gap)+0.5,
material, va='center', fontsize=12, fontweight='bold', ha='left')
material_row_start = current_row
for k, color in enumerate(colors):
# Normalize the RGB values to [0, 1] for Matplotlib
normalized_color = np.array(color) / 255.0
y_pos = (material_row_start + 1 + k) * (block_height + row_gap)
# Draw a rectangle for each color
ax.add_patch(plt.Rectangle((j * (block_width + text_gap + column_gap), y_pos),
block_width, block_height, color=normalized_color))
# Annotate the RGB values to the right of each color block
ax.text(j * (block_width + text_gap + column_gap) + block_width + text_gap, y_pos + block_height / 2,
str(color), va='center', fontsize=10)
current_row += max_row_height
# Adjust plot limits
ax.set_xlim(0, fig_width)
ax.set_ylim(current_row * (block_height + row_gap), 0)
# 保存到内存,而不是显示图像
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
img = Image.open(buf)
return img
# 将matplotlib图转换为图像
def plt_to_image():
buf = io.BytesIO()
plt.savefig(buf, format='png',dpi=300)
plt.close()
buf.seek(0)
img = Image.open(buf)
return img
def calculate_slice_statistics(one_mask, slice_size=128):
"""计算每个切片的材质占比"""
num_rows, num_cols = one_mask.shape[0] // slice_size, one_mask.shape[1] // slice_size
slice_stats = {}
for i in range(num_rows):
for j in range(num_cols):
slice_mask = one_mask[i*slice_size:(i+1)*slice_size, j*slice_size:(j+1)*slice_size]
unique, counts = np.unique(slice_mask, return_counts=True)
total_pixels = counts.sum()
slice_stats[(i, j)] = {k: v / total_pixels for k, v in zip(unique, counts)}
return slice_stats
def find_top_slices(slice_stats, exclusion_list, min_percent=0.7, min_slices=1, top_k=3):
"""找出每个类材质占比最高的前三个切片,加入新的筛选条件"""
from collections import defaultdict
import heapq
top_slices = defaultdict(list)
for slice_pos, stats in slice_stats.items():
for material_id, percent in stats.items():
# 第一个判断:材质是否在排除列表中
if material_id in exclusion_list:
continue
# 第二个判断:材质占比是否至少为70%
if percent < min_percent:
continue
# 将符合条件的切片添加到堆中
if len(top_slices[material_id]) < top_k:
heapq.heappush(top_slices[material_id], (percent, slice_pos))
else:
heapq.heappushpop(top_slices[material_id], (percent, slice_pos))
# 过滤出符合第三个条件的材质
valid_top_slices = {}
for material_id, slices in top_slices.items():
if len(slices) > min_slices: # 至少有超过一个切片
valid_top_slices[material_id] = sorted(slices, reverse=True)
return valid_top_slices
def extract_and_visualize_top_slices(image, top_slices, slice_size=128):
fig, axs = plt.subplots(nrows=len(top_slices), ncols=3, figsize=(15, 5 * len(top_slices)))
image=Image.fromarray(image)
if len(top_slices) == 1:
axs = [axs]
for idx, (material_id, slices) in enumerate(top_slices.items()):
for col, (_, pos) in enumerate(slices):
i, j = pos
img_slice = image.crop((j * slice_size, i * slice_size, (j + 1) * slice_size, (i + 1) * slice_size))
axs[idx][col].imshow(img_slice)
axs[idx][col].set_title(f'Material {id2label[material_id]} - Slice {pos}')
axs[idx][col].axis('off')
plt.tight_layout()
# 保存到内存,而不是显示图像
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
img = Image.open(buf)
return img
# main program
def process_image(image_path, extract_colors):
# 获取预测结果和调整后的图像
one_mask, color_mask, counts_dict, resized_image = predict_and_visualize(image_path)
palette_image = None
if extract_colors:
# 使用调整后的图像进行颜色提取
colors_per_class = ext_colors(resized_image, one_mask, n_clusters=4)
colors_per_label = {id2label[key]: value for key, value in colors_per_class.items()}
labels_to_remove = ['Sky', 'Background/Unrecognized','Glass','Tree','Water','Plastic, clear']
colors_per_label = {key: value for key, value in colors_per_label.items() if key not in labels_to_remove}
if colors_per_label:
palette_image = plot_material_color_palette_grid(colors_per_label)
plt.figure(figsize=(5, 5))
plt.imshow(color_mask)
plt.tight_layout()
plt.axis('off')
color_mask_img = plt_to_image()
counts_dict2 = {id2label[key]: value for key, value in counts_dict.items()}
counts_df = pd.DataFrame(list(counts_dict2.items()), columns=['material', '计数'])
total_count = counts_df['计数'].sum()
counts_df['百分比'] = (counts_df['计数'] / total_count * 100).round(2)
percentage_df = counts_df.rename(columns={'计数': 'pixels', '百分比': 'percentage (%)'})
percentage_df = percentage_df[['material', 'percentage (%)']]
slice_size = 64
exclusion_list = [38]
slice_stats = calculate_slice_statistics(one_mask, slice_size=slice_size)
top_slices = find_top_slices(slice_stats, exclusion_list=exclusion_list, min_percent=0.5, min_slices=1)
slice_image = extract_and_visualize_top_slices(resized_image, top_slices, slice_size=slice_size)
return color_mask_img, palette_image, slice_image, percentage_df
with gr.Blocks(title="Building Facade Material Segmentation") as iface:
gr.Markdown("<h1>Building Facade Material Segmentation</h1>")
gr.Markdown("Upload an image to segment materials. Images are compressed and mini version of the model is applied due to limited computing power, which may result in inaccuracies.")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil")
extract_colors_checkbox = gr.Checkbox(label="Extract Color Palette", value=False)
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column(scale=2):
gr.Examples(
examples=[
["example1.jpg", False],
["example2.jpg", False]
],
inputs=[image_input, extract_colors_checkbox],
label="Click an example to run"
)
gr.Markdown("---")
gr.Markdown("<h2>Results</h2>")
with gr.Row():
color_mask_output = gr.Image(type="pil", label="Color Mask")
palette_output = gr.Image(type="pil", label="Color Palette")
with gr.Row():
slice_image_output = gr.Image(type='pil', label='Texture Slices')
df_output = gr.DataFrame(label="Material Percentages")
submit_btn.click(
fn=process_image,
inputs=[image_input, extract_colors_checkbox],
outputs=[color_mask_output, palette_output, slice_image_output, df_output]
)
iface.launch()