Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import torch | |
| from flask import Flask, request, jsonify, send_file, render_template_string | |
| from werkzeug.utils import secure_filename | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| # import the necessary libraries from the local project directories | |
| import __init_paths | |
| from face_detect.retinaface_detection import RetinaFaceDetection | |
| from face_parse.face_parsing import FaceParse | |
| from face_model.face_gan import FaceGAN | |
| from sr_model.real_esrnet import RealESRNet | |
| from align_faces import warp_and_crop_face, get_reference_facial_points | |
| # Flask App setup | |
| app = Flask(__name__) | |
| UPLOAD_FOLDER = 'uploads' | |
| RESULTS_FOLDER = 'results' | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| app.config['RESULTS_FOLDER'] = RESULTS_FOLDER | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| os.makedirs(RESULTS_FOLDER, exist_ok=True) | |
| # The HTML content is defined here, before it is used by the `index` function. | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html lang="ar" dir="rtl"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>GPEN: تحسين جودة الوجه</title> | |
| <style> | |
| body { text-align:center; font-family: sans-serif; background-color: #f0f2f5; padding: 20px; } | |
| .container { max-width: 800px; margin: auto; background-color: white; padding: 30px; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); } | |
| .image-display { display:flex; justify-content:center; gap:20px; margin-top:20px; } | |
| img { max-width:100%; height: auto; border:1px solid #ddd; border-radius: 4px; } | |
| h3 { margin-bottom: 10px; } | |
| .button { padding: 10px 20px; background-color: #007bff; color: white; border: none; border-radius: 5px; cursor: pointer; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>GPEN: تحسين جودة الوجه</h1> | |
| <p>قم بتحميل صورة وسيقوم النموذج بتحسين جودة الوجه فيها.</p> | |
| <form id="uploadForm" enctype="multipart/form-data"> | |
| <input type="file" id="fileInput" name="file" accept="image/*" required> | |
| <button type="submit" class="button">معالجة الصورة</button> | |
| </form> | |
| <p id="loading" style="display:none;">... جاري المعالجة، يرجى الانتظار.</p> | |
| <div class="image-display"> | |
| <div> | |
| <h3>الأصلية</h3> | |
| <img id="original" style="display:none;"> | |
| </div> | |
| <div> | |
| <h3>النتيجة</h3> | |
| <img id="result" style="display:none;"> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| const form = document.getElementById("uploadForm"); | |
| const fileInput = document.getElementById("fileInput"); | |
| const loading = document.getElementById("loading"); | |
| const original = document.getElementById("original"); | |
| const result = document.getElementById("result"); | |
| fileInput.addEventListener("change", (e) => { | |
| if (e.target.files.length > 0) { | |
| original.src = URL.createObjectURL(e.target.files[0]); | |
| original.style.display = "block"; | |
| result.style.display = "none"; | |
| } | |
| }); | |
| form.addEventListener("submit", async (e) => { | |
| e.preventDefault(); | |
| if (fileInput.files.length === 0) return; | |
| const formData = new FormData(); | |
| formData.append("file", fileInput.files[0]); | |
| loading.style.display = "block"; | |
| try { | |
| const response = await fetch("/process_image", { | |
| method: "POST", | |
| body: formData | |
| }); | |
| if (response.ok) { | |
| const blob = await response.blob(); | |
| const url = URL.createObjectURL(blob); | |
| result.src = url; | |
| result.style.display = "block"; | |
| } else { | |
| const error = await response.json(); | |
| alert("خطأ: " + error.error); | |
| } | |
| } catch (err) { | |
| alert("فشل في الاتصال بالخادم: " + err); | |
| } finally { | |
| loading.style.display = "none"; | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # Function to download models if they don't exist | |
| def download_models(): | |
| models_to_download = { | |
| 'weights/RetinaFace-R50.pth': 'https://huggingface.co/akhaliq/RetinaFace-R50/resolve/main/RetinaFace-R50.pth', | |
| 'weights/GPEN-BFR-512.pth': 'https://huggingface.co/akhaliq/GPEN-BFR-512/resolve/main/GPEN-BFR-512.pth', | |
| 'weights/realesrnet_x2.pth': 'https://huggingface.co/akhaliq/realesrnet_x2/resolve/main/realesrnet_x2.pth', | |
| 'weights/ParseNet-latest.pth': 'https://huggingface.co/akhaliq/ParseNet-latest/resolve/main/ParseNet-latest.pth' | |
| } | |
| for local_path, url in models_to_download.items(): | |
| if not os.path.exists(local_path): | |
| print(f"Downloading {local_path}...") | |
| os.makedirs(os.path.dirname(local_path), exist_ok=True) | |
| os.system(f'wget "{url}" -O {local_path}') | |
| print(f"{local_path} downloaded.") | |
| # Global object for face enhancement model | |
| faceenhancer = None | |
| def initialize_model(): | |
| global faceenhancer | |
| if faceenhancer is None: | |
| try: | |
| download_models() | |
| model = "GPEN-BFR-512" | |
| key = None | |
| size = 512 | |
| channel_multiplier = 2 | |
| narrow = 1 | |
| use_sr = False | |
| sr_model = 'realesrnet_x2' | |
| use_cuda = False | |
| faceenhancer = FaceEnhancement( | |
| size=size, | |
| model=model, | |
| use_sr=use_sr, | |
| sr_model=sr_model, | |
| channel_multiplier=channel_multiplier, | |
| narrow=narrow, | |
| key=key, | |
| device='cpu' | |
| ) | |
| print("✅ FaceEnhancement model initialized successfully.") | |
| except Exception as e: | |
| print(f"❌ Error initializing model: {e}") | |
| faceenhancer = None | |
| # Class from the original code | |
| class FaceEnhancement(object): | |
| def __init__(self, base_dir='./', size=512, model=None, use_sr=True, sr_model=None, channel_multiplier=2, narrow=1, key=None, device='cuda'): | |
| self.facedetector = RetinaFaceDetection(base_dir, device) | |
| self.facegan = FaceGAN(base_dir, size, model, channel_multiplier, narrow, key, device=device) | |
| self.srmodel = RealESRNet(base_dir, sr_model, device=device) | |
| self.faceparser = FaceParse(base_dir, device=device) | |
| self.use_sr = use_sr | |
| self.size = size | |
| self.threshold = 0.9 | |
| self.mask = np.zeros((512, 512), np.float32) | |
| cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA) | |
| self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) | |
| self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) | |
| self.kernel = np.array(( | |
| [0.0625, 0.125, 0.0625], | |
| [0.125, 0.25, 0.125], | |
| [0.0625, 0.125, 0.0625]), dtype="float32") | |
| default_square = True | |
| inner_padding_factor = 0.25 | |
| outer_padding = (0, 0) | |
| self.reference_5pts = get_reference_facial_points( | |
| (self.size, self.size), inner_padding_factor, outer_padding, default_square) | |
| def mask_postprocess(self, mask, thres=20): | |
| mask[:thres, :] = 0; mask[-thres:, :] = 0 | |
| mask[:, :thres] = 0; mask[:, -thres:] = 0 | |
| mask = cv2.GaussianBlur(mask, (101, 101), 11) | |
| mask = cv2.GaussianBlur(mask, (101, 101), 11) | |
| return mask.astype(np.float32) | |
| def process(self, img): | |
| if self.use_sr: | |
| img_sr = self.srmodel.process(img) | |
| if img_sr is not None: | |
| img = cv2.resize(img, img_sr.shape[:2][::-1]) | |
| facebs, landms = self.facedetector.detect(img) | |
| orig_faces, enhanced_faces = [], [] | |
| height, width = img.shape[:2] | |
| full_mask = np.zeros((height, width), dtype=np.float32) | |
| full_img = np.zeros(img.shape, dtype=np.uint8) | |
| for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): | |
| if faceb[4]<self.threshold: continue | |
| fh, fw = (faceb[3]-faceb[1]), (faceb[2]-faceb[0]) | |
| facial5points = np.reshape(facial5points, (2, 5)) | |
| of, tfm_inv = warp_and_crop_face(img, facial5points, reference_pts=self.reference_5pts, crop_size=(self.size, self.size)) | |
| ef = self.facegan.process(of) | |
| orig_faces.append(of) | |
| enhanced_faces.append(ef) | |
| tmp_mask = self.mask_postprocess(self.faceparser.process(ef)[0]/255.) | |
| tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) | |
| tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=3) | |
| if min(fh, fw)<100: | |
| ef = cv2.filter2D(ef, -1, self.kernel) | |
| tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) | |
| mask = tmp_mask - full_mask | |
| full_mask[np.where(mask>0)] = tmp_mask[np.where(mask>0)] | |
| full_img[np.where(mask>0)] = tmp_img[np.where(mask>0)] | |
| full_mask = full_mask[:, :, np.newaxis] | |
| if self.use_sr and img_sr is not None: | |
| img = cv2.convertScaleAbs(img_sr*(1-full_mask) + full_img*full_mask) | |
| else: | |
| img = cv2.convertScaleAbs(img*(1-full_mask) + full_img*full_mask) | |
| return img, orig_faces, enhanced_faces | |
| # Flask Routes | |
| def index(): | |
| return render_template_string(html_content) | |
| def process_image(): | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file part"}), 400 | |
| file = request.files["file"] | |
| if file.filename == "": | |
| return jsonify({"error": "No filename"}), 400 | |
| filename = secure_filename(file.filename) | |
| input_path = os.path.join(app.config["UPLOAD_FOLDER"], filename) | |
| file.save(input_path) | |
| try: | |
| im = cv2.imread(input_path, cv2.IMREAD_COLOR) | |
| # The GPEN model's core `process` function returns a list of faces. | |
| _, _, enhanced_faces = faceenhancer.process(im) | |
| if not enhanced_faces: | |
| return jsonify({"error": "No faces detected in the image."}), 400 | |
| # We take the first enhanced face from the list | |
| enhanced_face = enhanced_faces[0] | |
| output_path = os.path.join(app.config["RESULTS_FOLDER"], "enhanced_" + filename) | |
| # This is the crucial fix: Convert the image from RGB to BGR before saving with OpenCV. | |
| enhanced_face_bgr = enhanced_face | |
| cv2.imwrite(output_path, enhanced_face_bgr) | |
| return send_file(output_path, mimetype="image/jpeg") | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| finally: | |
| if os.path.exists(input_path): | |
| os.remove(input_path) | |
| if __name__ == "__main__": | |
| initialize_model() | |
| app.run(host="0.0.0.0", port=7860, debug=True) |