Spaces:
Running
Running
| """ | |
| CodeFormer Flask Application | |
| Deployment on Hugging Face Spaces | |
| """ | |
| import os | |
| import cv2 | |
| import torch | |
| import uuid | |
| import numpy as np | |
| import zipfile | |
| import base64 | |
| from flask import Flask, render_template, request, send_file, url_for, jsonify, send_from_directory | |
| from flask_cors import CORS | |
| from werkzeug.utils import secure_filename | |
| from torchvision.transforms.functional import normalize | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from basicsr.utils import imwrite, img2tensor, tensor2img | |
| from basicsr.utils.download_util import load_file_from_url | |
| from basicsr.utils.misc import gpu_is_available, get_device | |
| from basicsr.utils.realesrgan_utils import RealESRGANer | |
| from basicsr.utils.registry import ARCH_REGISTRY | |
| from facelib.utils.face_restoration_helper import FaceRestoreHelper | |
| from facelib.utils.misc import is_gray | |
| # --- Initialization --- | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| app.config['UPLOAD_FOLDER'] = 'static/uploads' | |
| app.config['RESULT_FOLDER'] = 'static/results' | |
| app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100MB limit | |
| # Ensure directories exist | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True) | |
| os.makedirs('weights/CodeFormer', exist_ok=True) | |
| os.makedirs('weights/facelib', exist_ok=True) | |
| os.makedirs('weights/realesrgan', exist_ok=True) | |
| # Pretrained model URLs | |
| pretrain_model_url = { | |
| 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', | |
| 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth', | |
| 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth', | |
| 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth' | |
| } | |
| def download_weights(): | |
| if not os.path.exists('weights/CodeFormer/codeformer.pth'): | |
| load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='weights/CodeFormer', progress=True, file_name=None) | |
| if not os.path.exists('weights/facelib/detection_Resnet50_Final.pth'): | |
| load_file_from_url(url=pretrain_model_url['detection'], model_dir='weights/facelib', progress=True, file_name=None) | |
| if not os.path.exists('weights/facelib/parsing_parsenet.pth'): | |
| load_file_from_url(url=pretrain_model_url['parsing'], model_dir='weights/facelib', progress=True, file_name=None) | |
| if not os.path.exists('weights/realesrgan/RealESRGAN_x2plus.pth'): | |
| load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='weights/realesrgan', progress=True, file_name=None) | |
| # Download weights on startup | |
| print("Checking weights...") | |
| download_weights() | |
| # Global models | |
| device = get_device() | |
| upsampler = None | |
| codeformer_net = None | |
| def init_models(): | |
| global upsampler, codeformer_net | |
| # RealESRGAN | |
| half = True if gpu_is_available() else False | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
| upsampler = RealESRGANer( | |
| scale=2, | |
| model_path="weights/realesrgan/RealESRGAN_x2plus.pth", | |
| model=model, | |
| tile=400, | |
| tile_pad=40, | |
| pre_pad=0, | |
| half=half, | |
| ) | |
| # CodeFormer | |
| codeformer_net = ARCH_REGISTRY.get("CodeFormer")( | |
| dim_embd=512, | |
| codebook_size=1024, | |
| n_head=8, | |
| n_layers=9, | |
| connect_list=["32", "64", "128", "256"], | |
| ).to(device) | |
| ckpt_path = "weights/CodeFormer/codeformer.pth" | |
| checkpoint = torch.load(ckpt_path)["params_ema"] | |
| codeformer_net.load_state_dict(checkpoint) | |
| codeformer_net.eval() | |
| print("Models loaded successfully.") | |
| init_models() | |
| def process_image(img_path, background_enhance, face_upsample, upscale, codeformer_fidelity): | |
| """Core inference logic""" | |
| try: | |
| # Defaults | |
| has_aligned = False | |
| only_center_face = False | |
| draw_box = False | |
| detection_model = "retinaface_resnet50" | |
| img = cv2.imread(img_path, cv2.IMREAD_COLOR) | |
| # Memory safety checks | |
| upscale = int(upscale) | |
| if upscale > 4: upscale = 4 | |
| if upscale > 2 and max(img.shape[:2]) > 1000: upscale = 2 | |
| if max(img.shape[:2]) > 1500: | |
| upscale = 1 | |
| background_enhance = False | |
| face_upsample = False | |
| face_helper = FaceRestoreHelper( | |
| upscale, | |
| face_size=512, | |
| crop_ratio=(1, 1), | |
| det_model=detection_model, | |
| save_ext="png", | |
| use_parse=True, | |
| device=device, | |
| ) | |
| bg_upsampler = upsampler if background_enhance else None | |
| face_upsampler = upsampler if face_upsample else None | |
| if has_aligned: | |
| img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) | |
| face_helper.is_gray = is_gray(img, threshold=5) | |
| face_helper.cropped_faces = [img] | |
| else: | |
| face_helper.read_image(img) | |
| face_helper.get_face_landmarks_5(only_center_face=only_center_face, resize=640, eye_dist_threshold=5) | |
| face_helper.align_warp_face() | |
| # Face restoration | |
| for idx, cropped_face in enumerate(face_helper.cropped_faces): | |
| cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True) | |
| normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) | |
| cropped_face_t = cropped_face_t.unsqueeze(0).to(device) | |
| try: | |
| with torch.no_grad(): | |
| output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0] | |
| restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) | |
| except Exception as e: | |
| print(f"Inference error: {e}") | |
| restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) | |
| restored_face = restored_face.astype("uint8") | |
| face_helper.add_restored_face(restored_face) | |
| # Paste back | |
| if not has_aligned: | |
| bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] if bg_upsampler else None | |
| face_helper.get_inverse_affine(None) | |
| if face_upsample and face_upsampler: | |
| restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box, face_upsampler=face_upsampler) | |
| else: | |
| restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box) | |
| else: | |
| restored_img = face_helper.restored_faces[0] | |
| return restored_img | |
| except Exception as e: | |
| print(f"Global processing error: {e}") | |
| return None | |
| # --- Routes --- | |
| def index(): | |
| return render_template('index.html') | |
| def process(): | |
| if 'image' not in request.files: | |
| return "No image uploaded", 400 | |
| files = request.files.getlist('image') | |
| if not files or files[0].filename == '': | |
| return "No selected file", 400 | |
| results = [] | |
| # Get params (same for all images) | |
| try: | |
| fidelity = float(request.form.get('fidelity', 0.5)) | |
| upscale = 4 # Enforce 4x upscale | |
| background_enhance = 'background_enhance' in request.form | |
| face_upsample = 'face_upsample' in request.form | |
| except ValueError: | |
| return "Invalid parameters", 400 | |
| for file in files: | |
| if file.filename == '': continue | |
| # Save input | |
| filename = str(uuid.uuid4()) + "_" + secure_filename(file.filename) | |
| input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(input_path) | |
| # Process | |
| result_img = process_image(input_path, background_enhance, face_upsample, upscale, fidelity) | |
| if result_img is None: | |
| continue # Skip failed images or handle error appropriately | |
| # Save output | |
| output_filename = "result_" + filename.rsplit('.', 1)[0] + ".png" | |
| output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename) | |
| imwrite(result_img, output_path) | |
| # Generate preview (max 1000px width/height) | |
| preview_filename = "preview_" + output_filename | |
| preview_path = os.path.join(app.config['RESULT_FOLDER'], preview_filename) | |
| h, w = result_img.shape[:2] | |
| if max(h, w) > 1000: | |
| scale = 1000 / max(h, w) | |
| preview_img = cv2.resize(result_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) | |
| imwrite(preview_img, preview_path) | |
| else: | |
| preview_filename = output_filename | |
| results.append({ | |
| 'original': filename, | |
| 'preview': preview_filename, | |
| 'download': output_filename | |
| }) | |
| if not results: | |
| return "Processing failed for all images", 500 | |
| # Create ZIP of all results | |
| zip_filename = f"batch_{uuid.uuid4()}.zip" | |
| zip_path = os.path.join(app.config['RESULT_FOLDER'], zip_filename) | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| for item in results: | |
| file_path = os.path.join(app.config['RESULT_FOLDER'], item['download']) | |
| zipf.write(file_path, item['download']) | |
| return render_template('result.html', results=results, zip_filename=zip_filename) | |
| # --- API Routes --- | |
| def api_process(): | |
| """ | |
| API endpoint for image processing. | |
| Accepts: | |
| - multipart/form-data with one or more 'image' files. | |
| - application/json with 'image_base64' string (single image) or 'images_base64' list. | |
| Parameters (form or JSON): | |
| - fidelity: (float) 0-1, default 0.5. | |
| - background_enhance: (bool) default False. | |
| - face_upsample: (bool) default False. | |
| - upscale: (int) 1-4, default 2. | |
| - return_base64: (bool) default False. | |
| """ | |
| try: | |
| is_json = request.is_json | |
| data = request.get_json() if is_json else request.form | |
| fidelity = float(data.get('fidelity', 0.5)) | |
| background_enhance = (str(data.get('background_enhance', 'false')).lower() == 'true') if not is_json else data.get('background_enhance', False) | |
| face_upsample = (str(data.get('face_upsample', 'false')).lower() == 'true') if not is_json else data.get('face_upsample', False) | |
| upscale = int(data.get('upscale', 2)) | |
| return_base64 = (str(data.get('return_base64', 'false')).lower() == 'true') if not is_json else data.get('return_base64', False) | |
| processed_images = [] | |
| inputs = [] | |
| # Handle JSON input | |
| if is_json: | |
| if 'image_base64' in data: | |
| inputs.append({'data': data['image_base64'], 'name': 'image.png'}) | |
| elif 'images_base64' in data: | |
| for idx, img_b64 in enumerate(data['images_base64']): | |
| inputs.append({'data': img_b64, 'name': f'image_{idx}.png'}) | |
| for inp in inputs: | |
| temp_filename = str(uuid.uuid4()) | |
| image_data = base64.b64decode(inp['data'].split(',')[-1]) | |
| input_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{temp_filename}.png") | |
| with open(input_path, 'wb') as f: | |
| f.write(image_data) | |
| inp['path'] = input_path | |
| inp['temp_id'] = temp_filename | |
| # Handle Multipart input | |
| elif 'image' in request.files: | |
| files = request.files.getlist('image') | |
| for file in files: | |
| if file.filename != '': | |
| temp_filename = str(uuid.uuid4()) | |
| filename = f"{temp_filename}_{secure_filename(file.filename)}" | |
| input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(input_path) | |
| inputs.append({'path': input_path, 'name': file.filename, 'temp_id': temp_filename}) | |
| if not inputs: | |
| return jsonify({"status": "error", "message": "No images provided"}), 400 | |
| for inp in inputs: | |
| # Process image | |
| result_img = process_image(inp['path'], background_enhance, face_upsample, upscale, fidelity) | |
| if result_img is not None: | |
| # Save result | |
| output_filename = f"api_result_{inp['temp_id']}.png" | |
| output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename) | |
| imwrite(result_img, output_path) | |
| res = { | |
| "original_name": inp['name'], | |
| "image_url": url_for('static', filename=f'results/{output_filename}', _external=True), | |
| "filename": output_filename | |
| } | |
| if return_base64: | |
| _, buffer = cv2.imencode('.png', result_img) | |
| img_base64 = base64.b64encode(buffer).decode('utf-8') | |
| res["image_base64"] = img_base64 | |
| processed_images.append(res) | |
| if not processed_images: | |
| return jsonify({"status": "error", "message": "Processing failed for all images"}), 500 | |
| return jsonify({ | |
| "status": "success", | |
| "count": len(processed_images), | |
| "results": processed_images | |
| }) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def health_check(): | |
| return jsonify({"status": "online", "device": str(device)}) | |
| if __name__ == '__main__': | |
| # Docker/HF Spaces entry point | |
| app.run(host='0.0.0.0', port=7860) |