GPEN / app.py
mohamed12ahmed's picture
Update app.py
fe66f93 verified
import os
import shutil
import torch
from flask import Flask, request, jsonify, send_file, render_template_string
from werkzeug.utils import secure_filename
import cv2
import numpy as np
from PIL import Image
# import the necessary libraries from the local project directories
import __init_paths
from face_detect.retinaface_detection import RetinaFaceDetection
from face_parse.face_parsing import FaceParse
from face_model.face_gan import FaceGAN
from sr_model.real_esrnet import RealESRNet
from align_faces import warp_and_crop_face, get_reference_facial_points
# Flask App setup
app = Flask(__name__)
UPLOAD_FOLDER = 'uploads'
RESULTS_FOLDER = 'results'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['RESULTS_FOLDER'] = RESULTS_FOLDER
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULTS_FOLDER, exist_ok=True)
# The HTML content is defined here, before it is used by the `index` function.
html_content = """
<!DOCTYPE html>
<html lang="ar" dir="rtl">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>GPEN: تحسين جودة الوجه</title>
<style>
body { text-align:center; font-family: sans-serif; background-color: #f0f2f5; padding: 20px; }
.container { max-width: 800px; margin: auto; background-color: white; padding: 30px; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); }
.image-display { display:flex; justify-content:center; gap:20px; margin-top:20px; }
img { max-width:100%; height: auto; border:1px solid #ddd; border-radius: 4px; }
h3 { margin-bottom: 10px; }
.button { padding: 10px 20px; background-color: #007bff; color: white; border: none; border-radius: 5px; cursor: pointer; }
</style>
</head>
<body>
<div class="container">
<h1>GPEN: تحسين جودة الوجه</h1>
<p>قم بتحميل صورة وسيقوم النموذج بتحسين جودة الوجه فيها.</p>
<form id="uploadForm" enctype="multipart/form-data">
<input type="file" id="fileInput" name="file" accept="image/*" required>
<button type="submit" class="button">معالجة الصورة</button>
</form>
<p id="loading" style="display:none;">... جاري المعالجة، يرجى الانتظار.</p>
<div class="image-display">
<div>
<h3>الأصلية</h3>
<img id="original" style="display:none;">
</div>
<div>
<h3>النتيجة</h3>
<img id="result" style="display:none;">
</div>
</div>
</div>
<script>
const form = document.getElementById("uploadForm");
const fileInput = document.getElementById("fileInput");
const loading = document.getElementById("loading");
const original = document.getElementById("original");
const result = document.getElementById("result");
fileInput.addEventListener("change", (e) => {
if (e.target.files.length > 0) {
original.src = URL.createObjectURL(e.target.files[0]);
original.style.display = "block";
result.style.display = "none";
}
});
form.addEventListener("submit", async (e) => {
e.preventDefault();
if (fileInput.files.length === 0) return;
const formData = new FormData();
formData.append("file", fileInput.files[0]);
loading.style.display = "block";
try {
const response = await fetch("/process_image", {
method: "POST",
body: formData
});
if (response.ok) {
const blob = await response.blob();
const url = URL.createObjectURL(blob);
result.src = url;
result.style.display = "block";
} else {
const error = await response.json();
alert("خطأ: " + error.error);
}
} catch (err) {
alert("فشل في الاتصال بالخادم: " + err);
} finally {
loading.style.display = "none";
}
});
</script>
</body>
</html>
"""
# Function to download models if they don't exist
def download_models():
models_to_download = {
'weights/RetinaFace-R50.pth': 'https://huggingface.co/akhaliq/RetinaFace-R50/resolve/main/RetinaFace-R50.pth',
'weights/GPEN-BFR-512.pth': 'https://huggingface.co/akhaliq/GPEN-BFR-512/resolve/main/GPEN-BFR-512.pth',
'weights/realesrnet_x2.pth': 'https://huggingface.co/akhaliq/realesrnet_x2/resolve/main/realesrnet_x2.pth',
'weights/ParseNet-latest.pth': 'https://huggingface.co/akhaliq/ParseNet-latest/resolve/main/ParseNet-latest.pth'
}
for local_path, url in models_to_download.items():
if not os.path.exists(local_path):
print(f"Downloading {local_path}...")
os.makedirs(os.path.dirname(local_path), exist_ok=True)
os.system(f'wget "{url}" -O {local_path}')
print(f"{local_path} downloaded.")
# Global object for face enhancement model
faceenhancer = None
def initialize_model():
global faceenhancer
if faceenhancer is None:
try:
download_models()
model = "GPEN-BFR-512"
key = None
size = 512
channel_multiplier = 2
narrow = 1
use_sr = False
sr_model = 'realesrnet_x2'
use_cuda = False
faceenhancer = FaceEnhancement(
size=size,
model=model,
use_sr=use_sr,
sr_model=sr_model,
channel_multiplier=channel_multiplier,
narrow=narrow,
key=key,
device='cpu'
)
print("✅ FaceEnhancement model initialized successfully.")
except Exception as e:
print(f"❌ Error initializing model: {e}")
faceenhancer = None
# Class from the original code
class FaceEnhancement(object):
def __init__(self, base_dir='./', size=512, model=None, use_sr=True, sr_model=None, channel_multiplier=2, narrow=1, key=None, device='cuda'):
self.facedetector = RetinaFaceDetection(base_dir, device)
self.facegan = FaceGAN(base_dir, size, model, channel_multiplier, narrow, key, device=device)
self.srmodel = RealESRNet(base_dir, sr_model, device=device)
self.faceparser = FaceParse(base_dir, device=device)
self.use_sr = use_sr
self.size = size
self.threshold = 0.9
self.mask = np.zeros((512, 512), np.float32)
cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA)
self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11)
self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11)
self.kernel = np.array((
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625]), dtype="float32")
default_square = True
inner_padding_factor = 0.25
outer_padding = (0, 0)
self.reference_5pts = get_reference_facial_points(
(self.size, self.size), inner_padding_factor, outer_padding, default_square)
def mask_postprocess(self, mask, thres=20):
mask[:thres, :] = 0; mask[-thres:, :] = 0
mask[:, :thres] = 0; mask[:, -thres:] = 0
mask = cv2.GaussianBlur(mask, (101, 101), 11)
mask = cv2.GaussianBlur(mask, (101, 101), 11)
return mask.astype(np.float32)
def process(self, img):
if self.use_sr:
img_sr = self.srmodel.process(img)
if img_sr is not None:
img = cv2.resize(img, img_sr.shape[:2][::-1])
facebs, landms = self.facedetector.detect(img)
orig_faces, enhanced_faces = [], []
height, width = img.shape[:2]
full_mask = np.zeros((height, width), dtype=np.float32)
full_img = np.zeros(img.shape, dtype=np.uint8)
for i, (faceb, facial5points) in enumerate(zip(facebs, landms)):
if faceb[4]<self.threshold: continue
fh, fw = (faceb[3]-faceb[1]), (faceb[2]-faceb[0])
facial5points = np.reshape(facial5points, (2, 5))
of, tfm_inv = warp_and_crop_face(img, facial5points, reference_pts=self.reference_5pts, crop_size=(self.size, self.size))
ef = self.facegan.process(of)
orig_faces.append(of)
enhanced_faces.append(ef)
tmp_mask = self.mask_postprocess(self.faceparser.process(ef)[0]/255.)
tmp_mask = cv2.resize(tmp_mask, ef.shape[:2])
tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=3)
if min(fh, fw)<100:
ef = cv2.filter2D(ef, -1, self.kernel)
tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3)
mask = tmp_mask - full_mask
full_mask[np.where(mask>0)] = tmp_mask[np.where(mask>0)]
full_img[np.where(mask>0)] = tmp_img[np.where(mask>0)]
full_mask = full_mask[:, :, np.newaxis]
if self.use_sr and img_sr is not None:
img = cv2.convertScaleAbs(img_sr*(1-full_mask) + full_img*full_mask)
else:
img = cv2.convertScaleAbs(img*(1-full_mask) + full_img*full_mask)
return img, orig_faces, enhanced_faces
# Flask Routes
@app.route("/")
def index():
return render_template_string(html_content)
@app.route("/process_image", methods=["POST"])
def process_image():
if "file" not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files["file"]
if file.filename == "":
return jsonify({"error": "No filename"}), 400
filename = secure_filename(file.filename)
input_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
file.save(input_path)
try:
im = cv2.imread(input_path, cv2.IMREAD_COLOR)
# The GPEN model's core `process` function returns a list of faces.
_, _, enhanced_faces = faceenhancer.process(im)
if not enhanced_faces:
return jsonify({"error": "No faces detected in the image."}), 400
# We take the first enhanced face from the list
enhanced_face = enhanced_faces[0]
output_path = os.path.join(app.config["RESULTS_FOLDER"], "enhanced_" + filename)
# This is the crucial fix: Convert the image from RGB to BGR before saving with OpenCV.
enhanced_face_bgr = enhanced_face
cv2.imwrite(output_path, enhanced_face_bgr)
return send_file(output_path, mimetype="image/jpeg")
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
if os.path.exists(input_path):
os.remove(input_path)
if __name__ == "__main__":
initialize_model()
app.run(host="0.0.0.0", port=7860, debug=True)