lucky0146 commited on
Commit
6e685c3
·
verified ·
1 Parent(s): cd12255

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from gradio import Interface
7
+
8
+ sys.path.insert(0, '/CodeFormer')
9
+
10
+ from basicsr.utils import img2tensor, tensor2img
11
+ from basicsr.archs.codeformer_arch import CodeFormer
12
+ from facexlib.utils.face_restoration_helper import FaceRestorationHelper
13
+
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ # Initialize models
17
+ net = CodeFormer(
18
+ dim_embd=512,
19
+ codebook_size=1024,
20
+ n_head=8,
21
+ n_layers=9,
22
+ connect_list=['32', '64', '128', '256']
23
+ ).to(device)
24
+ net.load_state_dict(torch.load('/CodeFormer/weights/CodeFormer/codeformer.pth')['params_ema'])
25
+ net.eval()
26
+
27
+ face_helper = FaceRestorationHelper(
28
+ upscale_factor=1,
29
+ face_size=512,
30
+ crop_ratio=(1, 1),
31
+ det_model='retinaface_resnet50',
32
+ save_ext='png',
33
+ use_parse=True,
34
+ device=device
35
+ )
36
+
37
+ def process_image(img: np.ndarray, w: float = 0.7) -> np.ndarray:
38
+ face_helper.clean_all()
39
+ face_helper.read_image(img)
40
+ face_helper.get_face_landmarks_5()
41
+ face_helper.align_warp_face()
42
+
43
+ for cropped_face in face_helper.cropped_faces:
44
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
45
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
46
+
47
+ with torch.no_grad():
48
+ output = net(cropped_face_t, w=w, adain=True)[0]
49
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
50
+
51
+ face_helper.add_restored_face(restored_face)
52
+
53
+ face_helper.get_inverse_affine(None)
54
+ return face_helper.paste_faces_to_input_image()
55
+
56
+ def predict(input_img: Image.Image, w: float = 0.7) -> Image.Image:
57
+ img = np.array(input_img)
58
+ result = process_image(img, w)
59
+ return Image.fromarray(result)
60
+
61
+ iface = Interface(
62
+ fn=predict,
63
+ inputs=[
64
+ gr.Image(label="Input Image", type="pil"),
65
+ gr.Slider(0.0, 1.0, value=0.7, label="Fidelity Weight")
66
+ ],
67
+ outputs=gr.Image(label="Enhanced Image", type="pil"),
68
+ title="CodeFormer Face Restoration"
69
+ )
70
+
71
+ if __name__ == "__main__":
72
+ iface.launch(server_name="0.0.0.0", server_port=7860)