plutosss commited on
Commit
cbc569c
·
verified ·
1 Parent(s): 7cd04f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -6,11 +6,11 @@ import torch.nn.functional as F
6
  from torchvision.transforms import Compose
7
  import shutil
8
  import os
9
- import teed
10
 
11
  from depthAnything.depth_anything.dpt import DepthAnything
12
  from depthAnything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
13
- from TEED.main import parse_args
14
 
15
  # 深度处理函数
16
  def depth_anything_image(image, encoder='vitl', pred_only=True, grayscale=True):
@@ -54,13 +54,14 @@ def teed_process_image(image):
54
  temp_image_path = './teed_tmp/temp_image.png'
55
  cv2.imwrite(temp_image_path, np.array(image))
56
 
57
- # 获取解析后的参数
58
  args, train_info = parse_args(is_testing=True, pl_opt_dir='./output/teed_imgs')
59
- args.input_val_dir = './teed_tmp' # 临时目录
60
- args.output_dir = './output/teed_imgs' # 输出目录
 
 
 
61
 
62
- # 调用 TEED 主函数进行处理
63
- teed.main(args, train_info)
64
 
65
  shutil.rmtree('./teed_tmp')
66
  return cv2.imread(os.path.join('./output/teed_imgs', 'processed_image.png'))
@@ -90,4 +91,4 @@ iface = gr.Interface(
90
  )
91
 
92
  # 启动 Gradio 应用
93
- iface.launch()
 
6
  from torchvision.transforms import Compose
7
  import shutil
8
  import os
9
+
10
 
11
  from depthAnything.depth_anything.dpt import DepthAnything
12
  from depthAnything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
13
+ from TEED.main import parse_args, main
14
 
15
  # 深度处理函数
16
  def depth_anything_image(image, encoder='vitl', pred_only=True, grayscale=True):
 
54
  temp_image_path = './teed_tmp/temp_image.png'
55
  cv2.imwrite(temp_image_path, np.array(image))
56
 
 
57
  args, train_info = parse_args(is_testing=True, pl_opt_dir='./output/teed_imgs')
58
+ args.input_val_dir = './teed_tmp'
59
+ args.output_dir = './output/teed_imgs'
60
+
61
+ print(args) # 调试信息
62
+ print(train_info) # 调试信息
63
 
64
+ main(args, train_info) # 确保调用正确
 
65
 
66
  shutil.rmtree('./teed_tmp')
67
  return cv2.imread(os.path.join('./output/teed_imgs', 'processed_image.png'))
 
91
  )
92
 
93
  # 启动 Gradio 应用
94
+ iface.launch(share=True)