Update app.py
Browse filesFix fatal error
app.py
CHANGED
|
@@ -169,19 +169,20 @@ model = models_vit.__dict__['vit_base_patch16'](
|
|
| 169 |
num_classes=args.nb_classes,
|
| 170 |
drop_path_rate=args.drop_path,
|
| 171 |
global_pool=args.global_pool,
|
| 172 |
-
)
|
| 173 |
|
| 174 |
|
| 175 |
def load_model(ckpt):
|
| 176 |
if ckpt == 'choose from here' or 'continuously updating...':
|
| 177 |
return gr.update()
|
| 178 |
-
args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
|
| 179 |
if os.path.isfile(args.resume) == False:
|
| 180 |
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
| 181 |
repo_id='Wolowolo/fsfm-3c/' + CKPT_NAME[ckpt],
|
| 182 |
filename=ckpt)
|
| 183 |
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 184 |
model.load_state_dict(checkpoint['model'])
|
|
|
|
| 185 |
return gr.update()
|
| 186 |
|
| 187 |
|
|
@@ -276,9 +277,7 @@ def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None, dev
|
|
| 276 |
return frame_indices
|
| 277 |
|
| 278 |
|
| 279 |
-
def FSFM3C_video_detection(video):
|
| 280 |
-
model.to(device)
|
| 281 |
-
|
| 282 |
# extract frames
|
| 283 |
num_frames = 32
|
| 284 |
|
|
@@ -308,21 +307,18 @@ def FSFM3C_video_detection(video):
|
|
| 308 |
|
| 309 |
real_prob_video = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
|
| 310 |
if real_prob_video > 50:
|
| 311 |
-
result_message = "real"
|
| 312 |
else:
|
| 313 |
-
result_message = "fake"
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
video_results = (f"The face in this video may be {result_message}
|
| 316 |
-
f"and the video-level real_face_probability is {real_prob_video}% \n"
|
| 317 |
-
f"The frame-level detection results ['sampled_frame_index': 'real_face_probability']: \n"
|
| 318 |
-
f"{frame_results} \n")
|
| 319 |
|
| 320 |
return video_results
|
| 321 |
|
| 322 |
|
| 323 |
-
def FSFM3C_image_detection(image):
|
| 324 |
-
model.to(device)
|
| 325 |
-
|
| 326 |
files = os.listdir(FRAME_SAVE_PATH)
|
| 327 |
num_files = len(files)
|
| 328 |
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
|
|
@@ -352,12 +348,11 @@ def FSFM3C_image_detection(image):
|
|
| 352 |
|
| 353 |
real_prob_image = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
|
| 354 |
if real_prob_image > 50:
|
| 355 |
-
result_message = "real"
|
| 356 |
else:
|
| 357 |
-
result_message = "fake"
|
| 358 |
-
|
| 359 |
-
image_results = (f"The face in this image may be {result_message}
|
| 360 |
-
f"and the real_face_probability is {real_prob_image}%")
|
| 361 |
|
| 362 |
return image_results
|
| 363 |
|
|
@@ -406,12 +401,12 @@ with gr.Blocks() as demo:
|
|
| 406 |
|
| 407 |
image_submit_btn.click(
|
| 408 |
fn=FSFM3C_image_detection,
|
| 409 |
-
inputs=[image],
|
| 410 |
outputs=[output_results_image],
|
| 411 |
)
|
| 412 |
video_submit_btn.click(
|
| 413 |
fn=FSFM3C_video_detection,
|
| 414 |
-
inputs=[video],
|
| 415 |
outputs=[output_results_video],
|
| 416 |
)
|
| 417 |
ckpt_select_dropdown.change(
|
|
|
|
| 169 |
num_classes=args.nb_classes,
|
| 170 |
drop_path_rate=args.drop_path,
|
| 171 |
global_pool=args.global_pool,
|
| 172 |
+
).to(device)
|
| 173 |
|
| 174 |
|
| 175 |
def load_model(ckpt):
|
| 176 |
if ckpt == 'choose from here' or 'continuously updating...':
|
| 177 |
return gr.update()
|
| 178 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_NAME[ckpt])
|
| 179 |
if os.path.isfile(args.resume) == False:
|
| 180 |
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
| 181 |
repo_id='Wolowolo/fsfm-3c/' + CKPT_NAME[ckpt],
|
| 182 |
filename=ckpt)
|
| 183 |
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 184 |
model.load_state_dict(checkpoint['model'])
|
| 185 |
+
model.eval()
|
| 186 |
return gr.update()
|
| 187 |
|
| 188 |
|
|
|
|
| 277 |
return frame_indices
|
| 278 |
|
| 279 |
|
| 280 |
+
def FSFM3C_video_detection(video, ckpt_select_dropdown):
|
|
|
|
|
|
|
| 281 |
# extract frames
|
| 282 |
num_frames = 32
|
| 283 |
|
|
|
|
| 307 |
|
| 308 |
real_prob_video = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
|
| 309 |
if real_prob_video > 50:
|
| 310 |
+
result_message = "real" if 'FAS' not in ckpt_select_dropdown else 'spoof'
|
| 311 |
else:
|
| 312 |
+
result_message = "fake" if 'FAS' not in ckpt_select_dropdown else 'real'
|
| 313 |
+
prob = 1 - real_prob_image if real_prob_video <= 50 else real_prob_video
|
| 314 |
+
image_results = (f"The face in this image may be {result_message} with probability is {real_prob_image}%")
|
| 315 |
|
| 316 |
+
video_results = (f"The face in this video may be {result_message} with probability {prob}")
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
return video_results
|
| 319 |
|
| 320 |
|
| 321 |
+
def FSFM3C_image_detection(image, ckpt_select_dropdown):
|
|
|
|
|
|
|
| 322 |
files = os.listdir(FRAME_SAVE_PATH)
|
| 323 |
num_files = len(files)
|
| 324 |
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
|
|
|
|
| 348 |
|
| 349 |
real_prob_image = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
|
| 350 |
if real_prob_image > 50:
|
| 351 |
+
result_message = "real" if 'FAS' not in ckpt_select_dropdown else 'spoof'
|
| 352 |
else:
|
| 353 |
+
result_message = "fake" if 'FAS' not in ckpt_select_dropdown else 'real'
|
| 354 |
+
prob = 1 - real_prob_image if real_prob_image <= 50 else real_prob_image
|
| 355 |
+
image_results = (f"The face in this image may be {result_message} with probability is {real_prob_image}%")
|
|
|
|
| 356 |
|
| 357 |
return image_results
|
| 358 |
|
|
|
|
| 401 |
|
| 402 |
image_submit_btn.click(
|
| 403 |
fn=FSFM3C_image_detection,
|
| 404 |
+
inputs=[image, ckpt_select_dropdown],
|
| 405 |
outputs=[output_results_image],
|
| 406 |
)
|
| 407 |
video_submit_btn.click(
|
| 408 |
fn=FSFM3C_video_detection,
|
| 409 |
+
inputs=[video, ckpt_select_dropdown],
|
| 410 |
outputs=[output_results_video],
|
| 411 |
)
|
| 412 |
ckpt_select_dropdown.change(
|