File size: 14,104 Bytes
adf2fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""

CodeFormer Flask Application

Deployment on Hugging Face Spaces

"""

import os
import cv2
import torch
import uuid
import numpy as np
import zipfile
import base64
from flask import Flask, render_template, request, send_file, url_for, jsonify, send_from_directory
from flask_cors import CORS
from werkzeug.utils import secure_filename

from torchvision.transforms.functional import normalize
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils.misc import gpu_is_available, get_device
from basicsr.utils.realesrgan_utils import RealESRGANer
from basicsr.utils.registry import ARCH_REGISTRY

from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.utils.misc import is_gray

# --- Initialization ---
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['RESULT_FOLDER'] = 'static/results'
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024  # 100MB limit

# Ensure directories exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
os.makedirs('weights/CodeFormer', exist_ok=True)
os.makedirs('weights/facelib', exist_ok=True)
os.makedirs('weights/realesrgan', exist_ok=True)

# Pretrained model URLs
pretrain_model_url = {
    'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
    'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
    'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
    'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
}

def download_weights():
    if not os.path.exists('weights/CodeFormer/codeformer.pth'):
        load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='weights/CodeFormer', progress=True, file_name=None)
    if not os.path.exists('weights/facelib/detection_Resnet50_Final.pth'):
        load_file_from_url(url=pretrain_model_url['detection'], model_dir='weights/facelib', progress=True, file_name=None)
    if not os.path.exists('weights/facelib/parsing_parsenet.pth'):
        load_file_from_url(url=pretrain_model_url['parsing'], model_dir='weights/facelib', progress=True, file_name=None)
    if not os.path.exists('weights/realesrgan/RealESRGAN_x2plus.pth'):
        load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='weights/realesrgan', progress=True, file_name=None)

# Download weights on startup
print("Checking weights...")
download_weights()

# Global models
device = get_device()
upsampler = None
codeformer_net = None

def init_models():
    global upsampler, codeformer_net
    
    # RealESRGAN
    half = True if gpu_is_available() else False
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
    upsampler = RealESRGANer(
        scale=2,
        model_path="weights/realesrgan/RealESRGAN_x2plus.pth",
        model=model,
        tile=400,
        tile_pad=40,
        pre_pad=0,
        half=half,
    )

    # CodeFormer
    codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
        dim_embd=512,
        codebook_size=1024,
        n_head=8,
        n_layers=9,
        connect_list=["32", "64", "128", "256"],
    ).to(device)
    
    ckpt_path = "weights/CodeFormer/codeformer.pth"
    checkpoint = torch.load(ckpt_path)["params_ema"]
    codeformer_net.load_state_dict(checkpoint)
    codeformer_net.eval()
    print("Models loaded successfully.")

init_models()

def process_image(img_path, background_enhance, face_upsample, upscale, codeformer_fidelity):
    """Core inference logic"""
    try:
        # Defaults
        has_aligned = False
        only_center_face = False
        draw_box = False
        detection_model = "retinaface_resnet50"

        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        
        # Memory safety checks
        upscale = int(upscale)
        if upscale > 4: upscale = 4
        if upscale > 2 and max(img.shape[:2]) > 1000: upscale = 2
        if max(img.shape[:2]) > 1500:
            upscale = 1
            background_enhance = False
            face_upsample = False

        face_helper = FaceRestoreHelper(
            upscale,
            face_size=512,
            crop_ratio=(1, 1),
            det_model=detection_model,
            save_ext="png",
            use_parse=True,
            device=device,
        )
        
        bg_upsampler = upsampler if background_enhance else None
        face_upsampler = upsampler if face_upsample else None

        if has_aligned:
            img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
            face_helper.is_gray = is_gray(img, threshold=5)
            face_helper.cropped_faces = [img]
        else:
            face_helper.read_image(img)
            face_helper.get_face_landmarks_5(only_center_face=only_center_face, resize=640, eye_dist_threshold=5)
            face_helper.align_warp_face()

        # Face restoration
        for idx, cropped_face in enumerate(face_helper.cropped_faces):
            cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
            cropped_face_t = cropped_face_t.unsqueeze(0).to(device)

            try:
                with torch.no_grad():
                    output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
                    restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
            except Exception as e:
                print(f"Inference error: {e}")
                restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))

            restored_face = restored_face.astype("uint8")
            face_helper.add_restored_face(restored_face)

        # Paste back
        if not has_aligned:
            bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] if bg_upsampler else None
            face_helper.get_inverse_affine(None)
            
            if face_upsample and face_upsampler:
                restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box, face_upsampler=face_upsampler)
            else:
                restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box)
        else:
            restored_img = face_helper.restored_faces[0]

        return restored_img

    except Exception as e:
        print(f"Global processing error: {e}")
        return None

# --- Routes ---

@app.route('/', methods=['GET'])
def index():
    return render_template('index.html')

@app.route('/process', methods=['POST'])
def process():
    if 'image' not in request.files:
        return "No image uploaded", 400
    
    files = request.files.getlist('image')
    if not files or files[0].filename == '':
        return "No selected file", 400

    results = []

    # Get params (same for all images)
    try:
        fidelity = float(request.form.get('fidelity', 0.5))
        upscale = 4 # Enforce 4x upscale
        background_enhance = 'background_enhance' in request.form
        face_upsample = 'face_upsample' in request.form
    except ValueError:
        return "Invalid parameters", 400

    for file in files:
        if file.filename == '': continue

        # Save input
        filename = str(uuid.uuid4()) + "_" + secure_filename(file.filename)
        input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(input_path)

        # Process
        result_img = process_image(input_path, background_enhance, face_upsample, upscale, fidelity)
        
        if result_img is None:
            continue # Skip failed images or handle error appropriately

        # Save output
        output_filename = "result_" + filename.rsplit('.', 1)[0] + ".png"
        output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename)
        imwrite(result_img, output_path)

        # Generate preview (max 1000px width/height)
        preview_filename = "preview_" + output_filename
        preview_path = os.path.join(app.config['RESULT_FOLDER'], preview_filename)
        
        h, w = result_img.shape[:2]
        if max(h, w) > 1000:
            scale = 1000 / max(h, w)
            preview_img = cv2.resize(result_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
            imwrite(preview_img, preview_path)
        else:
            preview_filename = output_filename

        results.append({
            'original': filename,
            'preview': preview_filename,
            'download': output_filename
        })

    if not results:
        return "Processing failed for all images", 500

    # Create ZIP of all results
    zip_filename = f"batch_{uuid.uuid4()}.zip"
    zip_path = os.path.join(app.config['RESULT_FOLDER'], zip_filename)
    
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for item in results:
            file_path = os.path.join(app.config['RESULT_FOLDER'], item['download'])
            zipf.write(file_path, item['download'])

    return render_template('result.html', results=results, zip_filename=zip_filename)

# --- API Routes ---

@app.route('/api/process', methods=['POST'])
def api_process():
    """

    API endpoint for image processing.

    Accepts:

    - multipart/form-data with one or more 'image' files.

    - application/json with 'image_base64' string (single image) or 'images_base64' list.

    Parameters (form or JSON):

    - fidelity: (float) 0-1, default 0.5.

    - background_enhance: (bool) default False.

    - face_upsample: (bool) default False.

    - upscale: (int) 1-4, default 2.

    - return_base64: (bool) default False.

    """
    try:
        is_json = request.is_json
        data = request.get_json() if is_json else request.form
        
        fidelity = float(data.get('fidelity', 0.5))
        background_enhance = (str(data.get('background_enhance', 'false')).lower() == 'true') if not is_json else data.get('background_enhance', False)
        face_upsample = (str(data.get('face_upsample', 'false')).lower() == 'true') if not is_json else data.get('face_upsample', False)
        upscale = int(data.get('upscale', 2))
        return_base64 = (str(data.get('return_base64', 'false')).lower() == 'true') if not is_json else data.get('return_base64', False)

        processed_images = []
        inputs = []

        # Handle JSON input
        if is_json:
            if 'image_base64' in data:
                inputs.append({'data': data['image_base64'], 'name': 'image.png'})
            elif 'images_base64' in data:
                for idx, img_b64 in enumerate(data['images_base64']):
                    inputs.append({'data': img_b64, 'name': f'image_{idx}.png'})
            
            for inp in inputs:
                temp_filename = str(uuid.uuid4())
                image_data = base64.b64decode(inp['data'].split(',')[-1])
                input_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{temp_filename}.png")
                with open(input_path, 'wb') as f:
                    f.write(image_data)
                inp['path'] = input_path
                inp['temp_id'] = temp_filename

        # Handle Multipart input
        elif 'image' in request.files:
            files = request.files.getlist('image')
            for file in files:
                if file.filename != '':
                    temp_filename = str(uuid.uuid4())
                    filename = f"{temp_filename}_{secure_filename(file.filename)}"
                    input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
                    file.save(input_path)
                    inputs.append({'path': input_path, 'name': file.filename, 'temp_id': temp_filename})

        if not inputs:
            return jsonify({"status": "error", "message": "No images provided"}), 400

        for inp in inputs:
            # Process image
            result_img = process_image(inp['path'], background_enhance, face_upsample, upscale, fidelity)
            if result_img is not None:
                # Save result
                output_filename = f"api_result_{inp['temp_id']}.png"
                output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename)
                imwrite(result_img, output_path)

                res = {
                    "original_name": inp['name'],
                    "image_url": url_for('static', filename=f'results/{output_filename}', _external=True),
                    "filename": output_filename
                }

                if return_base64:
                    _, buffer = cv2.imencode('.png', result_img)
                    img_base64 = base64.b64encode(buffer).decode('utf-8')
                    res["image_base64"] = img_base64
                
                processed_images.append(res)

        if not processed_images:
            return jsonify({"status": "error", "message": "Processing failed for all images"}), 500

        return jsonify({
            "status": "success",
            "count": len(processed_images),
            "results": processed_images
        })

    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({"status": "error", "message": str(e)}), 500

@app.route('/api/health', methods=['GET'])
def health_check():
    return jsonify({"status": "online", "device": str(device)})

if __name__ == '__main__':
    # Docker/HF Spaces entry point
    app.run(host='0.0.0.0', port=7860)