leonelhs commited on
Commit
d12923a
·
1 Parent(s): 54a5078

code refactorized

Browse files
Files changed (8) hide show
  1. .gitignore +10 -0
  2. Dockerfile +0 -16
  3. README.md +8 -5
  4. app.py +97 -121
  5. models/model.py +29 -56
  6. packages.txt +0 -1
  7. pyrightconfig.json +0 -15
  8. server.sh +0 -1
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/
2
+ Deep3DFaceRecon_pytorch/models/__pycache__/
3
+ Deep3DFaceRecon_pytorch/util/__pycache__/
4
+ arcface_torch/backbones/__pycache__/
5
+ benchmark/__pycache__/
6
+ HRNet/__pycache__/
7
+ models/__pycache__/
8
+ configs/__pycache__/
9
+ .idea
10
+ playground.py
Dockerfile DELETED
@@ -1,16 +0,0 @@
1
- FROM xuehy93/hififace:1.0
2
-
3
-
4
- RUN apt update && apt install -y wget
5
-
6
-
7
- WORKDIR /
8
-
9
- RUN wget https://public.ph.files.1drv.com/y4m_El1_AyFLmGuZaPWOqkytzM4qYtDc3BvNNL99JV1OLCEkmD4RTQjtHEXZ0SAWb7UPLV1IPB0KO2rFlyGJaV_kITLbuAHzJ73GwR_cgvXpkIGywaTnKsKVV1jJe1LoFcl7XsxatyGpaC8-Gupq6jjBnaqSBH4dgfYAmzUk8Wqiiuj_ml2duU7No0M1T426y3RqOJsqVHXEMVfV0B6HjzQFKCCZIgfHjjHvLIB3B3xP8Q?AVOverride=1 -O checkpoints.tar.gz
10
-
11
- RUN tar xfz checkpoints.tar.gz
12
-
13
- WORKDIR /app
14
- ADD ./ /app
15
- RUN chmod +x ./server.sh
16
- CMD ["./server.sh"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,11 +1,14 @@
1
  ---
2
- title: HiFiFace Inference
3
- emoji: 📉
4
- colorFrom: indigo
5
- colorTo: indigo
6
- sdk: docker
 
 
7
  pinned: false
8
  license: mit
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: HiFiFace image swap
3
+ emoji: 👁
4
+ colorFrom: green
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.47.0
8
+ app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Swap faces from photos
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,134 +1,110 @@
1
- import argparse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import gradio as gr
 
 
4
 
5
  from benchmark.app_image import ImageSwap
6
- from benchmark.app_video import VideoSwap
7
- from configs.train_config import TrainConfig
8
- from models.model import HifiFace
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class ConfigPath:
12
- face_detector_weights = "/checkpoints/face_detector/face_detector_scrfd_10g_bnkps.onnx"
13
  model_path = ""
14
  model_idx = 80000
15
- ffmpeg_device = "cuda"
16
- device = "cuda"
17
 
 
18
 
19
- def main():
20
- cfg = ConfigPath()
21
- parser = argparse.ArgumentParser(
22
- prog="benchmark", description="What the program does", epilog="Text at the bottom of help"
23
- )
24
- parser.add_argument("-m", "--model_path", default="/checkpoints/hififace_pretrained/standard_model")
25
- parser.add_argument("-i", "--model_idx", default="320000")
26
- parser.add_argument("-f", "--ffmpeg_device", default="cpu")
27
- parser.add_argument("-d", "--device", default="cpu")
28
-
29
- args = parser.parse_args()
30
-
31
- cfg.model_path = args.model_path
32
- cfg.model_idx = int(args.model_idx)
33
- cfg.ffmpeg_device = args.ffmpeg_device
34
- cfg.device = args.device
35
- opt = TrainConfig()
36
- checkpoint = (cfg.model_path, cfg.model_idx)
37
- model_path_1 = "/checkpoints/hififace_pretrained/with_gaze_and_mouth"
38
- checkpoint1 = ("/checkpoints/hififace_pretrained/with_gaze_and_mouth", "190000")
39
- model = HifiFace(opt.identity_extractor_config, is_training=False, device=cfg.device, load_checkpoint=checkpoint)
40
-
41
- model1 = HifiFace(opt.identity_extractor_config, is_training=False, device=cfg.device, load_checkpoint=checkpoint1)
42
- image_infer = ImageSwap(cfg, model)
43
- image_infer1 = ImageSwap(cfg, model1)
44
- def inference_image(source_face, target_face, shape_rate, id_rate, iterations):
45
- return image_infer.inference(source_face, target_face, shape_rate, id_rate, int(iterations))
46
-
47
- def inference_image1(source_face, target_face, shape_rate, id_rate, iterations):
48
- return image_infer1.inference(source_face, target_face, shape_rate, id_rate, int(iterations))
49
-
50
- model_name = cfg.model_path.split("/")[-1] + ":" + f"{cfg.model_idx}"
51
- model_name1 = model_path_1.split("/")[-1] + ":" + "190000"
52
- with gr.Blocks(title="FaceSwap") as demo:
53
- gr.Markdown(
54
- f"""
55
- ### standard model: {model_name} \n
56
- ### model with eye and mouth hm loss: {model_name1}
57
- """
58
- )
59
- with gr.Tab("Image swap with standard model"):
60
- with gr.Row():
61
- source_image = gr.Image(shape=None, label="source image")
62
- target_image = gr.Image(shape=None, label="target image")
63
- with gr.Row():
64
- with gr.Column():
65
- structure_sim = gr.Slider(
66
- minimum=0.0,
67
- maximum=1.0,
68
- value=1.0,
69
- step=0.1,
70
- label="3d similarity",
71
- )
72
- id_sim = gr.Slider(
73
- minimum=0.0,
74
- maximum=1.0,
75
- value=1.0,
76
- step=0.1,
77
- label="id similarity",
78
- )
79
- iters = gr.Slider(
80
- minimum=1,
81
- maximum=10,
82
- value=1,
83
- step=1,
84
- label="iters",
85
- )
86
- image_btn = gr.Button("image swap")
87
- output_image = gr.Image(shape=None, label="Result")
88
-
89
- image_btn.click(
90
- fn=inference_image,
91
- inputs=[source_image, target_image, structure_sim, id_sim, iters],
92
- outputs=output_image,
93
- )
94
-
95
- with gr.Tab("Image swap with eye&mouth hm loss model"):
96
  with gr.Row():
97
- source_image = gr.Image(shape=None, label="source image")
98
- target_image = gr.Image(shape=None, label="target image")
 
 
 
 
 
 
 
99
  with gr.Row():
100
- with gr.Column():
101
- structure_sim = gr.Slider(
102
- minimum=0.0,
103
- maximum=1.0,
104
- value=1.0,
105
- step=0.1,
106
- label="3d similarity",
107
- )
108
- id_sim = gr.Slider(
109
- minimum=0.0,
110
- maximum=1.0,
111
- value=1.0,
112
- step=0.1,
113
- label="id similarity",
114
- )
115
- iters = gr.Slider(
116
- minimum=1,
117
- maximum=10,
118
- value=1,
119
- step=1,
120
- label="iters",
121
- )
122
- image_btn = gr.Button("image swap")
123
- output_image = gr.Image(shape=None, label="Result")
124
-
125
- image_btn.click(
126
- fn=inference_image1,
127
- inputs=[source_image, target_image, structure_sim, id_sim, iters],
128
- outputs=output_image,
129
- )
130
- demo.launch(server_name="0.0.0.0", server_port=7860)
131
-
132
-
133
- if __name__ == "__main__":
134
- main()
 
1
+ #######################################################################################
2
+ #
3
+ # MIT License
4
+ #
5
+ # Copyright (c) [2025] [leonelhs@gmail.com]
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+ #
25
+ #######################################################################################
26
+ #
27
+ # Source code is based on or inspired by several projects.
28
+ # For more details and proper attribution, please refer to the following resources:
29
+ #
30
+ # - [hyxue] - [https://huggingface.co/spaces/hyxue/HiFiFace-inference-demo]
31
+ # - [maum-ai] [https://github.com/maum-ai/hififace]
32
+ #
33
 
34
  import gradio as gr
35
+ import torch
36
+ from huggingface_hub import hf_hub_download
37
 
38
  from benchmark.app_image import ImageSwap
39
+ from models.model import HifiFaceST, HifiFaceWGM
 
 
40
 
41
+ REPO_ID = "leonelhs/HiFiFace"
42
+
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+
45
+ gen_st_path = hf_hub_download(repo_id=REPO_ID,
46
+ filename="hififace_pretrained/standard_model/generator_320000.pth")
47
+
48
+ gen_wgm_path = hf_hub_download(repo_id=REPO_ID,
49
+ filename="hififace_pretrained/with_gaze_and_mouth/generator_190000.pth")
50
+
51
+ fade_detector_path = hf_hub_download(repo_id=REPO_ID,
52
+ filename="face_detector/face_detector_scrfd_10g_bnkps.onnx")
53
+
54
+ identity_extractor_config = {
55
+ "f_3d_checkpoint_path": hf_hub_download(repo_id=REPO_ID, filename="Deep3DFaceRecon/epoch_20.pth"),
56
+ "f_id_checkpoint_path": hf_hub_download(repo_id=REPO_ID, filename="arcface/ms1mv3_arcface_r100_fp16_backbone.pth")
57
+ }
58
 
59
  class ConfigPath:
60
+ face_detector_weights = fade_detector_path
61
  model_path = ""
62
  model_idx = 80000
63
+ ffmpeg_device = device
64
+ device = device
65
 
66
+ cfg = ConfigPath()
67
 
68
+ model_standard = HifiFaceST(identity_extractor_config, device=device, generator_path=gen_st_path)
69
+
70
+ model_wgm = HifiFaceWGM(identity_extractor_config, device=device, generator_path=gen_wgm_path)
71
+
72
+ image_infer_standard = ImageSwap(cfg, model_standard)
73
+ image_infer_wgm = ImageSwap(cfg, model_wgm)
74
+
75
+ MODELS = {
76
+ "Standard model": "standard",
77
+ "Eye and mouth hm loss": "eyeandmouth",
78
+ }
79
+
80
+ def inference_image(source_face, target_face, method="standard", shape_rate=1.0, id_rate=1.0, iterations=1):
81
+ if method == "standard":
82
+ return target_face, image_infer_standard.inference(source_face, target_face, shape_rate, id_rate, int(iterations))
83
+ return target_face, image_infer_wgm.inference(source_face, target_face, shape_rate, id_rate, int(iterations))
84
+
85
+
86
+ with gr.Blocks(title="FaceSwap") as app:
87
+ gr.Markdown("## HiFiFace image swap")
88
+ with gr.Row():
89
+ with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  with gr.Row():
91
+ source_image = gr.Image(type="numpy", label="Face image")
92
+ target_image = gr.Image(type="numpy", label="Body image")
93
+ mod = gr.Dropdown(choices=list(MODELS.items()), label="Model generator", value="standard")
94
+ image_btn = gr.Button("Swap image")
95
+ with gr.Accordion("Fine tunes", open=False):
96
+ structure_sim = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="3d similarity")
97
+ id_sim = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="id similarity")
98
+ iters = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="iters")
99
+ with gr.Column(scale=1):
100
  with gr.Row():
101
+ output_image = gr.ImageSlider(label="Swapped image", type="pil")
102
+
103
+ image_btn.click(
104
+ fn=inference_image,
105
+ inputs=[source_image, target_image, mod, structure_sim, id_sim, iters],
106
+ outputs=output_image,
107
+ )
108
+
109
+ app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True)
110
+ app.queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/model.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
 
2
  from typing import Dict
3
- from typing import Optional
4
  from typing import Tuple
5
 
6
  import kornia
@@ -10,28 +10,30 @@ import torch.nn as nn
10
  import torch.nn.functional as F
11
  from loguru import logger
12
 
13
- from arcface_torch.backbones.iresnet import iresnet100
14
- from configs.train_config import TrainConfig
15
  from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
16
  from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
17
  from HRNet.hrnet import HighResolutionNet
 
18
  from models.discriminator import Discriminator
19
  from models.gan_loss import GANLoss
20
  from models.generator import Generator
21
  from models.init_weight import init_net
22
 
23
-
24
  class HifiFace:
25
  def __init__(
26
  self,
27
  identity_extractor_config,
28
- is_training=True,
29
- device="cpu",
30
- load_checkpoint: Optional[Tuple[str, int]] = None,
31
  ):
32
  super(HifiFace, self).__init__()
 
 
33
  self.generator = Generator(identity_extractor_config)
34
  self.is_training = is_training
 
 
35
 
36
  if self.is_training:
37
  self.lr = TrainConfig().lr
@@ -80,10 +82,9 @@ class HifiFace:
80
 
81
  self.dilation_kernel = torch.ones(5, 5)
82
 
83
- if load_checkpoint is not None:
84
- self.load(load_checkpoint[0], load_checkpoint[1])
85
 
86
- self.setup(device)
87
 
88
  def save(self, path, idx=None):
89
  os.makedirs(path, exist_ok=True)
@@ -100,18 +101,9 @@ class HifiFace:
100
  torch.save(self.generator.state_dict(), g_path)
101
  torch.save(self.discriminator.state_dict(), d_path)
102
 
103
- def load(self, path, idx=None):
104
- if idx is None:
105
- g_path = os.path.join(path, "generator.pth")
106
- d_path = os.path.join(path, "discriminator.pth")
107
- else:
108
- g_path = os.path.join(path, f"generator_{idx}.pth")
109
- d_path = os.path.join(path, f"discriminator_{idx}.pth")
110
- logger.info(f"Loading generator from {g_path}")
111
- self.generator.load_state_dict(torch.load(g_path, map_location="cpu"))
112
- if self.is_training:
113
- logger.info(f"Loading discriminator from {d_path}")
114
- self.discriminator.load_state_dict(torch.load(d_path, map_location="cpu"))
115
 
116
  def setup(self, device):
117
  self.generator.to(device)
@@ -399,37 +391,18 @@ class HifiFace:
399
  }
400
 
401
 
402
- if __name__ == "__main__":
403
- import torch
404
- import cv2
405
- from configs.train_config import TrainConfig
406
-
407
- identity_extractor_config = TrainConfig().identity_extractor_config
408
-
409
- model = HifiFace(identity_extractor_config, is_training=True)
410
-
411
- # src = cv2.imread("/home/xuehongyang/data/test1.jpg")
412
- # tgt = cv2.imread("/home/xuehongyang/data/test2.jpg")
413
- # src = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
414
- # tgt = cv2.cvtColor(tgt, cv2.COLOR_BGR2RGB)
415
- # src = cv2.resize(src, (256, 256))
416
- # tgt = cv2.resize(tgt, (256, 256))
417
- # src = src.transpose(2, 0, 1)[None, ...]
418
- # tgt = tgt.transpose(2, 0, 1)[None, ...]
419
- # source_img = torch.from_numpy(src).float() / 255.0
420
- # target_img = torch.from_numpy(tgt).float() / 255.0
421
- # same_id_mask = torch.Tensor([1]).unsqueeze(0)
422
- # tgt_mask = target_img[:, 0, :, :].unsqueeze(1)
423
- # if torch.cuda.is_available():
424
- # model.to("cuda:3")
425
- # source_img = source_img.to("cuda:3")
426
- # target_img = target_img.to("cuda:3")
427
- # tgt_mask = tgt_mask.to("cuda:3")
428
- # same_id_mask = same_id_mask.to("cuda:3")
429
- # source_img = source_img.repeat(16, 1, 1, 1)
430
- # target_img = target_img.repeat(16, 1, 1, 1)
431
- # tgt_mask = tgt_mask.repeat(16, 1, 1, 1)
432
- # same_id_mask = same_id_mask.repeat(16, 1)
433
- # while True:
434
- # x = model.optimize(source_img, target_img, tgt_mask, same_id_mask)
435
- # print(x[0]["loss_generator"])
 
1
  import os
2
+ from abc import abstractmethod
3
  from typing import Dict
 
4
  from typing import Tuple
5
 
6
  import kornia
 
10
  import torch.nn.functional as F
11
  from loguru import logger
12
 
 
 
13
  from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
14
  from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
15
  from HRNet.hrnet import HighResolutionNet
16
+ from arcface_torch.backbones.iresnet import iresnet100
17
  from models.discriminator import Discriminator
18
  from models.gan_loss import GANLoss
19
  from models.generator import Generator
20
  from models.init_weight import init_net
21
 
 
22
  class HifiFace:
23
  def __init__(
24
  self,
25
  identity_extractor_config,
26
+ generator_path,
27
+ is_training=False,
28
+ device="cpu"
29
  ):
30
  super(HifiFace, self).__init__()
31
+ self.d_optimizer = None
32
+ self.g_optimizer = None
33
  self.generator = Generator(identity_extractor_config)
34
  self.is_training = is_training
35
+ self.device = device
36
+ self.generator_path = generator_path
37
 
38
  if self.is_training:
39
  self.lr = TrainConfig().lr
 
82
 
83
  self.dilation_kernel = torch.ones(5, 5)
84
 
85
+ self.load_checkpoint()
 
86
 
87
+ self.setup(self.device)
88
 
89
  def save(self, path, idx=None):
90
  os.makedirs(path, exist_ok=True)
 
101
  torch.save(self.generator.state_dict(), g_path)
102
  torch.save(self.discriminator.state_dict(), d_path)
103
 
104
+ @abstractmethod
105
+ def load_checkpoint(self):
106
+ pass
 
 
 
 
 
 
 
 
 
107
 
108
  def setup(self, device):
109
  self.generator.to(device)
 
391
  }
392
 
393
 
394
+ class HifiFaceST(HifiFace):
395
+ def __init__(self, identity_extractor_config, device, generator_path):
396
+ super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
397
+
398
+ def load_checkpoint(self):
399
+ self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
400
+ logger.info(f"Loading generator from {self.generator_path}")
401
+
402
+ class HifiFaceWGM(HifiFace):
403
+ def __init__(self, identity_extractor_config, device, generator_path):
404
+ super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
405
+
406
+ def load_checkpoint(self):
407
+ self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
408
+ logger.info(f"Loading generator from {self.generator_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages.txt DELETED
@@ -1 +0,0 @@
1
- wget
 
 
pyrightconfig.json DELETED
@@ -1,15 +0,0 @@
1
- {
2
- "reportMissingImports": true,
3
- "reportMissingTypeStubs": true,
4
- "useLibraryCodeForTypes": true,
5
- "reportUnusedImport": "warning",
6
- "reportUnusedVariable": "warning",
7
- "reportDuplicateImport": true,
8
- "reportPrivateImportUsage": false,
9
- "reportWildcardImportFromLibrary": "warning",
10
- "reportTypedDictNotRequiredAccess": false,
11
- "reportGeneralTypeIssues": false,
12
- "venvPath": "/home/xuehongyang/miniconda3/envs/",
13
- "venv": "pytorch-2.0",
14
- "stubPath": "/home/xuehongyang/dev_configs/typings"
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server.sh DELETED
@@ -1 +0,0 @@
1
- python3 app.py