ImageProcessing / app.py
plutosss's picture
Update app.py
71bfff1 verified
raw
history blame
6.89 kB
from PIL import Image
import cv2
import numpy as np
import os
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
from tqdm import tqdm
import TEED.main as teed
from TEED.main import parse_args
from depthAnything.depth_anything.dpt import DepthAnything
from depthAnything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
import shutil
def multiply_blend(image1, image2):
# 确保 image2 具有与 image1 相同的形状
image2 = np.stack((image2,) * 3, axis=-1)
# 执行混合操作
multiplied = np.multiply(image1 / 255.0, image2 / 255.0) * 255.0
return multiplied.astype(np.uint8)
def screen_blend(image1, image2):
image1 = image1.astype(float)
image2 = image2.astype(float)
screened = 1 - (1 - image1 / 255) * (1 - image2 / 255) * 255
result = np.clip(screened, 0, 255).astype('uint8')
return result
def erosion(img, kernel_size=3, iterations=1, dilate=False):
if len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
kernel = np.ones((kernel_size, kernel_size), np.uint8)
if dilate:
img = cv2.dilate(img, kernel, iterations=iterations)
else:
img = cv2.erode(img, kernel, iterations=iterations)
return img
def teed_imgs(img_path='./input', outdir='./output/teed_imgs', gaussianBlur=[0, 3, 0]):
os.makedirs(outdir, exist_ok=True)
os.makedirs('teed_tmp', exist_ok=True)
if os.path.isfile(img_path):
img = cv2.imread(img_path)
if gaussianBlur[0] != 0:
img = cv2.GaussianBlur(img, (gaussianBlur[1], gaussianBlur[1]), gaussianBlur[2])
cv2.imwrite(os.path.join('teed_tmp', 'temp_image.png'), img)
else:
cv2.imwrite(os.path.join('teed_tmp', 'temp_image.png'), img)
args, train_info = parse_args(is_testing=True, pl_opt_dir=outdir)
args.input_val_dir = 'teed_tmp'
teed.main(args, train_info)
shutil.rmtree('teed_tmp')
def depth_anything(img_path='./input', outdir='./output/depth_anything', encoder='vitl', pred_only=True, grayscale=True):
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_configs = {
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
}
depth_anything = DepthAnything(model_configs[encoder])
depth_anything.load_state_dict(torch.load('./checkpoints/depth_anything_{}14.pth'.format(encoder)))
depth_anything = depth_anything.to(DEVICE).eval()
transform = Compose([
Resize(width=518, height=518, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
if os.path.isfile(img_path):
raw_image = cv2.imread(img_path)
image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
h, w = image.shape[:2]
image = transform({'image': image})['image']
image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
depth = depth_anything(image)
depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.cpu().numpy().astype(np.uint8)
if grayscale:
depth = np.repeat(depth[..., np.newaxis], 3, axis=-1)
else:
depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
filename = os.path.basename(img_path)
cv2.imwrite(os.path.join(outdir, filename[:filename.rfind('.')] + '_depth.png'), depth)
def process_line(img_path='./input', outdir='./output'):
depth_anything(img_path, os.path.join(outdir, "depth_anything"))
teed_imgs(img_path, os.path.join(outdir, "teed_imgs"), [1, 7, 2])
teed_imgs(os.path.join(outdir, "depth_anything"), os.path.join(outdir, "dp_teed_imgs"), [0, 7, 2])
merge_images_in_2_folder(
os.path.join(outdir, "teed_imgs"),
os.path.join(outdir, "dp_teed_imgs"),
os.path.join(outdir, "merged_imgs"),
'_depth',
1,
'multiply',
[[2, 0], [2, 1]],
[1, 0]
)
def merge_2_images(img1, img2, mode, erosion_para=[[0, 0], [0, 0]], dilate=[0, 0]):
img1 = cv2.imread(img1)
img2 = cv2.imread(img2)
img1 = cv2.resize(img1, (img2.shape[1], img2.shape[0]))
if erosion_para[0][1] != 0:
img1 = erosion(img1, erosion_para[0][0], erosion_para[0][1], dilate[0])
if erosion_para[1][1] != 0:
img2 = erosion(img2, erosion_para[1][0], erosion_para[1][1], dilate[1])
if mode == 'multiply':
return multiply_blend(img1, img2)
elif mode == 'screen':
return screen_blend(img1, img2)
def merge_images_in_2_folder(folder1, folder2, outdir, suffix_need_remove=None, suffix_floder=0, mode='multiply', erosion_para=[[0, 0], [0, 0]], dilate=[0, 0]):
os.makedirs(outdir, exist_ok=True)
name_extension_pairs_folder1 = [os.path.splitext(filename) for filename in os.listdir(folder1) if filename.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', 'tif'))]
filenames_noext_folder1, extensions_folder1 = zip(*name_extension_pairs_folder1)
name_extension_pairs_folder2 = [os.path.splitext(filename) for filename in os.listdir(folder2) if filename.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', 'tif'))]
filenames_noext_folder2, extensions_folder2 = zip(*name_extension_pairs_folder2)
for index, filename in enumerate(filenames_noext_folder1):
if filename in filenames_noext_folder2:
img1 = os.path.join(folder1, filename + extensions_folder1[index])
img2 = os.path.join(folder2, filename + extensions_folder2[filenames_noext_folder2.index(filename)])
result = merge_2_images(img1, img2, mode, erosion_para, dilate)
cv2.imwrite(os.path.join(outdir, filename + extensions_folder1[index]), result)
if __name__ == '__main__':
import gradio as gr
def gradio_process_line(img):
img_path = './temp_input.png'
img.save(img_path)
process_line(img_path, './output')
output_image_path = './output/merged_imgs/temp_input.png' # 更新为实际输出路径
return Image.open(output_image_path)
iface = gr.Interface(
fn=gradio_process_line,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Image Processing with Depth Anything and TEED",
description="Upload an image to process it with depth estimation and edge detection."
)
iface.launch()