from PIL import Image import cv2 import numpy as np import torch import torch.nn.functional as F from torchvision.transforms import Compose import shutil import os from depthAnything.depth_anything.dpt import DepthAnything from depthAnything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet from TEED.main import parse_args, main # 深度处理函数 def depth_anything_image(image, encoder='vitl', pred_only=True, grayscale=True): DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' model_configs = { 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]} } depth_anything = DepthAnything(model_configs[encoder]) depth_anything.load_state_dict(torch.load(f'./checkpoints/depth_anything_{encoder}14.pth')) depth_anything = depth_anything.to(DEVICE).eval() transform = Compose([ Resize(width=518, height=518, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), PrepareForNet(), ]) raw_image = np.array(image.convert('RGB'))[:, :, ::-1].copy() # RGB to BGR h, w = raw_image.shape[:2] image_tensor = transform({'image': raw_image / 255.0})['image'] image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(DEVICE) with torch.no_grad(): depth = depth_anything(image_tensor) depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0] depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth = depth.cpu().numpy().astype(np.uint8) return np.repeat(depth[..., np.newaxis], 3, axis=-1) if grayscale else cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO) # TEED 图像处理函数 def teed_process_image(image): os.makedirs('./output/teed_imgs', exist_ok=True) os.makedirs('./teed_tmp', exist_ok=True) temp_image_path = './teed_tmp/temp_image.png' cv2.imwrite(temp_image_path, np.array(image)) args, train_info = parse_args(is_testing=True, pl_opt_dir='./output/teed_imgs') args.input_val_dir = './teed_tmp' args.output_dir = './output/teed_imgs' checkpoint_path = './TEED/checkpoints/BIPED/5/5_model.pth' if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") args.checkpoint_data = checkpoint_path # 确保使用正确的路径 main(args, train_info) shutil.rmtree('./teed_tmp') return cv2.imread(os.path.join('./output/teed_imgs', 'processed_image.png')) # 处理单个图像 def process_single_image(image): depth_result = depth_anything_image(image, 'vitl') teed_result = teed_process_image(image) merged_result = multiply_blend(depth_result, teed_result) return merged_result # Gradio 界面处理函数 def gradio_process_line(img): processed_image = process_single_image(img) return Image.fromarray(processed_image) # Gradio 界面 import gradio as gr iface = gr.Interface( fn=gradio_process_line, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Image Processing with Depth Anything and TEED", description="Upload an image to process it with depth estimation and edge detection." ) # 启动 Gradio 应用 iface.launch(share=True)