Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision.transforms import Compose | |
| import shutil | |
| import os | |
| from depthAnything.depth_anything.dpt import DepthAnything | |
| from depthAnything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet | |
| from TEED.main import parse_args, main | |
| # 深度处理函数 | |
| def depth_anything_image(image, encoder='vitl', pred_only=True, grayscale=True): | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model_configs = { | |
| 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
| 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
| 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]} | |
| } | |
| depth_anything = DepthAnything(model_configs[encoder]) | |
| depth_anything.load_state_dict(torch.load(f'./checkpoints/depth_anything_{encoder}14.pth')) | |
| depth_anything = depth_anything.to(DEVICE).eval() | |
| transform = Compose([ | |
| Resize(width=518, height=518, resize_target=False, keep_aspect_ratio=True, | |
| ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC), | |
| NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| PrepareForNet(), | |
| ]) | |
| raw_image = np.array(image.convert('RGB'))[:, :, ::-1].copy() # RGB to BGR | |
| h, w = raw_image.shape[:2] | |
| image_tensor = transform({'image': raw_image / 255.0})['image'] | |
| image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| depth = depth_anything(image_tensor) | |
| depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0] | |
| depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
| depth = depth.cpu().numpy().astype(np.uint8) | |
| return np.repeat(depth[..., np.newaxis], 3, axis=-1) if grayscale else cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO) | |
| # TEED 图像处理函数 | |
| def teed_process_image(image): | |
| os.makedirs('./output/teed_imgs', exist_ok=True) | |
| os.makedirs('./teed_tmp', exist_ok=True) | |
| temp_image_path = './teed_tmp/temp_image.png' | |
| cv2.imwrite(temp_image_path, np.array(image)) | |
| args, train_info = parse_args(is_testing=True, pl_opt_dir='./output/teed_imgs') | |
| args.input_val_dir = './teed_tmp' | |
| args.output_dir = './output/teed_imgs' | |
| checkpoint_path = './TEED/checkpoints/BIPED/5/5_model.pth' | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
| args.checkpoint_data = checkpoint_path # 确保使用正确的路径 | |
| main(args, train_info) | |
| shutil.rmtree('./teed_tmp') | |
| return cv2.imread(os.path.join('./output/teed_imgs', 'processed_image.png')) | |
| # 处理单个图像 | |
| def process_single_image(image): | |
| depth_result = depth_anything_image(image, 'vitl') | |
| teed_result = teed_process_image(image) | |
| merged_result = multiply_blend(depth_result, teed_result) | |
| return merged_result | |
| # Gradio 界面处理函数 | |
| def gradio_process_line(img): | |
| processed_image = process_single_image(img) | |
| return Image.fromarray(processed_image) | |
| # Gradio 界面 | |
| import gradio as gr | |
| iface = gr.Interface( | |
| fn=gradio_process_line, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Image(type="pil"), | |
| title="Image Processing with Depth Anything and TEED", | |
| description="Upload an image to process it with depth estimation and edge detection." | |
| ) | |
| # 启动 Gradio 应用 | |
| iface.launch(share=True) | |