import sys import cv2 import torch import numpy as np from PIL import Image from gradio import Interface # At the top of app.py import torch device = torch.device('cpu') # Force CPU usage # Modify model loading checkpoint = torch.load( '/CodeFormer/weights/CodeFormer/codeformer.pth', map_location='cpu' # Load weights to CPU ) net.load_state_dict(checkpoint['params_ema']) sys.path.insert(0, '/CodeFormer') from basicsr.utils import img2tensor, tensor2img from basicsr.archs.codeformer_arch import CodeFormer from facexlib.utils.face_restoration_helper import FaceRestorationHelper device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Initialize models net = CodeFormer( dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256'] ).to(device) net.load_state_dict(torch.load('/CodeFormer/weights/CodeFormer/codeformer.pth')['params_ema']) net.eval() face_helper = FaceRestorationHelper( upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=device ) def process_image(img: np.ndarray, w: float = 0.7) -> np.ndarray: face_helper.clean_all() face_helper.read_image(img) face_helper.get_face_landmarks_5() face_helper.align_warp_face() for cropped_face in face_helper.cropped_faces: cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = cropped_face_t.unsqueeze(0).to(device) with torch.no_grad(): output = net(cropped_face_t, w=w, adain=True)[0] restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) face_helper.add_restored_face(restored_face) face_helper.get_inverse_affine(None) return face_helper.paste_faces_to_input_image() def predict(input_img: Image.Image, w: float = 0.7) -> Image.Image: img = np.array(input_img) result = process_image(img, w) return Image.fromarray(result) iface = Interface( fn=predict, inputs=[ gr.Image(label="Input Image", type="pil"), gr.Slider(0.0, 1.0, value=0.7, label="Fidelity Weight") ], outputs=gr.Image(label="Enhanced Image", type="pil"), title="CodeFormer Face Restoration" ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)