code refactorized
Browse files- .gitignore +10 -0
- Dockerfile +0 -16
- README.md +8 -5
- app.py +97 -121
- models/model.py +29 -56
- packages.txt +0 -1
- pyrightconfig.json +0 -15
- server.sh +0 -1
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/
|
| 2 |
+
Deep3DFaceRecon_pytorch/models/__pycache__/
|
| 3 |
+
Deep3DFaceRecon_pytorch/util/__pycache__/
|
| 4 |
+
arcface_torch/backbones/__pycache__/
|
| 5 |
+
benchmark/__pycache__/
|
| 6 |
+
HRNet/__pycache__/
|
| 7 |
+
models/__pycache__/
|
| 8 |
+
configs/__pycache__/
|
| 9 |
+
.idea
|
| 10 |
+
playground.py
|
Dockerfile
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
FROM xuehy93/hififace:1.0
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
RUN apt update && apt install -y wget
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
WORKDIR /
|
| 8 |
-
|
| 9 |
-
RUN wget https://public.ph.files.1drv.com/y4m_El1_AyFLmGuZaPWOqkytzM4qYtDc3BvNNL99JV1OLCEkmD4RTQjtHEXZ0SAWb7UPLV1IPB0KO2rFlyGJaV_kITLbuAHzJ73GwR_cgvXpkIGywaTnKsKVV1jJe1LoFcl7XsxatyGpaC8-Gupq6jjBnaqSBH4dgfYAmzUk8Wqiiuj_ml2duU7No0M1T426y3RqOJsqVHXEMVfV0B6HjzQFKCCZIgfHjjHvLIB3B3xP8Q?AVOverride=1 -O checkpoints.tar.gz
|
| 10 |
-
|
| 11 |
-
RUN tar xfz checkpoints.tar.gz
|
| 12 |
-
|
| 13 |
-
WORKDIR /app
|
| 14 |
-
ADD ./ /app
|
| 15 |
-
RUN chmod +x ./server.sh
|
| 16 |
-
CMD ["./server.sh"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,11 +1,14 @@
|
|
| 1 |
---
|
| 2 |
-
title: HiFiFace
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
|
|
|
|
|
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: HiFiFace image swap
|
| 3 |
+
emoji: 👁
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.47.0
|
| 8 |
+
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
+
short_description: Swap faces from photos
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,134 +1,110 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import gradio as gr
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from benchmark.app_image import ImageSwap
|
| 6 |
-
from
|
| 7 |
-
from configs.train_config import TrainConfig
|
| 8 |
-
from models.model import HifiFace
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class ConfigPath:
|
| 12 |
-
face_detector_weights =
|
| 13 |
model_path = ""
|
| 14 |
model_idx = 80000
|
| 15 |
-
ffmpeg_device =
|
| 16 |
-
device =
|
| 17 |
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
model1 = HifiFace(opt.identity_extractor_config, is_training=False, device=cfg.device, load_checkpoint=checkpoint1)
|
| 42 |
-
image_infer = ImageSwap(cfg, model)
|
| 43 |
-
image_infer1 = ImageSwap(cfg, model1)
|
| 44 |
-
def inference_image(source_face, target_face, shape_rate, id_rate, iterations):
|
| 45 |
-
return image_infer.inference(source_face, target_face, shape_rate, id_rate, int(iterations))
|
| 46 |
-
|
| 47 |
-
def inference_image1(source_face, target_face, shape_rate, id_rate, iterations):
|
| 48 |
-
return image_infer1.inference(source_face, target_face, shape_rate, id_rate, int(iterations))
|
| 49 |
-
|
| 50 |
-
model_name = cfg.model_path.split("/")[-1] + ":" + f"{cfg.model_idx}"
|
| 51 |
-
model_name1 = model_path_1.split("/")[-1] + ":" + "190000"
|
| 52 |
-
with gr.Blocks(title="FaceSwap") as demo:
|
| 53 |
-
gr.Markdown(
|
| 54 |
-
f"""
|
| 55 |
-
### standard model: {model_name} \n
|
| 56 |
-
### model with eye and mouth hm loss: {model_name1}
|
| 57 |
-
"""
|
| 58 |
-
)
|
| 59 |
-
with gr.Tab("Image swap with standard model"):
|
| 60 |
-
with gr.Row():
|
| 61 |
-
source_image = gr.Image(shape=None, label="source image")
|
| 62 |
-
target_image = gr.Image(shape=None, label="target image")
|
| 63 |
-
with gr.Row():
|
| 64 |
-
with gr.Column():
|
| 65 |
-
structure_sim = gr.Slider(
|
| 66 |
-
minimum=0.0,
|
| 67 |
-
maximum=1.0,
|
| 68 |
-
value=1.0,
|
| 69 |
-
step=0.1,
|
| 70 |
-
label="3d similarity",
|
| 71 |
-
)
|
| 72 |
-
id_sim = gr.Slider(
|
| 73 |
-
minimum=0.0,
|
| 74 |
-
maximum=1.0,
|
| 75 |
-
value=1.0,
|
| 76 |
-
step=0.1,
|
| 77 |
-
label="id similarity",
|
| 78 |
-
)
|
| 79 |
-
iters = gr.Slider(
|
| 80 |
-
minimum=1,
|
| 81 |
-
maximum=10,
|
| 82 |
-
value=1,
|
| 83 |
-
step=1,
|
| 84 |
-
label="iters",
|
| 85 |
-
)
|
| 86 |
-
image_btn = gr.Button("image swap")
|
| 87 |
-
output_image = gr.Image(shape=None, label="Result")
|
| 88 |
-
|
| 89 |
-
image_btn.click(
|
| 90 |
-
fn=inference_image,
|
| 91 |
-
inputs=[source_image, target_image, structure_sim, id_sim, iters],
|
| 92 |
-
outputs=output_image,
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
with gr.Tab("Image swap with eye&mouth hm loss model"):
|
| 96 |
with gr.Row():
|
| 97 |
-
source_image = gr.Image(
|
| 98 |
-
target_image = gr.Image(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
with gr.Row():
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
maximum=1.0,
|
| 111 |
-
value=1.0,
|
| 112 |
-
step=0.1,
|
| 113 |
-
label="id similarity",
|
| 114 |
-
)
|
| 115 |
-
iters = gr.Slider(
|
| 116 |
-
minimum=1,
|
| 117 |
-
maximum=10,
|
| 118 |
-
value=1,
|
| 119 |
-
step=1,
|
| 120 |
-
label="iters",
|
| 121 |
-
)
|
| 122 |
-
image_btn = gr.Button("image swap")
|
| 123 |
-
output_image = gr.Image(shape=None, label="Result")
|
| 124 |
-
|
| 125 |
-
image_btn.click(
|
| 126 |
-
fn=inference_image1,
|
| 127 |
-
inputs=[source_image, target_image, structure_sim, id_sim, iters],
|
| 128 |
-
outputs=output_image,
|
| 129 |
-
)
|
| 130 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
if __name__ == "__main__":
|
| 134 |
-
main()
|
|
|
|
| 1 |
+
#######################################################################################
|
| 2 |
+
#
|
| 3 |
+
# MIT License
|
| 4 |
+
#
|
| 5 |
+
# Copyright (c) [2025] [leonelhs@gmail.com]
|
| 6 |
+
#
|
| 7 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 8 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 9 |
+
# in the Software without restriction, including without limitation the rights
|
| 10 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 11 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 12 |
+
# furnished to do so, subject to the following conditions:
|
| 13 |
+
#
|
| 14 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 15 |
+
# copies or substantial portions of the Software.
|
| 16 |
+
#
|
| 17 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 18 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 19 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 20 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 21 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 22 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 23 |
+
# SOFTWARE.
|
| 24 |
+
#
|
| 25 |
+
#######################################################################################
|
| 26 |
+
#
|
| 27 |
+
# Source code is based on or inspired by several projects.
|
| 28 |
+
# For more details and proper attribution, please refer to the following resources:
|
| 29 |
+
#
|
| 30 |
+
# - [hyxue] - [https://huggingface.co/spaces/hyxue/HiFiFace-inference-demo]
|
| 31 |
+
# - [maum-ai] [https://github.com/maum-ai/hififace]
|
| 32 |
+
#
|
| 33 |
|
| 34 |
import gradio as gr
|
| 35 |
+
import torch
|
| 36 |
+
from huggingface_hub import hf_hub_download
|
| 37 |
|
| 38 |
from benchmark.app_image import ImageSwap
|
| 39 |
+
from models.model import HifiFaceST, HifiFaceWGM
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
REPO_ID = "leonelhs/HiFiFace"
|
| 42 |
+
|
| 43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
|
| 45 |
+
gen_st_path = hf_hub_download(repo_id=REPO_ID,
|
| 46 |
+
filename="hififace_pretrained/standard_model/generator_320000.pth")
|
| 47 |
+
|
| 48 |
+
gen_wgm_path = hf_hub_download(repo_id=REPO_ID,
|
| 49 |
+
filename="hififace_pretrained/with_gaze_and_mouth/generator_190000.pth")
|
| 50 |
+
|
| 51 |
+
fade_detector_path = hf_hub_download(repo_id=REPO_ID,
|
| 52 |
+
filename="face_detector/face_detector_scrfd_10g_bnkps.onnx")
|
| 53 |
+
|
| 54 |
+
identity_extractor_config = {
|
| 55 |
+
"f_3d_checkpoint_path": hf_hub_download(repo_id=REPO_ID, filename="Deep3DFaceRecon/epoch_20.pth"),
|
| 56 |
+
"f_id_checkpoint_path": hf_hub_download(repo_id=REPO_ID, filename="arcface/ms1mv3_arcface_r100_fp16_backbone.pth")
|
| 57 |
+
}
|
| 58 |
|
| 59 |
class ConfigPath:
|
| 60 |
+
face_detector_weights = fade_detector_path
|
| 61 |
model_path = ""
|
| 62 |
model_idx = 80000
|
| 63 |
+
ffmpeg_device = device
|
| 64 |
+
device = device
|
| 65 |
|
| 66 |
+
cfg = ConfigPath()
|
| 67 |
|
| 68 |
+
model_standard = HifiFaceST(identity_extractor_config, device=device, generator_path=gen_st_path)
|
| 69 |
+
|
| 70 |
+
model_wgm = HifiFaceWGM(identity_extractor_config, device=device, generator_path=gen_wgm_path)
|
| 71 |
+
|
| 72 |
+
image_infer_standard = ImageSwap(cfg, model_standard)
|
| 73 |
+
image_infer_wgm = ImageSwap(cfg, model_wgm)
|
| 74 |
+
|
| 75 |
+
MODELS = {
|
| 76 |
+
"Standard model": "standard",
|
| 77 |
+
"Eye and mouth hm loss": "eyeandmouth",
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def inference_image(source_face, target_face, method="standard", shape_rate=1.0, id_rate=1.0, iterations=1):
|
| 81 |
+
if method == "standard":
|
| 82 |
+
return target_face, image_infer_standard.inference(source_face, target_face, shape_rate, id_rate, int(iterations))
|
| 83 |
+
return target_face, image_infer_wgm.inference(source_face, target_face, shape_rate, id_rate, int(iterations))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
with gr.Blocks(title="FaceSwap") as app:
|
| 87 |
+
gr.Markdown("## HiFiFace image swap")
|
| 88 |
+
with gr.Row():
|
| 89 |
+
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
with gr.Row():
|
| 91 |
+
source_image = gr.Image(type="numpy", label="Face image")
|
| 92 |
+
target_image = gr.Image(type="numpy", label="Body image")
|
| 93 |
+
mod = gr.Dropdown(choices=list(MODELS.items()), label="Model generator", value="standard")
|
| 94 |
+
image_btn = gr.Button("Swap image")
|
| 95 |
+
with gr.Accordion("Fine tunes", open=False):
|
| 96 |
+
structure_sim = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="3d similarity")
|
| 97 |
+
id_sim = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="id similarity")
|
| 98 |
+
iters = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="iters")
|
| 99 |
+
with gr.Column(scale=1):
|
| 100 |
with gr.Row():
|
| 101 |
+
output_image = gr.ImageSlider(label="Swapped image", type="pil")
|
| 102 |
+
|
| 103 |
+
image_btn.click(
|
| 104 |
+
fn=inference_image,
|
| 105 |
+
inputs=[source_image, target_image, mod, structure_sim, id_sim, iters],
|
| 106 |
+
outputs=output_image,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True)
|
| 110 |
+
app.queue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/model.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from typing import Dict
|
| 3 |
-
from typing import Optional
|
| 4 |
from typing import Tuple
|
| 5 |
|
| 6 |
import kornia
|
|
@@ -10,28 +10,30 @@ import torch.nn as nn
|
|
| 10 |
import torch.nn.functional as F
|
| 11 |
from loguru import logger
|
| 12 |
|
| 13 |
-
from arcface_torch.backbones.iresnet import iresnet100
|
| 14 |
-
from configs.train_config import TrainConfig
|
| 15 |
from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
|
| 16 |
from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
|
| 17 |
from HRNet.hrnet import HighResolutionNet
|
|
|
|
| 18 |
from models.discriminator import Discriminator
|
| 19 |
from models.gan_loss import GANLoss
|
| 20 |
from models.generator import Generator
|
| 21 |
from models.init_weight import init_net
|
| 22 |
|
| 23 |
-
|
| 24 |
class HifiFace:
|
| 25 |
def __init__(
|
| 26 |
self,
|
| 27 |
identity_extractor_config,
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
):
|
| 32 |
super(HifiFace, self).__init__()
|
|
|
|
|
|
|
| 33 |
self.generator = Generator(identity_extractor_config)
|
| 34 |
self.is_training = is_training
|
|
|
|
|
|
|
| 35 |
|
| 36 |
if self.is_training:
|
| 37 |
self.lr = TrainConfig().lr
|
|
@@ -80,10 +82,9 @@ class HifiFace:
|
|
| 80 |
|
| 81 |
self.dilation_kernel = torch.ones(5, 5)
|
| 82 |
|
| 83 |
-
|
| 84 |
-
self.load(load_checkpoint[0], load_checkpoint[1])
|
| 85 |
|
| 86 |
-
self.setup(device)
|
| 87 |
|
| 88 |
def save(self, path, idx=None):
|
| 89 |
os.makedirs(path, exist_ok=True)
|
|
@@ -100,18 +101,9 @@ class HifiFace:
|
|
| 100 |
torch.save(self.generator.state_dict(), g_path)
|
| 101 |
torch.save(self.discriminator.state_dict(), d_path)
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
d_path = os.path.join(path, "discriminator.pth")
|
| 107 |
-
else:
|
| 108 |
-
g_path = os.path.join(path, f"generator_{idx}.pth")
|
| 109 |
-
d_path = os.path.join(path, f"discriminator_{idx}.pth")
|
| 110 |
-
logger.info(f"Loading generator from {g_path}")
|
| 111 |
-
self.generator.load_state_dict(torch.load(g_path, map_location="cpu"))
|
| 112 |
-
if self.is_training:
|
| 113 |
-
logger.info(f"Loading discriminator from {d_path}")
|
| 114 |
-
self.discriminator.load_state_dict(torch.load(d_path, map_location="cpu"))
|
| 115 |
|
| 116 |
def setup(self, device):
|
| 117 |
self.generator.to(device)
|
|
@@ -399,37 +391,18 @@ class HifiFace:
|
|
| 399 |
}
|
| 400 |
|
| 401 |
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
# src = src.transpose(2, 0, 1)[None, ...]
|
| 418 |
-
# tgt = tgt.transpose(2, 0, 1)[None, ...]
|
| 419 |
-
# source_img = torch.from_numpy(src).float() / 255.0
|
| 420 |
-
# target_img = torch.from_numpy(tgt).float() / 255.0
|
| 421 |
-
# same_id_mask = torch.Tensor([1]).unsqueeze(0)
|
| 422 |
-
# tgt_mask = target_img[:, 0, :, :].unsqueeze(1)
|
| 423 |
-
# if torch.cuda.is_available():
|
| 424 |
-
# model.to("cuda:3")
|
| 425 |
-
# source_img = source_img.to("cuda:3")
|
| 426 |
-
# target_img = target_img.to("cuda:3")
|
| 427 |
-
# tgt_mask = tgt_mask.to("cuda:3")
|
| 428 |
-
# same_id_mask = same_id_mask.to("cuda:3")
|
| 429 |
-
# source_img = source_img.repeat(16, 1, 1, 1)
|
| 430 |
-
# target_img = target_img.repeat(16, 1, 1, 1)
|
| 431 |
-
# tgt_mask = tgt_mask.repeat(16, 1, 1, 1)
|
| 432 |
-
# same_id_mask = same_id_mask.repeat(16, 1)
|
| 433 |
-
# while True:
|
| 434 |
-
# x = model.optimize(source_img, target_img, tgt_mask, same_id_mask)
|
| 435 |
-
# print(x[0]["loss_generator"])
|
|
|
|
| 1 |
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
from typing import Dict
|
|
|
|
| 4 |
from typing import Tuple
|
| 5 |
|
| 6 |
import kornia
|
|
|
|
| 10 |
import torch.nn.functional as F
|
| 11 |
from loguru import logger
|
| 12 |
|
|
|
|
|
|
|
| 13 |
from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
|
| 14 |
from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
|
| 15 |
from HRNet.hrnet import HighResolutionNet
|
| 16 |
+
from arcface_torch.backbones.iresnet import iresnet100
|
| 17 |
from models.discriminator import Discriminator
|
| 18 |
from models.gan_loss import GANLoss
|
| 19 |
from models.generator import Generator
|
| 20 |
from models.init_weight import init_net
|
| 21 |
|
|
|
|
| 22 |
class HifiFace:
|
| 23 |
def __init__(
|
| 24 |
self,
|
| 25 |
identity_extractor_config,
|
| 26 |
+
generator_path,
|
| 27 |
+
is_training=False,
|
| 28 |
+
device="cpu"
|
| 29 |
):
|
| 30 |
super(HifiFace, self).__init__()
|
| 31 |
+
self.d_optimizer = None
|
| 32 |
+
self.g_optimizer = None
|
| 33 |
self.generator = Generator(identity_extractor_config)
|
| 34 |
self.is_training = is_training
|
| 35 |
+
self.device = device
|
| 36 |
+
self.generator_path = generator_path
|
| 37 |
|
| 38 |
if self.is_training:
|
| 39 |
self.lr = TrainConfig().lr
|
|
|
|
| 82 |
|
| 83 |
self.dilation_kernel = torch.ones(5, 5)
|
| 84 |
|
| 85 |
+
self.load_checkpoint()
|
|
|
|
| 86 |
|
| 87 |
+
self.setup(self.device)
|
| 88 |
|
| 89 |
def save(self, path, idx=None):
|
| 90 |
os.makedirs(path, exist_ok=True)
|
|
|
|
| 101 |
torch.save(self.generator.state_dict(), g_path)
|
| 102 |
torch.save(self.discriminator.state_dict(), d_path)
|
| 103 |
|
| 104 |
+
@abstractmethod
|
| 105 |
+
def load_checkpoint(self):
|
| 106 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
def setup(self, device):
|
| 109 |
self.generator.to(device)
|
|
|
|
| 391 |
}
|
| 392 |
|
| 393 |
|
| 394 |
+
class HifiFaceST(HifiFace):
|
| 395 |
+
def __init__(self, identity_extractor_config, device, generator_path):
|
| 396 |
+
super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
|
| 397 |
+
|
| 398 |
+
def load_checkpoint(self):
|
| 399 |
+
self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
|
| 400 |
+
logger.info(f"Loading generator from {self.generator_path}")
|
| 401 |
+
|
| 402 |
+
class HifiFaceWGM(HifiFace):
|
| 403 |
+
def __init__(self, identity_extractor_config, device, generator_path):
|
| 404 |
+
super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
|
| 405 |
+
|
| 406 |
+
def load_checkpoint(self):
|
| 407 |
+
self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
|
| 408 |
+
logger.info(f"Loading generator from {self.generator_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
packages.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
wget
|
|
|
|
|
|
pyrightconfig.json
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"reportMissingImports": true,
|
| 3 |
-
"reportMissingTypeStubs": true,
|
| 4 |
-
"useLibraryCodeForTypes": true,
|
| 5 |
-
"reportUnusedImport": "warning",
|
| 6 |
-
"reportUnusedVariable": "warning",
|
| 7 |
-
"reportDuplicateImport": true,
|
| 8 |
-
"reportPrivateImportUsage": false,
|
| 9 |
-
"reportWildcardImportFromLibrary": "warning",
|
| 10 |
-
"reportTypedDictNotRequiredAccess": false,
|
| 11 |
-
"reportGeneralTypeIssues": false,
|
| 12 |
-
"venvPath": "/home/xuehongyang/miniconda3/envs/",
|
| 13 |
-
"venv": "pytorch-2.0",
|
| 14 |
-
"stubPath": "/home/xuehongyang/dev_configs/typings"
|
| 15 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server.sh
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
python3 app.py
|
|
|
|
|
|