acmyu commited on
Commit
e0fbca6
·
1 Parent(s): 9ff3844

get fvd metric

Browse files
Files changed (3) hide show
  1. evaluate.py +118 -49
  2. main.py +1 -1
  3. requirements.txt +3 -1
evaluate.py CHANGED
@@ -8,14 +8,26 @@ import torch
8
  import torchvision.transforms as transforms
9
  import lpips
10
  from pytorch_fid.fid_score import calculate_fid_given_paths
 
11
  import os
12
  import json
 
13
  from huggingface_hub import snapshot_download
14
 
15
  # Convert PIL to numpy
16
  def pil_to_np(img):
17
  return np.array(img).astype(np.float32) / 255.0
18
 
 
 
 
 
 
 
 
 
 
 
19
  # SSIM
20
  def compute_ssim(img1, img2):
21
  img1_np = pil_to_np(img1)
@@ -46,53 +58,96 @@ def compute_lpips(img1, img2):
46
  img2_tensor = lpips_transform(img2).unsqueeze(0)
47
  return lpips_model(img1_tensor, img2_tensor).item()
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # FID: Save images to temp folders for FID calculation
50
  def compute_fid(img1, img2):
51
  os.makedirs('temp/img1', exist_ok=True)
52
  os.makedirs('temp/img2', exist_ok=True)
53
  img1.save('temp/img1/0.png')
54
  img2.save('temp/img2/0.png')
55
- fid = calculate_fid_given_paths(['temp/img1', 'temp/img2'], batch_size=1, device='cpu', dims=2048)
56
  return fid
57
 
58
 
59
-
60
-
61
  def get_score(item, image_paths, video_path, metrics, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
62
-
63
  images = []
64
  for path in image_paths:
65
  img = Image.open(path)
66
  images.append([img])
67
-
68
- gt_frames = extract_frames(video_path, fps)
69
- gt_frames = gt_frames[:1000]
70
- for f in gt_frames:
71
- f.thumbnail((512,512))
72
-
73
- os.makedirs('out/'+item, exist_ok=True)
74
-
75
-
76
- for i, frame in enumerate(gt_frames):
77
- frame.save("out/"+item+"/frame_"+str(i)+".png")
78
-
79
- #results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
80
- results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
81
 
82
- for i, result in enumerate(results):
83
- result.save("out/"+item+"/result_"+str(i)+".png")
84
-
85
- for i, result in enumerate(results_base):
86
- result.save("out/"+item+"/base_"+str(i)+".png")
87
-
88
- """
89
- img1=gt_frames[0]
90
- img2=Image.open("out/base_0.png")
91
- print("SSIM:", compute_ssim(img1, img2))
92
- print("PSNR:", compute_psnr(img1, img2))
93
- print("LPIPS:", compute_lpips(img1, img2))
94
- print("FID:", compute_fid(img1, img2))
95
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  ssim = []
98
  psnr = []
@@ -102,37 +157,50 @@ def get_score(item, image_paths, video_path, metrics, train_steps=100, inference
102
  psnr2 = []
103
  lpips2 = []
104
  fid2 = []
 
 
105
  for gt, result, base in zip(gt_frames, results, results_base):
106
  ssim.append(float(compute_ssim(gt, result)))
107
  psnr.append(float(compute_psnr(gt, result)))
108
  lpips.append(float(compute_lpips(gt, result)))
109
- fid.append(float(compute_fid(gt, result)))
110
 
111
  ssim2.append(float(compute_ssim(gt, base)))
112
  psnr2.append(float(compute_psnr(gt, base)))
113
  lpips2.append(float(compute_lpips(gt, base)))
114
- fid2.append(float(compute_fid(gt, base)))
115
 
 
 
 
 
 
 
 
 
 
116
 
117
  print("SSIM:", sum(ssim)/len(ssim))
118
  print("PSNR:", sum(psnr)/len(psnr))
119
  print("LPIPS:", sum(lpips)/len(lpips))
120
  print("FID:", sum(fid)/len(fid))
 
121
  print('baseline:')
122
  print("SSIM:", sum(ssim2)/len(ssim2))
123
  print("PSNR:", sum(psnr2)/len(psnr2))
124
  print("LPIPS:", sum(lpips2)/len(lpips2))
125
  print("FID:", sum(fid2)/len(fid2))
 
126
 
127
- metrics[item] = {'ft': {}, 'base': {}}
128
  metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
129
  metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
130
  metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips}
131
  metrics[item]['ft']['fid'] = {'avg': sum(fid)/len(fid), 'vals': fid}
 
132
  metrics[item]['base']['ssim'] = {'avg': sum(ssim2)/len(ssim2), 'vals': ssim2}
133
  metrics[item]['base']['psnr'] = {'avg': sum(psnr2)/len(psnr2), 'vals': psnr2}
134
  metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2}
135
  metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2}
 
136
 
137
  #print(metrics)
138
 
@@ -154,6 +222,7 @@ def get_files(directory_path):
154
 
155
 
156
  def run_evaluate():
 
157
  snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
158
 
159
  with open('metrics.json', 'r') as file:
@@ -169,20 +238,20 @@ def run_evaluate():
169
  continue
170
  print(item)
171
 
172
- try:
173
- files = get_files('test/'+item)
174
- images = list(filter(lambda x: not x.endswith('.mp4'), files))
175
- images = ['test/'+item+'/'+img for img in images]
176
- videos = [x for x in files if x.endswith('.mp4')]
177
- print(images, videos)
178
-
179
- if len(videos) == 1:
180
- get_score(item, images, 'test/'+item+'/'+videos[0], metrics)
181
- #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
182
- else:
183
- print('Error: mp4 not found')
184
- except Exception as e:
185
- print("Error", item, e)
186
 
187
 
188
  ssim = []
 
8
  import torchvision.transforms as transforms
9
  import lpips
10
  from pytorch_fid.fid_score import calculate_fid_given_paths
11
+ from cdfvd import fvd
12
  import os
13
  import json
14
+ import cv2
15
  from huggingface_hub import snapshot_download
16
 
17
  # Convert PIL to numpy
18
  def pil_to_np(img):
19
  return np.array(img).astype(np.float32) / 255.0
20
 
21
+ def save_mp4(images, name):
22
+ width, height = images[0].size
23
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4
24
+ video = cv2.VideoWriter(name, fourcc, 12, (width, height))
25
+
26
+ for image in images:
27
+ img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
28
+ video.write(img)
29
+ video.release()
30
+
31
  # SSIM
32
  def compute_ssim(img1, img2):
33
  img1_np = pil_to_np(img1)
 
58
  img2_tensor = lpips_transform(img2).unsqueeze(0)
59
  return lpips_model(img1_tensor, img2_tensor).item()
60
 
61
+
62
+ def trans(x):
63
+ # if greyscale images add channel
64
+ if x.shape[-3] == 1:
65
+ x = x.repeat(1, 1, 3, 1, 1)
66
+
67
+ # permute BTCHW -> BCTHW
68
+ x = x.permute(0, 2, 1, 3, 4)
69
+
70
+ return x
71
+
72
+
73
+ def compute_fvd(item, gt_imgs, results):
74
+ os.makedirs('temp/gt', exist_ok=True)
75
+ os.makedirs('temp/result', exist_ok=True)
76
+
77
+ #save_mp4(gt_imgs, "temp/gt/gt.mp4")
78
+ #save_mp4(results, "temp/result/result.mp4")
79
+
80
+ evaluator = fvd.cdfvd('i3d', ckpt_path=None, device='cuda', n_real=1, n_fake=1)
81
+ evaluator.compute_real_stats(evaluator.load_videos('temp/gt', data_type='video_folder'))
82
+ evaluator.compute_fake_stats(evaluator.load_videos('temp/result', data_type='video_folder'))
83
+ score = evaluator.compute_fvd_from_stats()
84
+ evaluator.offload_model_to_cpu()
85
+ print(score)
86
+ return score
87
+
88
+
89
+ def compute_fidx(item, gt_imgs, results):
90
+ os.makedirs('temp/'+item+'_gt', exist_ok=True)
91
+ os.makedirs('temp/'+item, exist_ok=True)
92
+ c = 0
93
+ for img in gt_imgs:
94
+ img.save('temp/'+item+'_gt/'+str(c)+'.png')
95
+ c = c+1
96
+ c = 0
97
+ for img in gt_imgs:
98
+ img.save('temp/'+item+'/'+str(c)+'.png')
99
+ c = c+1
100
+
101
+ fid = calculate_fid_given_paths(['temp/'+item+'_gt', 'temp/'+item], batch_size=8, device='cuda', dims=2048)
102
+ return fid
103
+
104
  # FID: Save images to temp folders for FID calculation
105
  def compute_fid(img1, img2):
106
  os.makedirs('temp/img1', exist_ok=True)
107
  os.makedirs('temp/img2', exist_ok=True)
108
  img1.save('temp/img1/0.png')
109
  img2.save('temp/img2/0.png')
110
+ fid = calculate_fid_given_paths(['temp/img1', 'temp/img2'], batch_size=1, device='cuda', dims=2048)
111
  return fid
112
 
113
 
 
 
114
  def get_score(item, image_paths, video_path, metrics, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
 
115
  images = []
116
  for path in image_paths:
117
  img = Image.open(path)
118
  images.append([img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ results = []
121
+ results_base = []
122
+ gt_frames = []
123
+ if os.path.isdir('out/'+item):
124
+ for filename in os.listdir('out/'+item):
125
+ img = Image.open('out/'+item+'/'+filename)
126
+ if filename.startswith('result_'):
127
+ results.append(img)
128
+ elif filename.startswith('base_'):
129
+ results_base.append(img)
130
+ elif filename.startswith('frame_'):
131
+ gt_frames.append(img)
132
+ else:
133
+ gt_frames = extract_frames(video_path, fps)
134
+ gt_frames = gt_frames[:200]
135
+ for f in gt_frames:
136
+ f.thumbnail((512,512))
137
+
138
+ #results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
139
+ results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
140
+
141
+ os.makedirs('out/'+item, exist_ok=True)
142
+
143
+ for i, frame in enumerate(gt_frames):
144
+ frame.save("out/"+item+"/frame_"+str(i)+".png")
145
+
146
+ for i, result in enumerate(results):
147
+ result.save("out/"+item+"/result_"+str(i)+".png")
148
+
149
+ for i, result in enumerate(results_base):
150
+ result.save("out/"+item+"/base_"+str(i)+".png")
151
 
152
  ssim = []
153
  psnr = []
 
157
  psnr2 = []
158
  lpips2 = []
159
  fid2 = []
160
+ c = 0
161
+ #print(len(gt_frames), len(results), len(results_base))
162
  for gt, result, base in zip(gt_frames, results, results_base):
163
  ssim.append(float(compute_ssim(gt, result)))
164
  psnr.append(float(compute_psnr(gt, result)))
165
  lpips.append(float(compute_lpips(gt, result)))
 
166
 
167
  ssim2.append(float(compute_ssim(gt, base)))
168
  psnr2.append(float(compute_psnr(gt, base)))
169
  lpips2.append(float(compute_lpips(gt, base)))
 
170
 
171
+ if c<50:
172
+ print(c)
173
+ fid.append(float(compute_fid(gt, result)))
174
+ fid2.append(float(compute_fid(gt, base)))
175
+ c = c+1
176
+
177
+ fvd = float(compute_fvd(item, gt_frames, results))
178
+ fvd2 = float(compute_fvd(item, gt_frames, results_base))
179
+
180
 
181
  print("SSIM:", sum(ssim)/len(ssim))
182
  print("PSNR:", sum(psnr)/len(psnr))
183
  print("LPIPS:", sum(lpips)/len(lpips))
184
  print("FID:", sum(fid)/len(fid))
185
+ print("FVD:", fvd)
186
  print('baseline:')
187
  print("SSIM:", sum(ssim2)/len(ssim2))
188
  print("PSNR:", sum(psnr2)/len(psnr2))
189
  print("LPIPS:", sum(lpips2)/len(lpips2))
190
  print("FID:", sum(fid2)/len(fid2))
191
+ print("FVD:", fvd2)
192
 
193
+ metrics[item] = {'ft': {}, 'base': {}, 'frames': len(gt_frames), 'complexity': len(images)}
194
  metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
195
  metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
196
  metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips}
197
  metrics[item]['ft']['fid'] = {'avg': sum(fid)/len(fid), 'vals': fid}
198
+ metrics[item]['ft']['fvd'] = fvd
199
  metrics[item]['base']['ssim'] = {'avg': sum(ssim2)/len(ssim2), 'vals': ssim2}
200
  metrics[item]['base']['psnr'] = {'avg': sum(psnr2)/len(psnr2), 'vals': psnr2}
201
  metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2}
202
  metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2}
203
+ metrics[item]['base']['fvd'] = fvd2
204
 
205
  #print(metrics)
206
 
 
222
 
223
 
224
  def run_evaluate():
225
+ print("run_evaluate")
226
  snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
227
 
228
  with open('metrics.json', 'r') as file:
 
238
  continue
239
  print(item)
240
 
241
+ #try:
242
+ files = get_files('test/'+item)
243
+ images = list(filter(lambda x: not x.endswith('.mp4'), files))
244
+ images = ['test/'+item+'/'+img for img in images]
245
+ videos = [x for x in files if x.endswith('.mp4')]
246
+ print(images, videos)
247
+
248
+ if len(videos) == 1:
249
+ get_score(item, images, 'test/'+item+'/'+videos[0], metrics)
250
+ #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
251
+ else:
252
+ print('Error: mp4 not found')
253
+ #except Exception as e:
254
+ # print("Error", item, e)
255
 
256
 
257
  ssim = []
main.py CHANGED
@@ -85,7 +85,7 @@ debug = False
85
  save_model = True
86
  should_gen_vid = False
87
  max_batch_size = 8
88
- max_frame_count = 1000
89
 
90
  def save_temp_imgs(imgs):
91
  os.makedirs('temp', exist_ok=True)
 
85
  save_model = True
86
  should_gen_vid = False
87
  max_batch_size = 8
88
+ max_frame_count = 200
89
 
90
  def save_temp_imgs(imgs):
91
  os.makedirs('temp', exist_ok=True)
requirements.txt CHANGED
@@ -23,4 +23,6 @@ spaces
23
  matplotlib
24
 
25
  lpips
26
- pytorch-fid
 
 
 
23
  matplotlib
24
 
25
  lpips
26
+ pytorch-fid
27
+ cd-fvd
28
+ av