""" CodeFormer Flask Application Deployment on Hugging Face Spaces """ import os import cv2 import torch import uuid import numpy as np import zipfile import base64 from flask import Flask, render_template, request, send_file, url_for, jsonify, send_from_directory from flask_cors import CORS from werkzeug.utils import secure_filename from torchvision.transforms.functional import normalize from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils import imwrite, img2tensor, tensor2img from basicsr.utils.download_util import load_file_from_url from basicsr.utils.misc import gpu_is_available, get_device from basicsr.utils.realesrgan_utils import RealESRGANer from basicsr.utils.registry import ARCH_REGISTRY from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.misc import is_gray # --- Initialization --- app = Flask(__name__) CORS(app) # Enable CORS for all routes app.config['UPLOAD_FOLDER'] = 'static/uploads' app.config['RESULT_FOLDER'] = 'static/results' app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100MB limit # Ensure directories exist os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True) os.makedirs('weights/CodeFormer', exist_ok=True) os.makedirs('weights/facelib', exist_ok=True) os.makedirs('weights/realesrgan', exist_ok=True) # Pretrained model URLs pretrain_model_url = { 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth', 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth', 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth' } def download_weights(): if not os.path.exists('weights/CodeFormer/codeformer.pth'): load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='weights/CodeFormer', progress=True, file_name=None) if not os.path.exists('weights/facelib/detection_Resnet50_Final.pth'): load_file_from_url(url=pretrain_model_url['detection'], model_dir='weights/facelib', progress=True, file_name=None) if not os.path.exists('weights/facelib/parsing_parsenet.pth'): load_file_from_url(url=pretrain_model_url['parsing'], model_dir='weights/facelib', progress=True, file_name=None) if not os.path.exists('weights/realesrgan/RealESRGAN_x2plus.pth'): load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='weights/realesrgan', progress=True, file_name=None) # Download weights on startup print("Checking weights...") download_weights() # Global models device = get_device() upsampler = None codeformer_net = None def init_models(): global upsampler, codeformer_net # RealESRGAN half = True if gpu_is_available() else False model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) upsampler = RealESRGANer( scale=2, model_path="weights/realesrgan/RealESRGAN_x2plus.pth", model=model, tile=400, tile_pad=40, pre_pad=0, half=half, ) # CodeFormer codeformer_net = ARCH_REGISTRY.get("CodeFormer")( dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=["32", "64", "128", "256"], ).to(device) ckpt_path = "weights/CodeFormer/codeformer.pth" checkpoint = torch.load(ckpt_path)["params_ema"] codeformer_net.load_state_dict(checkpoint) codeformer_net.eval() print("Models loaded successfully.") init_models() def process_image(img_path, background_enhance, face_upsample, upscale, codeformer_fidelity): """Core inference logic""" try: # Defaults has_aligned = False only_center_face = False draw_box = False detection_model = "retinaface_resnet50" img = cv2.imread(img_path, cv2.IMREAD_COLOR) # Memory safety checks upscale = int(upscale) if upscale > 4: upscale = 4 if upscale > 2 and max(img.shape[:2]) > 1000: upscale = 2 if max(img.shape[:2]) > 1500: upscale = 1 background_enhance = False face_upsample = False face_helper = FaceRestoreHelper( upscale, face_size=512, crop_ratio=(1, 1), det_model=detection_model, save_ext="png", use_parse=True, device=device, ) bg_upsampler = upsampler if background_enhance else None face_upsampler = upsampler if face_upsample else None if has_aligned: img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) face_helper.is_gray = is_gray(img, threshold=5) face_helper.cropped_faces = [img] else: face_helper.read_image(img) face_helper.get_face_landmarks_5(only_center_face=only_center_face, resize=640, eye_dist_threshold=5) face_helper.align_warp_face() # Face restoration for idx, cropped_face in enumerate(face_helper.cropped_faces): cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face_t = cropped_face_t.unsqueeze(0).to(device) try: with torch.no_grad(): output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0] restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) except Exception as e: print(f"Inference error: {e}") restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype("uint8") face_helper.add_restored_face(restored_face) # Paste back if not has_aligned: bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] if bg_upsampler else None face_helper.get_inverse_affine(None) if face_upsample and face_upsampler: restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box, face_upsampler=face_upsampler) else: restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box) else: restored_img = face_helper.restored_faces[0] return restored_img except Exception as e: print(f"Global processing error: {e}") return None # --- Routes --- @app.route('/', methods=['GET']) def index(): return render_template('index.html') @app.route('/process', methods=['POST']) def process(): if 'image' not in request.files: return "No image uploaded", 400 files = request.files.getlist('image') if not files or files[0].filename == '': return "No selected file", 400 results = [] # Get params (same for all images) try: fidelity = float(request.form.get('fidelity', 0.5)) upscale = 4 # Enforce 4x upscale background_enhance = 'background_enhance' in request.form face_upsample = 'face_upsample' in request.form except ValueError: return "Invalid parameters", 400 for file in files: if file.filename == '': continue # Save input filename = str(uuid.uuid4()) + "_" + secure_filename(file.filename) input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(input_path) # Process result_img = process_image(input_path, background_enhance, face_upsample, upscale, fidelity) if result_img is None: continue # Skip failed images or handle error appropriately # Save output output_filename = "result_" + filename.rsplit('.', 1)[0] + ".png" output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename) imwrite(result_img, output_path) # Generate preview (max 1000px width/height) preview_filename = "preview_" + output_filename preview_path = os.path.join(app.config['RESULT_FOLDER'], preview_filename) h, w = result_img.shape[:2] if max(h, w) > 1000: scale = 1000 / max(h, w) preview_img = cv2.resize(result_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) imwrite(preview_img, preview_path) else: preview_filename = output_filename results.append({ 'original': filename, 'preview': preview_filename, 'download': output_filename }) if not results: return "Processing failed for all images", 500 # Create ZIP of all results zip_filename = f"batch_{uuid.uuid4()}.zip" zip_path = os.path.join(app.config['RESULT_FOLDER'], zip_filename) with zipfile.ZipFile(zip_path, 'w') as zipf: for item in results: file_path = os.path.join(app.config['RESULT_FOLDER'], item['download']) zipf.write(file_path, item['download']) return render_template('result.html', results=results, zip_filename=zip_filename) # --- API Routes --- @app.route('/api/process', methods=['POST']) def api_process(): """ API endpoint for image processing. Accepts: - multipart/form-data with one or more 'image' files. - application/json with 'image_base64' string (single image) or 'images_base64' list. Parameters (form or JSON): - fidelity: (float) 0-1, default 0.5. - background_enhance: (bool) default False. - face_upsample: (bool) default False. - upscale: (int) 1-4, default 2. - return_base64: (bool) default False. """ try: is_json = request.is_json data = request.get_json() if is_json else request.form fidelity = float(data.get('fidelity', 0.5)) background_enhance = (str(data.get('background_enhance', 'false')).lower() == 'true') if not is_json else data.get('background_enhance', False) face_upsample = (str(data.get('face_upsample', 'false')).lower() == 'true') if not is_json else data.get('face_upsample', False) upscale = int(data.get('upscale', 2)) return_base64 = (str(data.get('return_base64', 'false')).lower() == 'true') if not is_json else data.get('return_base64', False) processed_images = [] inputs = [] # Handle JSON input if is_json: if 'image_base64' in data: inputs.append({'data': data['image_base64'], 'name': 'image.png'}) elif 'images_base64' in data: for idx, img_b64 in enumerate(data['images_base64']): inputs.append({'data': img_b64, 'name': f'image_{idx}.png'}) for inp in inputs: temp_filename = str(uuid.uuid4()) image_data = base64.b64decode(inp['data'].split(',')[-1]) input_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{temp_filename}.png") with open(input_path, 'wb') as f: f.write(image_data) inp['path'] = input_path inp['temp_id'] = temp_filename # Handle Multipart input elif 'image' in request.files: files = request.files.getlist('image') for file in files: if file.filename != '': temp_filename = str(uuid.uuid4()) filename = f"{temp_filename}_{secure_filename(file.filename)}" input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(input_path) inputs.append({'path': input_path, 'name': file.filename, 'temp_id': temp_filename}) if not inputs: return jsonify({"status": "error", "message": "No images provided"}), 400 for inp in inputs: # Process image result_img = process_image(inp['path'], background_enhance, face_upsample, upscale, fidelity) if result_img is not None: # Save result output_filename = f"api_result_{inp['temp_id']}.png" output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename) imwrite(result_img, output_path) res = { "original_name": inp['name'], "image_url": url_for('static', filename=f'results/{output_filename}', _external=True), "filename": output_filename } if return_base64: _, buffer = cv2.imencode('.png', result_img) img_base64 = base64.b64encode(buffer).decode('utf-8') res["image_base64"] = img_base64 processed_images.append(res) if not processed_images: return jsonify({"status": "error", "message": "Processing failed for all images"}), 500 return jsonify({ "status": "success", "count": len(processed_images), "results": processed_images }) except Exception as e: import traceback traceback.print_exc() return jsonify({"status": "error", "message": str(e)}), 500 @app.route('/api/health', methods=['GET']) def health_check(): return jsonify({"status": "online", "device": str(device)}) if __name__ == '__main__': # Docker/HF Spaces entry point app.run(host='0.0.0.0', port=7860)