lucky0146 commited on
Commit
ba71866
·
verified ·
1 Parent(s): d4e25e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -131
app.py CHANGED
@@ -1,137 +1,176 @@
1
- import os
2
- import sys
3
  import gradio as gr
4
  import torch
5
  import cv2
6
  import numpy as np
7
- from PIL import Image
8
- import urllib.request
9
- from facexlib.utils.face_restoration_helper import FaceRestoreHelper
10
- from codeformer_arch import CodeFormer
11
-
12
- # Function to download a file from a URL
13
- def download_file(url, dest):
14
- if not os.path.exists(dest):
15
- os.makedirs(os.path.dirname(dest), exist_ok=True)
16
- urllib.request.urlretrieve(url, dest)
17
- print(f"Downloaded {dest}")
18
-
19
- # Download pretrained models
20
- def setup_environment():
21
- # Download CodeFormer pretrained model
22
- model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
23
- model_path = "weights/CodeFormer/codeformer.pth"
24
- download_file(model_url, model_path)
25
-
26
- # Download facelib model (for face detection)
27
- facelib_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth"
28
- facelib_path = "weights/facelib/detection_Resnet50_Final.pth"
29
- download_file(facelib_url, facelib_path)
30
-
31
- # Download Real-ESRGAN model for background upsampling (optional)
32
- realesrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
33
- realesrgan_path = "weights/realesrgan/RealESRGAN_x4plus.pth"
34
- download_file(realesrgan_url, realesrgan_path)
35
-
36
- # Load CodeFormer model
37
- def load_codeformer():
38
- setup_environment()
39
- model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256'])
40
- # Load the state dict, extracting the 'params_ema' key
41
- checkpoint = torch.load("weights/CodeFormer/codeformer.pth", map_location='cpu')
42
- state_dict = checkpoint['params_ema'] if 'params_ema' in checkpoint else checkpoint
43
- model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore missing keys
44
- model.eval()
45
- model = model.to('cpu') # Force CPU
46
- return model
47
-
48
- # Image processing utilities (mimicking basicsr.utils)
49
- def img2tensor(img, bgr2rgb=True, float32=True):
50
- if bgr2rgb:
51
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
52
- img = torch.from_numpy(img.transpose(2, 0, 1)).float()
53
- if float32:
54
- img = img / 255.0
55
- return img
56
-
57
- def tensor2img(tensor, rgb2bgr=True, min_max=(-1, 1)):
58
- tensor = tensor.squeeze().float().cpu().clamp_(*min_max)
59
- tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) * 255.0
60
- img = tensor.numpy().transpose(1, 2, 0).astype(np.uint8)
61
- if rgb2bgr:
62
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
63
- return img
64
-
65
- # Inference function
66
- def enhance_image(input_image, fidelity_weight=0.5, background_enhance=True, face_upsample=False):
67
- # Convert PIL image to OpenCV format
68
- img = np.array(input_image)
69
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
70
-
71
- # Initialize face helper
72
- face_helper = FaceRestoreHelper(
73
- upscale_factor=1 if not face_upsample else 2,
74
- face_size=512,
75
- crop_ratio=(1, 1),
76
- det_model='retinaface_resnet50',
77
- save_ext='png',
78
- device='cpu'
79
- )
80
- face_helper.clean_all()
81
- face_helper.read_image(img)
82
- face_helper.get_face_landmarks_5()
83
- face_helper.align_warp_face()
84
-
85
- # Load CodeFormer model
86
- net = load_codeformer()
87
-
88
- # Enhance face
89
- for cropped_face in face_helper.cropped_faces:
90
- cropped_face_t = img2tensor(cropped_face, bgr2rgb=True, float32=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  with torch.no_grad():
92
- output = net(cropped_face_t.unsqueeze(0), w=fidelity_weight, adain=True)[0]
93
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
94
- restored_face = restored_face.astype('uint8')
95
- face_helper.add_restored_face(restored_face)
96
-
97
- # Get restored image
98
- face_helper.get_inverse_affine(None)
99
- restored_img = face_helper.paste_faces_to_input_image()
100
-
101
- # Background enhancement with Real-ESRGAN (optional)
102
- if background_enhance:
103
- from realesrgan import RealESRGANer
104
- upsampler = RealESRGANer(
105
- scale=4,
106
- model_path="weights/realesrgan/RealESRGAN_x4plus.pth",
107
- device='cpu'
108
- )
109
- restored_img, _ = upsampler.enhance(restored_img, outscale=4)
110
-
111
- # Convert back to PIL for Gradio
112
- restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
113
- return Image.fromarray(restored_img)
114
-
115
- # Gradio interface
116
- with gr.Blocks() as demo:
117
- gr.Markdown("# CodeFormer Face Restoration (CPU)")
118
- gr.Markdown("Upload an image to enhance faces using CodeFormer. Runs on CPU in Hugging Face Spaces.")
119
-
120
- with gr.Row():
121
- input_image = gr.Image(type="pil", label="Input Image")
122
- output_image = gr.Image(type="pil", label="Enhanced Image")
123
-
124
- fidelity_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Fidelity Weight (0 = more restoration, 1 = more original)")
125
- background_enhance = gr.Checkbox(label="Enhance Background (Real-ESRGAN)", value=True)
126
- face_upsample = gr.Checkbox(label="Upsample Restored Faces", value=False)
127
- submit_btn = gr.Button("Enhance")
128
-
129
- submit_btn.click(
130
- fn=enhance_image,
131
- inputs=[input_image, fidelity_slider, background_enhance, face_upsample],
132
- outputs=output_image
133
- )
134
-
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  if __name__ == "__main__":
136
- setup_environment()
137
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import cv2
4
  import numpy as np
5
+ import os
6
+ import time
7
+ import warnings
8
+
9
+ # Suppress specific warnings or all warnings if needed
10
+ warnings.filterwarnings("ignore")
11
+
12
+ # Try importing CodeFormer, handle potential import errors
13
+ try:
14
+ from codeformer import CodeFormer
15
+ except ImportError:
16
+ print("Error: CodeFormer not found. Make sure it's installed correctly (check requirements.txt).")
17
+ # Optionally, try adding the repo path if cloned (more complex setup)
18
+ # sys.path.append('CodeFormer') # If you cloned the repo into a folder named CodeFormer
19
+ # from basicsr.utils.registry import ARCH_REGISTRY
20
+ raise
21
+
22
+ print("Imports successful.")
23
+
24
+ # --- Configuration ---
25
+ # Automatically select CPU
26
+ device = torch.device("cpu")
27
+ print(f"Using device: {device}")
28
+
29
+ # Initialize CodeFormer - Model weights will be downloaded automatically on first run
30
+ # Ensure you have internet access in the Space for the download.
31
+ print("Initializing CodeFormer model...")
32
+ try:
33
+ # Adjust model path if needed, but pretrained=True should handle downloads
34
+ # Check the documentation for the 'codeformer' package if this fails.
35
+ # Common parameters: bg_upsampler='realesrgan', face_upsample=True
36
+ codeformer_net = CodeFormer(
37
+ dim_embd=512,
38
+ codebook_size=1024,
39
+ n_head=8,
40
+ n_layers=9,
41
+ connect_list=['32', '64', '128', '256']
42
+ ).to(device)
43
+
44
+ # Load the pre-trained model weights
45
+ # Adjust the path based on how the package stores weights or if downloaded manually
46
+ # This path assumes the standard download location used by `load_state_dict_from_url`
47
+ # It might differ based on the specific 'codeformer' pip package version.
48
+ # If this fails, check where the package downloads/expects the .pth file.
49
+ model_path = 'weights/CodeFormer/codeformer.pth' # Default path often used
50
+
51
+ # Check if the default path exists, otherwise rely on package's internal loading if possible
52
+ # A robust package might have a load_pretrained() method. Check its usage.
53
+ # This explicit loading might be needed if the package is minimal.
54
+ # Let's assume the package handles loading implicitly or requires a different call.
55
+ # Simpler approach: Rely on package potentially loading during init or a specific method.
56
+ # If the above CodeFormer() init doesn't load weights, check package docs.
57
+ # For now, let's assume the package *might* need explicit loading IF NOT BUILT-IN:
58
+
59
+ # Placeholder checkpoint loading - adjust based on actual package behavior
60
+ # This might be automatically handled by the package; if the app fails here,
61
+ # investigate how the specific `codeformer` pip package loads weights.
62
+ try:
63
+ # Example: Load weights/CodeFormer/codeformer.pth
64
+ # This path needs to be correct relative to where HF downloads/caches it, or package internal path
65
+ # It's often complex to pinpoint the exact cache location in HF Spaces
66
+ # A safer bet is often using a model hub integration if available, or ensuring the package handles it well.
67
+ # For now, we'll *assume* the package loads weights correctly or fails gracefully if not found
68
+ # checkpoint = torch.load(model_path)['params_ema']
69
+ # codeformer_net.load_state_dict(checkpoint)
70
+ print("Model weights assumed to be loaded by package or implicitly.") # Placeholder message
71
+ except FileNotFoundError:
72
+ print(f"Warning: Pretrained weights not found at default path '{model_path}'. Relying on package's internal loading mechanism if available.")
73
+ except Exception as e:
74
+ print(f"Error loading weights explicitly: {e}. Relying on package's internal loading.")
75
+
76
+ codeformer_net.eval()
77
+ print("CodeFormer model initialized successfully.")
78
+ except Exception as e:
79
+ print(f"Error initializing CodeFormer model: {e}")
80
+ # Provide helpful error message in the UI if initialization fails
81
+ gr.Error(f"Failed to load CodeFormer model. Check logs. Error: {e}")
82
+ codeformer_net = None # Set to None to prevent processing attempts
83
+
84
+ # --- Processing Function ---
85
+ def enhance_image(input_img, fidelity_weight, background_enhance, face_upsample):
86
+ """
87
+ Enhances the input image using CodeFormer.
88
+ Args:
89
+ input_img (np.ndarray): Input image from Gradio (RGB format).
90
+ fidelity_weight (float): Balances fidelity and quality (0 = best quality, 1 = best fidelity).
91
+ background_enhance (bool): Whether to enhance background using RealESRGAN.
92
+ face_upsample (bool): Whether to further upsample restored faces.
93
+ Returns:
94
+ np.ndarray: Enhanced image (RGB format).
95
+ str: Processing time message.
96
+ """
97
+ if codeformer_net is None:
98
+ return None, "Error: CodeFormer model not loaded."
99
+
100
+ if input_img is None:
101
+ return None, "Error: No input image provided."
102
+
103
+ print(f"Processing image with fidelity: {fidelity_weight}, bg_enhance: {background_enhance}, face_upsample: {face_upsample}")
104
+ start_time = time.time()
105
+
106
+ try:
107
+ # Gradio provides RGB, CodeFormer often expects BGR internally via OpenCV
108
+ img_bgr = cv2.cvtColor(input_img, cv2.COLOR_RGB2BGR)
109
+
110
+ # Enhance the image - Use the correct method from the CodeFormer package
111
+ # The method might be called 'enhance', 'process', 'restore', etc.
112
+ # Check the package documentation for the exact API.
113
+ # Assuming a method like `codeformer_net.enhance(...)` or similar exists:
114
+ # The exact parameters (like `w`, `adain`) depend on the CodeFormer implementation.
115
+ # `w` typically corresponds to fidelity_weight.
116
  with torch.no_grad():
117
+ output_bgr, _, _ = codeformer_net.enhance(
118
+ img_bgr,
119
+ w=fidelity_weight,
120
+ adain=True, # Adain usually enabled
121
+ face_upsample=face_upsample,
122
+ bg_upsampler='realesrgan' if background_enhance else None # Use bg_upsampler if requested
123
+ )
124
+
125
+ # Convert back to RGB for Gradio display
126
+ output_rgb = cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)
127
+
128
+ end_time = time.time()
129
+ processing_time = end_time - start_time
130
+ time_msg = f"Processing finished in {processing_time:.2f} seconds (on CPU)."
131
+ print(time_msg)
132
+ return output_rgb, time_msg
133
+
134
+ except Exception as e:
135
+ print(f"Error during enhancement: {e}")
136
+ import traceback
137
+ traceback.print_exc()
138
+ return None, f"Error during processing: {e}"
139
+
140
+ # --- Gradio Interface ---
141
+ title = "CodeFormer Image Enhancement (CPU Demo)"
142
+ description = """
143
+ Upload an image to enhance its quality, particularly for faces, using CodeFormer.
144
+ **Note:** This demo runs on a free Hugging Face CPU. Processing will be **SLOW** (expect seconds to minutes per image).
145
+ Adjust the fidelity weight (0 = max quality enhancement, 1 = closer to original). Optionally enhance background and upsample faces.
146
+ """
147
+ article = "<p style='text-align: center'>CodeFormer CPU Demo | <a href='https://github.com/sczhou/CodeFormer' target='_blank'>Official Repo</a></p>"
148
+
149
+ iface = gr.Interface(
150
+ fn=enhance_image,
151
+ inputs=[
152
+ gr.Image(label="Upload Image", type="numpy"),
153
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.7, label="Fidelity Weight (0 = Max Quality, 1 = Max Fidelity)"),
154
+ gr.Checkbox(label="Enhance Background (Uses RealESRGAN)", value=True),
155
+ gr.Checkbox(label="Upsample Restored Faces", value=True)
156
+ ],
157
+ outputs=[
158
+ gr.Image(label="Enhanced Image", type="numpy"),
159
+ gr.Textbox(label="Processing Time")
160
+ ],
161
+ title=title,
162
+ description=description,
163
+ article=article,
164
+ examples=[
165
+ ["examples/face1.png", 0.7, True, True], # Add example files to an 'examples' folder in your Space
166
+ ["examples/face2.png", 0.5, True, True],
167
+ ["examples/bg1.png", 0.8, True, False],
168
+ ],
169
+ allow_flagging="never" # Can change to "manual" or "auto" if needed
170
+ )
171
+
172
+ # --- Launch the App ---
173
  if __name__ == "__main__":
174
+ iface.launch()
175
+ print("Gradio app launched.")
176
+