lucky0146 commited on
Commit
fa500fc
·
verified ·
1 Parent(s): f828903

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -213
app.py CHANGED
@@ -1,219 +1,227 @@
1
- # app.py
2
- # Updated: 2025-04-05 18:56:50 PM IST (Ludhiana, Punjab, India)
3
- # Fix: Removed .to(device) call after CodeFormer() initialization due to AttributeError.
4
- # Added check for .eval() method.
5
-
6
- import gradio as gr
7
  import torch
8
- import cv2
 
9
  import numpy as np
10
- import os
11
- import time
12
- import warnings
13
- import traceback # Import traceback for detailed error printing
14
-
15
- print("--- Script Start ---")
16
- print(f"Current Time (IST): {time.strftime('%Y-%m-%d %H:%M:%S')}") # Using system time from HF Space
17
-
18
- # Suppress specific warnings or all warnings if needed
19
- warnings.filterwarnings("ignore")
20
-
21
- # --- Globals ---
22
- codeformer_net = None
23
- is_initialized = False
24
- init_error_message = ""
25
-
26
- # --- Attempt to Import CodeFormer ---
27
- try:
28
- from codeformer import CodeFormer
29
- print("Successfully imported CodeFormer.")
30
- except ImportError as e:
31
- init_error_message = f"Error: CodeFormer package not found or import failed. Check requirements.txt installation. Details: {e}"
32
- print(init_error_message)
33
- # Keep script running so Gradio might load with an error message
34
- except Exception as e:
35
- init_error_message = f"An unexpected error occurred during import: {e}\n{traceback.format_exc()}"
36
- print(init_error_message)
37
-
38
-
39
- # --- Initialize Model (only if import succeeded) ---
40
- if 'CodeFormer' in globals(): # Check if import was successful
41
- print("Attempting to initialize CodeFormer model (simplified init, no .to(device))...")
42
- try:
43
- # Define the target device (though we won't call .to() on the model object)
44
- device = torch.device("cpu")
45
- print(f"Target device check: {device}")
46
-
47
- # *** MODIFIED LINE BELOW ***
48
- # Initialize WITHOUT .to(device)
49
- codeformer_net = CodeFormer()
50
- print("CodeFormer() object created.")
51
-
52
- # Check if the object has an eval method (common for models)
53
- # If it doesn't, it might not be a standard torch.nn.Module
54
- try:
55
- codeformer_net.eval()
56
- print("codeformer_net.eval() called successfully.")
57
- except AttributeError:
58
- print("Note: codeformer_net object does not have an eval() method. Assuming it's not needed or handled internally.")
59
- except Exception as e_eval:
60
- print(f"Warning: Calling codeformer_net.eval() failed: {e_eval}")
61
-
62
-
63
- is_initialized = True
64
- print("CodeFormer model object initialization appears successful (removed .to(device)).")
65
-
66
- except FileNotFoundError as e:
67
- # This error might occur if CodeFormer() tries to load weights internally and fails
68
- init_error_message = f"Error: Could not find CodeFormer model weights during init. The package might have failed to download them automatically. Check logs. Details: {e}\n{traceback.format_exc()}"
69
- print(init_error_message)
70
- codeformer_net = None # Ensure model is None if init fails
71
- except TypeError as e:
72
- # Catch TypeErrors specifically from the init call
73
- init_error_message = f"TypeError during CodeFormer init. The constructor signature might be wrong. Details: {e}\n{traceback.format_exc()}"
74
- print(init_error_message)
75
- codeformer_net = None
76
- except Exception as e:
77
- # Catch any other initialization errors
78
- init_error_message = f"Error initializing CodeFormer model object: {e}\n{traceback.format_exc()}"
79
- print(init_error_message)
80
- codeformer_net = None # Ensure model is None if init fails
81
- else:
82
- # If import failed, ensure message reflects that
83
- if not init_error_message: # Safety net if import error wasn't captured somehow
84
- init_error_message = "CodeFormer could not be imported. Cannot initialize model."
85
- print("Skipping model initialization due to import failure.")
86
-
87
-
88
- # --- Processing Function ---
89
- def enhance_image(input_img, fidelity_weight, background_enhance, face_upsample):
90
- """
91
- Enhances the input image using CodeFormer.
92
- Args:
93
- input_img (np.ndarray): Input image from Gradio (RGB format).
94
- fidelity_weight (float): Balances fidelity and quality (0 = best quality, 1 = best fidelity).
95
- background_enhance (bool): Whether to enhance background using RealESRGAN.
96
- face_upsample (bool): Whether to further upsample restored faces.
97
- Returns:
98
- np.ndarray: Enhanced image (RGB format) or None on error.
99
- str: Status or processing time message.
100
- """
101
- print("--- enhance_image function called ---") # Log function entry
102
-
103
- if not is_initialized or codeformer_net is None:
104
- error_msg = f"ERROR: CodeFormer model is not available. Initialization failed. Check logs for details. Message: {init_error_message}"
105
- print(error_msg)
106
- # Return None for the image and the error message for the status textbox
107
- return None, error_msg
108
-
109
- if input_img is None:
110
- print("Error: No input image provided.")
111
- return None, "Error: No input image provided."
112
-
113
- print(f"Processing image with fidelity: {fidelity_weight}, bg_enhance: {background_enhance}, face_upsample: {face_upsample}")
114
- start_time = time.time()
115
-
116
- try:
117
- # 1. Convert RGB (from Gradio) to BGR (often expected by OpenCV/CodeFormer backend)
118
- img_bgr = cv2.cvtColor(input_img, cv2.COLOR_RGB2BGR)
119
- print("Input image converted from RGB to BGR.")
120
-
121
- # 2. Select background upsampler based on checkbox
122
- bg_upsampler = 'realesrgan' if background_enhance else None
123
- print(f"Background upsampler selected: {bg_upsampler}")
124
-
125
- # 3. Run CodeFormer enhancement
126
- # Ensure parameters match the CodeFormer API you have installed.
127
- # `w` corresponds to fidelity_weight.
128
- print("Starting CodeFormer enhancement...")
129
- # *** Assuming enhance method exists and works on CPU by default ***
130
- with torch.no_grad(): # Keep torch.no_grad() as internal ops might use Torch
131
- output_bgr, _, _ = codeformer_net.enhance(
132
- img_bgr,
133
- w=fidelity_weight,
134
- adain=True, # AdaIN is commonly used
135
- face_upsample=face_upsample,
136
- bg_upsampler=bg_upsampler
137
- )
138
- print("CodeFormer enhancement finished.")
139
-
140
- # 4. Convert BGR output back to RGB for Gradio display
141
- output_rgb = cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)
142
- print("Output image converted from BGR to RGB.")
143
-
144
- end_time = time.time()
145
- processing_time = end_time - start_time
146
- time_msg = f"Processing finished in {processing_time:.2f} seconds (on CPU)."
147
- print(time_msg)
148
- return output_rgb, time_msg
149
-
150
- except Exception as e:
151
- # Catch errors during the enhancement process
152
- error_details = f"Error during enhancement: {e}\n{traceback.format_exc()}"
153
- print(error_details)
154
- return None, error_details # Return error message to the status textbox
155
-
156
- # --- Gradio Interface Definition ---
157
- title = "CodeFormer Image Enhancement (CPU Demo)"
158
- # Dynamically update description based on model load status
159
- status_message = 'Model initialization appears successful.' if is_initialized else f'Model Load FAILED: {init_error_message}'
160
- description = f"""
161
- Upload an image to enhance its quality, particularly for faces, using CodeFormer.
162
- **Note:** This demo runs on a free Hugging Face CPU. Processing will be **SLOW** (expect seconds to minutes per image).
163
- Adjust the fidelity weight (0 = max quality enhancement, 1 = closer to original). Optionally enhance background and upsample faces.
164
- **Status:** {status_message} Check Logs for details if failed.
165
- """
166
- article = "<p style='text-align: center'>CodeFormer CPU Demo | <a href='https://github.com/sczhou/CodeFormer' target='_blank'>Official Repo</a></p>"
167
-
168
- # Define examples (Make sure the 'examples' folder and files exist in your Space repo)
169
- example_list = []
170
- if os.path.exists("examples"):
171
- example_list = [
172
- ["examples/face1.png", 0.7, True, True],
173
- ["examples/face2.png", 0.5, True, True],
174
- ["examples/bg1.png", 0.8, True, False],
175
- ]
176
- print("Example files folder found.")
177
- else:
178
- print("Note: 'examples' folder not found. Gradio examples will be empty.")
179
-
180
-
181
- print("Defining Gradio interface...")
182
- try:
183
- iface = gr.Interface(
184
- fn=enhance_image,
185
- inputs=[
186
- gr.Image(label="Upload Image", type="numpy"),
187
- gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.7, label="Fidelity Weight (0 = Max Quality, 1 = Max Fidelity)"),
188
- gr.Checkbox(label="Enhance Background (Uses RealESRGAN)", value=True),
189
- gr.Checkbox(label="Upsample Restored Faces", value=True)
190
- ],
191
- outputs=[
192
- gr.Image(label="Enhanced Image", type="numpy"),
193
- gr.Textbox(label="Status / Processing Time") # Output for status messages and time
194
- ],
195
- title=title,
196
- description=description,
197
- article=article,
198
- examples=example_list,
199
- allow_flagging="never"
200
  )
201
- print("Gradio interface defined successfully.")
202
- except Exception as e:
203
- print(f"FATAL: Failed to define Gradio interface: {e}\n{traceback.format_exc()}")
204
- # If interface definition fails, we can't launch.
205
- raise RuntimeError("Could not create Gradio Interface.") from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
 
 
 
 
207
 
208
- # --- Launch the App ---
209
  if __name__ == "__main__":
210
- print("Attempting to launch Gradio app...")
211
- try:
212
- iface.launch()
213
- print("Gradio app launched successfully.")
214
- print("--- Script End (App Running) ---")
215
- except Exception as e:
216
- print(f"FATAL: Failed to launch Gradio interface: {e}\n{traceback.format_exc()}")
217
- # Exit if launch fails. Logs should show the error.
218
- exit(1)
219
-
 
1
+ import os
2
+ import sys
 
 
 
 
3
  import torch
4
+ import gradio as gr
5
+ from PIL import Image
6
  import numpy as np
7
+ import requests
8
+ from io import BytesIO
9
+ from basicsr.utils import imwrite
10
+ from torchvision.transforms.functional import normalize
11
+
12
+ # Clone the CodeFormer repository if it doesn't exist
13
+ if not os.path.exists('CodeFormer'):
14
+ !git clone https://github.com/sczhou/CodeFormer.git
15
+ !pip install -r CodeFormer/requirements.txt
16
+ !pip install basicsr
17
+ !pip install facexlib
18
+ !pip install gradio>=3.25.0
19
+ !pip install realesrgan
20
+ !python CodeFormer/basicsr/setup.py develop
21
+
22
+ # Add the CodeFormer directory to the system path
23
+ sys.path.append('CodeFormer')
24
+
25
+ # Import necessary modules from CodeFormer
26
+ from CodeFormer.basicsr.archs.codeformer_arch import CodeFormer
27
+ from CodeFormer.basicsr.utils.registry import ARCH_REGISTRY
28
+ from CodeFormer.facelib.utils.face_restoration_helper import FaceRestoreHelper
29
+ from CodeFormer.facelib.detection.retinaface import retinaface
30
+
31
+ # Function to download model weights
32
+ def download_model_weights():
33
+ if not os.path.exists('CodeFormer/weights'):
34
+ os.makedirs('CodeFormer/weights/CodeFormer', exist_ok=True)
35
+ os.makedirs('CodeFormer/weights/facelib', exist_ok=True)
36
+
37
+ # Download CodeFormer weights
38
+ codeformer_weight_path = 'CodeFormer/weights/CodeFormer/codeformer.pth'
39
+ if not os.path.exists(codeformer_weight_path):
40
+ print('Downloading CodeFormer weights...')
41
+ url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
42
+ r = requests.get(url, allow_redirects=True)
43
+ with open(codeformer_weight_path, 'wb') as f:
44
+ f.write(r.content)
45
+
46
+ # Download detection model weights
47
+ detection_model_path = 'CodeFormer/weights/facelib/detection_Resnet50_Final.pth'
48
+ if not os.path.exists(detection_model_path):
49
+ print('Downloading face detection model weights...')
50
+ url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth'
51
+ r = requests.get(url, allow_redirects=True)
52
+ with open(detection_model_path, 'wb') as f:
53
+ f.write(r.content)
54
+
55
+ # Download parsing model weights
56
+ parsing_model_path = 'CodeFormer/weights/facelib/parsing_parsenet.pth'
57
+ if not os.path.exists(parsing_model_path):
58
+ print('Downloading face parsing model weights...')
59
+ url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
60
+ r = requests.get(url, allow_redirects=True)
61
+ with open(parsing_model_path, 'wb') as f:
62
+ f.write(r.content)
63
+
64
+ # Load CodeFormer model
65
+ def load_codeformer_model():
66
+ # Force to use CPU
67
+ device = torch.device('cpu')
68
+ print(f'Running on device: {device}')
69
+
70
+ # Download model weights if they don't exist
71
+ download_model_weights()
72
+
73
+ # Load CodeFormer model
74
+ codeformer_net = ARCH_REGISTRY.get('CodeFormer')(
75
+ dim_embd=512,
76
+ codebook_size=1024,
77
+ n_head=8,
78
+ n_layers=9,
79
+ connect_list=['32', '64', '128', '256']
80
+ ).to(device)
81
+
82
+ ckpt_path = 'CodeFormer/weights/CodeFormer/codeformer.pth'
83
+ checkpoint = torch.load(ckpt_path, map_location=device)
84
+
85
+ if 'params_ema' in checkpoint:
86
+ codeformer_net.load_state_dict(checkpoint['params_ema'])
87
+ else:
88
+ codeformer_net.load_state_dict(checkpoint['params'])
89
+
90
+ codeformer_net.eval()
91
+
92
+ # Setup face restoration helper
93
+ face_helper = FaceRestoreHelper(
94
+ upscale_factor=1,
95
+ face_size=512,
96
+ crop_ratio=(1, 1),
97
+ det_model='retinaface_resnet50',
98
+ save_ext='png',
99
+ use_parse=True,
100
+ device=device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  )
102
+
103
+ return codeformer_net, face_helper
104
+
105
+ # Process the image with CodeFormer
106
+ def process_image(image, w=0.5, has_aligned=False):
107
+ device = torch.device('cpu')
108
+ codeformer_net, face_helper = load_codeformer_model()
109
+
110
+ # Convert the input image to numpy array
111
+ if isinstance(image, Image.Image):
112
+ img = np.array(image)
113
+ else:
114
+ img = image
115
+
116
+ if has_aligned:
117
+ # The input image is already a cropped and aligned face
118
+ face_helper.is_gray = len(img.shape) == 2 or (len(img.shape) == 3 and img.shape[2] == 1)
119
+ if face_helper.is_gray:
120
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
121
+ # Prepare the face for processing
122
+ face_helper.cropped_faces = [img]
123
+ else:
124
+ # Detect and crop faces from the input image
125
+ face_helper.clean_all()
126
+ face_helper.read_image(img)
127
+ face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
128
+ face_helper.align_warp_face()
129
+
130
+ # If no face is detected
131
+ if len(face_helper.cropped_faces) == 0:
132
+ return image, "No face detected. Please try another image."
133
+
134
+ # CodeFormer inference
135
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
136
+ # Prepare the image for the model
137
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
138
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
139
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
140
+
141
+ try:
142
+ with torch.no_grad():
143
+ output = codeformer_net(cropped_face_t, w=w, adain=True)[0]
144
+ # Convert tensor to image
145
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
146
+ del output
147
+ torch.cuda.empty_cache()
148
+ except Exception as error:
149
+ print(f'Error: {error}')
150
+ restored_face = cropped_face
151
+
152
+ # Save the restored face
153
+ face_helper.add_restored_face(restored_face)
154
+
155
+ # Get the final result
156
+ if not has_aligned:
157
+ # Paste the restored faces back to the original image
158
+ face_helper.get_inverse_affine(None)
159
+ restored_img = face_helper.paste_faces_to_input_image()
160
+ restored_img = Image.fromarray(restored_img)
161
+ else:
162
+ restored_img = Image.fromarray(face_helper.restored_faces[0])
163
+
164
+ return restored_img, "Face successfully restored."
165
+
166
+ # Helper functions for image conversion
167
+ def img2tensor(img, bgr2rgb=True, float32=True):
168
+ img = img.astype(np.float32) if float32 else img
169
+ if bgr2rgb:
170
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
171
+ img = torch.from_numpy(img.transpose(2, 0, 1))
172
+ return img
173
+
174
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
175
+ tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
176
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
177
+ n_dim = tensor.dim()
178
+ if n_dim == 3:
179
+ img_np = tensor.numpy()
180
+ img_np = img_np.transpose(1, 2, 0)
181
+ if rgb2bgr:
182
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
183
+ elif n_dim == 2:
184
+ img_np = tensor.numpy()
185
+ else:
186
+ raise TypeError(f'Only support 3D and 2D tensor. But got {n_dim}D tensor.')
187
+ if out_type == np.uint8:
188
+ img_np = (img_np * 255.0).round().astype(np.uint8)
189
+ return img_np
190
+
191
+ # Create a Gradio interface for the app
192
+ def create_gradio_interface():
193
+ with gr.Blocks(title="CodeFormer Face Restoration (CPU Version)") as app:
194
+ gr.Markdown("# CodeFormer Face Restoration (CPU Version)")
195
+ gr.Markdown("Upload a photo with faces to restore the quality. This model runs on CPU, so it might take a few minutes to process.")
196
+
197
+ with gr.Row():
198
+ with gr.Column():
199
+ input_image = gr.Image(label="Input Image", type="pil")
200
+ w_slider = gr.Slider(0, 1, value=0.5, step=0.1, label="Fidelity Weight (0: more quality, 1: more identity)")
201
+ aligned_checkbox = gr.Checkbox(label="Input is an already aligned face", value=False)
202
+ process_button = gr.Button("Restore Face")
203
+
204
+ with gr.Column():
205
+ output_image = gr.Image(label="Restored Image")
206
+ output_text = gr.Textbox(label="Status")
207
+
208
+ process_button.click(
209
+ fn=process_image,
210
+ inputs=[input_image, w_slider, aligned_checkbox],
211
+ outputs=[output_image, output_text]
212
+ )
213
+
214
+ gr.Markdown("Note: Lower fidelity weight (w) values create higher-quality results with more modifications, while higher values preserve more of the original identity.")
215
+
216
+ return app
217
+
218
+ # Import CV2 only when needed to avoid issues
219
+ import cv2
220
 
221
+ # Main function
222
+ def main():
223
+ app = create_gradio_interface()
224
+ app.launch(share=True)
225
 
 
226
  if __name__ == "__main__":
227
+ main()