ImageProcessing / app.py
plutosss's picture
Update app.py
d25393a verified
raw
history blame
3.61 kB
from PIL import Image
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
import shutil
import os
from depthAnything.depth_anything.dpt import DepthAnything
from depthAnything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
from TEED.main import parse_args, main
# 深度处理函数
def depth_anything_image(image, encoder='vitl', pred_only=True, grayscale=True):
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_configs = {
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
}
depth_anything = DepthAnything(model_configs[encoder])
depth_anything.load_state_dict(torch.load(f'./checkpoints/depth_anything_{encoder}14.pth'))
depth_anything = depth_anything.to(DEVICE).eval()
transform = Compose([
Resize(width=518, height=518, resize_target=False, keep_aspect_ratio=True,
ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
raw_image = np.array(image.convert('RGB'))[:, :, ::-1].copy() # RGB to BGR
h, w = raw_image.shape[:2]
image_tensor = transform({'image': raw_image / 255.0})['image']
image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(DEVICE)
with torch.no_grad():
depth = depth_anything(image_tensor)
depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.cpu().numpy().astype(np.uint8)
return np.repeat(depth[..., np.newaxis], 3, axis=-1) if grayscale else cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
# TEED 图像处理函数
def teed_process_image(image):
os.makedirs('./output/teed_imgs', exist_ok=True)
os.makedirs('./teed_tmp', exist_ok=True)
temp_image_path = './teed_tmp/temp_image.png'
cv2.imwrite(temp_image_path, np.array(image))
args, train_info = parse_args(is_testing=True, pl_opt_dir='./output/teed_imgs')
args.input_val_dir = './teed_tmp'
args.output_dir = './output/teed_imgs'
checkpoint_path = './TEED/checkpoints/BIPED/5/5_model.pth'
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
args.checkpoint_data = checkpoint_path # 确保使用正确的路径
main(args, train_info)
shutil.rmtree('./teed_tmp')
return cv2.imread(os.path.join('./output/teed_imgs', 'processed_image.png'))
# 处理单个图像
def process_single_image(image):
depth_result = depth_anything_image(image, 'vitl')
teed_result = teed_process_image(image)
merged_result = multiply_blend(depth_result, teed_result)
return merged_result
# Gradio 界面处理函数
def gradio_process_line(img):
processed_image = process_single_image(img)
return Image.fromarray(processed_image)
# Gradio 界面
import gradio as gr
iface = gr.Interface(
fn=gradio_process_line,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Image Processing with Depth Anything and TEED",
description="Upload an image to process it with depth estimation and edge detection."
)
# 启动 Gradio 应用
iface.launch(share=True)