lucky0146 commited on
Commit
0c96f3e
·
verified ·
1 Parent(s): 37aabe6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -27
app.py CHANGED
@@ -6,7 +6,6 @@ import cv2
6
  import numpy as np
7
  from PIL import Image
8
  import urllib.request
9
- import tarfile
10
 
11
  # Function to download a file from a URL
12
  def download_file(url, dest):
@@ -22,39 +21,23 @@ def setup_environment():
22
  model_path = "weights/codeformer.pth"
23
  download_file(model_url, model_path)
24
 
25
- # Download facexlib detection models (needed for face detection)
26
  retinaface_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth"
27
  retinaface_path = "weights/detection_Resnet50_Final.pth"
28
  download_file(retinaface_url, retinaface_path)
29
 
30
- # Define a simplified CodeFormer architecture (placeholder)
31
- class CodeFormer(torch.nn.Module):
32
- def __init__(self, dim_embd=512, codebook_size=1024, n_head=8, n_layer=9, connect_list=['32', '64', '128', '256']):
33
- super(CodeFormer, self).__init__()
34
- # Simplified placeholder (full architecture requires codeformer_arch.py)
35
- self.encoder = torch.nn.Sequential(
36
- torch.nn.Conv2d(3, dim_embd, kernel_size=3, stride=1, padding=1),
37
- torch.nn.ReLU(),
38
- torch.nn.Conv2d(dim_embd, dim_embd, kernel_size=3, stride=1, padding=1)
39
- )
40
- self.decoder = torch.nn.Sequential(
41
- torch.nn.ConvTranspose2d(dim_embd, 3, kernel_size=3, stride=1, padding=1),
42
- torch.nn.Sigmoid()
43
- )
44
-
45
- def forward(self, x, w=0.5, adain=True):
46
- # Simplified forward pass (placeholder)
47
- enc = self.encoder(x)
48
- dec = self.decoder(enc)
49
- return dec
50
 
51
  # Load CodeFormer model
52
  def load_codeformer():
53
  setup_environment()
 
54
  model_path = "weights/codeformer.pth"
55
- net = CodeFormer().to('cpu')
56
  checkpoint = torch.load(model_path, map_location='cpu')
57
- net.load_state_dict(checkpoint, strict=False) # strict=False due to simplified architecture
58
  net.eval()
59
  return net
60
 
@@ -90,8 +73,8 @@ def enhance_image(image, fidelity_weight=0.5):
90
  face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device='cpu')
91
  face_helper.clean_all()
92
  face_helper.read_image(img)
93
- face_helper.get_face_landmarks_5() # Removed align=True
94
- face_helper.align_warp_face() # Ensure alignment happens here
95
 
96
  # Enhance face with CodeFormer
97
  for cropped_face in face_helper.cropped_faces:
@@ -129,6 +112,5 @@ with gr.Blocks() as demo:
129
  )
130
 
131
  if __name__ == "__main__":
132
- # Ensure setup runs once
133
  setup_environment()
134
  demo.launch()
 
6
  import numpy as np
7
  from PIL import Image
8
  import urllib.request
 
9
 
10
  # Function to download a file from a URL
11
  def download_file(url, dest):
 
21
  model_path = "weights/codeformer.pth"
22
  download_file(model_url, model_path)
23
 
24
+ # Download facexlib detection models
25
  retinaface_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth"
26
  retinaface_path = "weights/detection_Resnet50_Final.pth"
27
  download_file(retinaface_url, retinaface_path)
28
 
29
+ # Download codeformer_arch.py
30
+ arch_url = "https://raw.githubusercontent.com/sczhou/CodeFormer/master/codeformer_arch.py"
31
+ download_file(arch_url, "codeformer_arch.py")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # Load CodeFormer model
34
  def load_codeformer():
35
  setup_environment()
36
+ from codeformer_arch import CodeFormer
37
  model_path = "weights/codeformer.pth"
38
+ net = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9, connect_list=['32', '64', '128', '256']).to('cpu')
39
  checkpoint = torch.load(model_path, map_location='cpu')
40
+ net.load_state_dict(checkpoint)
41
  net.eval()
42
  return net
43
 
 
73
  face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device='cpu')
74
  face_helper.clean_all()
75
  face_helper.read_image(img)
76
+ face_helper.get_face_landmarks_5()
77
+ face_helper.align_warp_face()
78
 
79
  # Enhance face with CodeFormer
80
  for cropped_face in face_helper.cropped_faces:
 
112
  )
113
 
114
  if __name__ == "__main__":
 
115
  setup_environment()
116
  demo.launch()