garment_print_extractor / delete_bg.py
Anna Beloborodova
upd delete_bg
3011b82
from sklearn.cluster import KMeans
import cv2
import numpy as np
import os
import shutil
def remove_background(image_path, n_clusters=10, edge_threshold=0.05, bg_ratio_threshold=1.5):
"""
Удаление фона с изображения принта на одежде
Args:
image_path: Путь к изображению
n_clusters: Количество кластеров для k-means
edge_threshold: Процент краёв изображения для определения фона
bg_ratio_threshold: Пороговое отношение частоты цвета на краях к общей частоте
Returns:
Изображение с прозрачным фоном (RGBA)
"""
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
height, width, _ = img.shape
edge_mask = np.zeros((height, width), dtype=bool)
edge_size = int(min(height, width) * edge_threshold)
edge_mask[:edge_size, :] = True
edge_mask[-edge_size:, :] = True
edge_mask[:, :edge_size] = True
edge_mask[:, -edge_size:] = True
pixels = img.reshape(-1, 3)
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans.fit(pixels)
labels = kmeans.labels_
segmented_img = labels.reshape(height, width)
total_counts = np.bincount(labels, minlength=n_clusters)
edge_counts = np.bincount(
labels[edge_mask.reshape(-1)],
minlength=n_clusters
)
edge_to_total_ratio = np.zeros(n_clusters)
for i in range(n_clusters):
if total_counts[i] > 0:
edge_to_total_ratio[i] = edge_counts[i] / total_counts[i]
background_clusters = np.where(edge_to_total_ratio > bg_ratio_threshold * np.mean(edge_to_total_ratio))[0]
if len(background_clusters) == 0:
background_clusters = [np.argmax(edge_to_total_ratio)]
result_mask = np.ones((height, width), dtype=np.uint8) * 255
for cluster_id in background_clusters:
result_mask[segmented_img == cluster_id] = 0
result_mask = cv2.GaussianBlur(result_mask, (5, 5), 0)
result = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
result = cv2.cvtColor(result, cv2.COLOR_RGB2RGBA)
#result = img
result[:, :, 3] = result_mask
print(f"Обработка {os.path.basename(image_path)}:")
print(f"Всего кластеров: {n_clusters}")
print(f"Отношения край/всего: {edge_to_total_ratio}")
print(f"Фоновые кластеры: {background_clusters}")
return result