Spaces:
Runtime error
Runtime error
add tqdm
Browse files- requirements.txt +2 -0
- train.py +7 -6
requirements.txt
CHANGED
|
@@ -6,3 +6,5 @@ scikit-image>=0.14.0
|
|
| 6 |
torchvision>=0.2.1
|
| 7 |
pillow>=7.2.0
|
| 8 |
lpips>=0.1.3
|
|
|
|
|
|
|
|
|
| 6 |
torchvision>=0.2.1
|
| 7 |
pillow>=7.2.0
|
| 8 |
lpips>=0.1.3
|
| 9 |
+
gdown
|
| 10 |
+
tqdm
|
train.py
CHANGED
|
@@ -12,6 +12,7 @@ from data_loader import (FileDataset,
|
|
| 12 |
RandomResizedCropWithAutoCenteringAndZeroPadding)
|
| 13 |
from torch.utils.data.distributed import DistributedSampler
|
| 14 |
from conr import CoNR
|
|
|
|
| 15 |
|
| 16 |
def data_sampler(dataset, shuffle, distributed):
|
| 17 |
|
|
@@ -123,7 +124,7 @@ def infer(args, humanflowmodel, image_names_list):
|
|
| 123 |
time_stamp = time.time()
|
| 124 |
prev_frame_rgb = []
|
| 125 |
prev_frame_a = []
|
| 126 |
-
for i, data in enumerate(train_data):
|
| 127 |
data_time_interval = time.time() - time_stamp
|
| 128 |
time_stamp = time.time()
|
| 129 |
with torch.no_grad():
|
|
@@ -137,11 +138,11 @@ def infer(args, humanflowmodel, image_names_list):
|
|
| 137 |
|
| 138 |
train_time_interval = time.time() - time_stamp
|
| 139 |
time_stamp = time.time()
|
| 140 |
-
if i % 5 == 0 and args.local_rank == 0:
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
with torch.no_grad():
|
| 146 |
|
| 147 |
if args.test_output_video:
|
|
|
|
| 12 |
RandomResizedCropWithAutoCenteringAndZeroPadding)
|
| 13 |
from torch.utils.data.distributed import DistributedSampler
|
| 14 |
from conr import CoNR
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
|
| 17 |
def data_sampler(dataset, shuffle, distributed):
|
| 18 |
|
|
|
|
| 124 |
time_stamp = time.time()
|
| 125 |
prev_frame_rgb = []
|
| 126 |
prev_frame_a = []
|
| 127 |
+
for i, data in tqdm(enumerate(train_data)):
|
| 128 |
data_time_interval = time.time() - time_stamp
|
| 129 |
time_stamp = time.time()
|
| 130 |
with torch.no_grad():
|
|
|
|
| 138 |
|
| 139 |
train_time_interval = time.time() - time_stamp
|
| 140 |
time_stamp = time.time()
|
| 141 |
+
# if i % 5 == 0 and args.local_rank == 0:
|
| 142 |
+
# print("[infer batch: %4d/%4d] time:%2f+%2f" % (
|
| 143 |
+
# i, train_num,
|
| 144 |
+
# data_time_interval, train_time_interval
|
| 145 |
+
# ))
|
| 146 |
with torch.no_grad():
|
| 147 |
|
| 148 |
if args.test_output_video:
|