File size: 12,869 Bytes
7f1bf24
 
ab770ce
7f1bf24
1922819
 
 
5d7b79c
ab770ce
 
5d7b79c
 
 
 
 
0d31062
7f1bf24
ab770ce
0d31062
7f1bf24
5d7b79c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1922819
5d7b79c
 
3ba2c8f
 
 
 
 
 
 
5d7b79c
54d2d7e
0d31062
54d2d7e
5d7b79c
 
54d2d7e
5d7b79c
54d2d7e
 
 
 
1922819
ab770ce
c321f55
4d8fafe
 
 
 
 
 
1922819
ab770ce
7f1bf24
 
 
9132a89
 
 
 
 
 
 
7f1bf24
4d8fafe
7f1bf24
 
 
 
 
 
 
5d7b79c
1922819
5d7b79c
 
 
 
 
 
 
 
9132a89
 
758398e
9132a89
 
4d8fafe
9132a89
 
 
 
 
 
ab770ce
9132a89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab770ce
 
1922819
5d7b79c
 
 
 
ab770ce
1922819
ab770ce
5d7b79c
ab770ce
9132a89
a0d5d89
9132a89
758398e
 
9132a89
 
 
 
1922819
9132a89
 
 
 
 
 
 
0d31062
 
5d7b79c
 
 
9132a89
 
 
 
 
 
 
 
 
 
 
 
 
0d31062
9132a89
1922819
 
 
9132a89
7f1bf24
9132a89
1922819
 
9132a89
 
 
0d31062
9132a89
1922819
 
9132a89
 
 
 
 
 
 
 
1922819
 
 
 
9132a89
 
 
0d31062
9132a89
 
1922819
 
 
9132a89
1922819
0d31062
9132a89
0d31062
1922819
 
9132a89
 
 
 
05fd58a
9132a89
 
 
 
212dd68
9132a89
 
 
1922819
 
 
 
 
 
 
 
 
 
9132a89
 
 
d8ab7fe
 
 
 
9132a89
d8ab7fe
609e827
 
d8ab7fe
 
 
9132a89
 
1922819
 
 
 
9132a89
 
 
1922819
 
 
9132a89
ab770ce
0d31062
7f1bf24
9132a89
 
 
0d31062
 
9132a89
 
 
 
0d31062
1922819
0d31062
9132a89
 
0d31062
 
9132a89
0d31062
 
7f1bf24
 
9132a89
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import os
import sys
import subprocess

# --- CẤU HÌNH PATH NGAY LẬP TỨC ---
sys.path.append(os.getcwd())

# --- PHẦN 1: SETUP MÔI TRƯỜNG (FINAL FIX) ---
print("⏳ Đang thiết lập môi trường...")

# 0. CÀI ĐẶT CÁC THƯ VIỆN BỊ THIẾU (BẮT BUỘC)
# DRCT_arch yêu cầu einops nhưng chưa có trong requirements.txt
print("   + Installing missing dependencies (einops)...")
subprocess.run([sys.executable, "-m", "pip", "install", "einops", "scipy"], check=True)

# 1. Clone CodeFormer
if not os.path.exists("CodeFormer"):
    print("   + Cloning CodeFormer...")
    subprocess.run(["git", "clone", "https://github.com/sczhou/CodeFormer.git"], check=True)

# 2. TẠO CÁC FILE GIẢ LẬP ĐỂ TRÁNH LỖI SETUP.PY
# Setup.py của BasicSR rất "khó tính", nó đòi hỏi file VERSION phải tồn tại ở đúng chỗ
print("   + Creating dummy version files...")

# Tạo file VERSION (Fix lỗi FileNotFoundError: './basicsr/VERSION')
if not os.path.exists("CodeFormer/basicsr/VERSION"):
    with open("CodeFormer/basicsr/VERSION", "w", encoding="utf-8") as f:
        f.write("1.4.2")

# Tạo file version.py đầy đủ (Fix lỗi ImportError: cannot import name '__gitsha__')
version_py_path = "CodeFormer/basicsr/version.py"
with open(version_py_path, "w", encoding="utf-8") as f:
    f.write("version = '1.4.2'\n")
    f.write("__gitsha__ = 'unknown'\n")
    f.write("__version__ = '1.4.2'\n")

# Patch setup.py (Phòng hờ)
setup_file_path = "CodeFormer/basicsr/setup.py"
if os.path.exists(setup_file_path):
    with open(setup_file_path, "r", encoding="utf-8") as f:
        content = f.read()
    content = content.replace("version=get_version(),", "version='1.4.2',")
    with open(setup_file_path, "w", encoding="utf-8") as f:
        f.write(content)

# 3. CÀI ĐẶT BASICSR
print("   + Installing BasicSR...")
if not os.path.exists("CodeFormer/basicsr.egg-info"):
    try:
        # --no-build-isolation: Dùng torch có sẵn
        # --no-deps: Không cài lại torch
        subprocess.run(
            [sys.executable, "-m", "pip", "install", ".", "--no-build-isolation", "--no-deps"], 
            cwd="CodeFormer/basicsr", 
            check=True
        )
    except subprocess.CalledProcessError:
        print("⚠️ Cài đặt BasicSR thất bại. Chuyển sang chế độ chạy trực tiếp (Pure Python).")

# 4. CÀI ĐẶT GFPGAN
print("   + Installing GFPGAN...")
try:
    import gfpgan
except ImportError:
    subprocess.run([sys.executable, "-m", "pip", "install", "gfpgan", "--no-deps"], check=True)

# Thêm CodeFormer vào path
sys.path.append(os.path.join(os.getcwd(), "CodeFormer"))

# -----------------------------------------------------------

import gradio as gr
import torch
import cv2
import time
import numpy as np
from PIL import Image, ImageEnhance
from torchvision.transforms.functional import normalize

# Import module an toàn
try:
    from basicsr.utils import img2tensor, tensor2img
    from basicsr.utils.realesrgan_utils import RealESRGANer
    from basicsr.utils.download_util import load_file_from_url
    from basicsr.archs.codeformer_arch import CodeFormer
    from facelib.utils.face_restoration_helper import FaceRestoreHelper
except ImportError as e:
    print(f"⚠️ Lỗi Import BasicSR: {e}. Đang kiểm tra lại path...")
    sys.path.append(os.path.join(os.getcwd(), "CodeFormer"))
    try:
        from basicsr.utils import img2tensor, tensor2img
        from basicsr.utils.realesrgan_utils import RealESRGANer
        from basicsr.utils.download_util import load_file_from_url
        from basicsr.archs.codeformer_arch import CodeFormer
        from facelib.utils.face_restoration_helper import FaceRestoreHelper
    except ImportError as e2:
        print(f"❌ Lỗi Import nghiêm trọng: {e2}")

# --- CẤU HÌNH ---
DRCT_MODEL_PATH = "Real_DRCT_GAN_SRx4_finetuned_from_mse_net_g_latest.pth" 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- CLASS TÙY CHỈNH ---
class RealESRGANer_Custom(RealESRGANer):
    def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None, gpu_id=None):
        self.scale = scale
        self.tile_size = tile
        self.tile_pad = tile_pad
        self.pre_pad = pre_pad
        self.mod_scale = 16 
        self.half = half

        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device

        if model_path is not None:
            if model_path.startswith('https://'):
                model_path = load_file_from_url(
                    url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
            loadnet = torch.load(model_path, map_location=torch.device('cpu'))
            keyname = 'params_ema' if 'params_ema' in loadnet else 'params'
            model.load_state_dict(loadnet[keyname], strict=True)
            
        model.eval()
        self.model = model.to(self.device)
        if self.half:
            self.model = self.model.half()

    def pre_process(self, img):
        img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
        self.img = img.unsqueeze(0).to(self.device)
        if self.half:
            self.img = self.img.half()

        if self.pre_pad != 0:
            self.img = torch.nn.functional.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')

        if self.mod_scale is not None:
            self.mod_pad_h, self.mod_pad_w = 0, 0
            _, _, h, w = self.img.size()
            if (h % self.mod_scale != 0):
                self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
            if (w % self.mod_scale != 0):
                self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
            self.img = torch.nn.functional.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')

def load_drct_model(model_path, device):
    try:
        from DRCT_arch import DRCT
    except ImportError as e:
        print(f"Lỗi import DRCT: {e}")
        # Thử import lại nếu einops vừa mới được cài
        import site
        site.main() 
        try:
            from DRCT_arch import DRCT
        except ImportError:
             raise ImportError("❌ Không thể import class 'DRCT'. Đảm bảo đã cài 'einops'.")

    model = DRCT(
        upscale=4, in_chans=3, img_size=64, window_size=16,
        compress_ratio=3, squeeze_factor=30, conv_scale=0.01, overlap_ratio=0.5,
        img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
        num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffle',
        resi_connection='1conv'
    )
    
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Thiếu file model weights: {model_path}")
        
    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['params_ema'] if 'params_ema' in checkpoint else checkpoint['params']
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    return model.to(device)

# --- LOAD MODEL ---
print("⏳ Đang tải Model...")
drct_model = None
codeformer = None

try:
    drct_model = load_drct_model(DRCT_MODEL_PATH, device)
    
    if not os.path.exists('weights/CodeFormer/codeformer.pth'):
        os.makedirs('weights/CodeFormer', exist_ok=True)
        load_file_from_url(url='https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', 
                           model_dir='weights/CodeFormer', progress=True, file_name='codeformer.pth')
                           
    codeformer = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, 
                            connect_list=['32', '64', '128', '256']).to(device)
    ckpt = torch.load('weights/CodeFormer/codeformer.pth')['params_ema']
    codeformer.load_state_dict(ckpt)
    codeformer.eval()
    print("✅ Model Ready!")
except Exception as e:
    print(f"⚠️ Lỗi khởi tạo Model: {e}")
    import traceback
    traceback.print_exc()

# --- XỬ LÝ ẢNH ---
def process_image(input_img, w=0.7):
    if drct_model is None: return None, None, "Lỗi Model (Xem Logs)", ""
    if input_img is None: return None, None, "Thiếu ảnh input", ""

    img = cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)
    
    # 1. DRCT
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
    start_time = time.time()
    
    try:
        upsampler = RealESRGANer_Custom(
            scale=4, model_path=None, model=drct_model,
            tile=512, tile_pad=32, pre_pad=0, half=False, device=device
        )
        
        if device.type == 'cuda':
             with torch.autocast(device_type='cuda', dtype=torch.float16):
                bg_img, _ = upsampler.enhance(img, outscale=4)
        else:
            bg_img, _ = upsampler.enhance(img, outscale=4)
            
    except Exception as e:
        return None, None, f"Lỗi DRCT: {str(e)}", ""

    drct_time = time.time() - start_time
    drct_vram = 0
    if torch.cuda.is_available():
        drct_vram = torch.cuda.max_memory_allocated() / (1024 ** 3)
    
    res_drct = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
    stats_drct = f"⏱️ {drct_time:.2f}s | 💾 {drct_vram:.2f} GB | 📏 {bg_img.shape[1]}x{bg_img.shape[0]}"

    # 2. CODEFORMER
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
    start_time_cf = time.time()
    
    try:
        face_helper = FaceRestoreHelper(
            upscale_factor=4, face_size=512, crop_ratio=(1, 1), 
            det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=device
        )
        
        face_helper.clean_all()
        face_helper.read_image(img)
        face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
        face_helper.align_warp_face()

        # Xử lý khuôn mặt
        for idx, cropped_face in enumerate(face_helper.cropped_faces):
            cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
            cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
            
            with torch.no_grad():
                output = codeformer(cropped_face_t, w=w, adain=True)[0]
                restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
            face_helper.add_restored_face(restored_face)

        face_helper.get_inverse_affine(None)
        final_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)

        # Chuyển BGR (OpenCV) sang RGB để xử lý với PIL và hiển thị trên UI
        final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
        final_img_pil = Image.fromarray(final_img_rgb)
        
        # Thực hiện Enhance (nếu cần)
        final_img_pil = ImageEnhance.Color(final_img_pil).enhance(1.0) 
        final_img_pil = ImageEnhance.Contrast(final_img_pil).enhance(1.0)
        
        # Chuyển về mảng numpy để Gradio hiển thị đúng màu
        res_hybrid = np.array(final_img_pil)
        
    except Exception as e:
        print(f"CodeFormer Error/No Face: {e}")
        res_hybrid = res_drct
        stats_hybrid = f"⚠️ Lỗi CF/Không có mặt: {str(e)}"
        return res_drct, res_hybrid, stats_drct, stats_hybrid

    cf_time = time.time() - start_time_cf
    total_time = drct_time + cf_time
    max_vram = drct_vram
    if torch.cuda.is_available():
        max_vram = max(drct_vram, torch.cuda.max_memory_allocated() / (1024 ** 3))

    stats_hybrid = (f"⏱️ Tổng: {total_time:.2f}s\n"
                    f"   (DRCT: {drct_time:.2f}s + CF: {cf_time:.2f}s)\n"
                    f"💾 VRAM Peak: {max_vram:.2f} GB")

    return res_drct, res_hybrid, stats_drct, stats_hybrid

# --- UI ---
title = "So sánh Upscale: DRCT vs Hybrid"
with gr.Blocks(title=title) as demo:
    gr.Markdown(f"# {title}")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input")
            w_slider = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="CodeFormer Weight (0=Restore, 1=Identity)")
            run_btn = gr.Button("🚀 Chạy", variant="primary")
    with gr.Row():
        with gr.Column():
            output_drct = gr.Image(label="DRCT Only")
            stats_drct_box = gr.Textbox(label="Stats")
        with gr.Column():
            output_hybrid = gr.Image(label="DRCT + CodeFormer")
            stats_hybrid_box = gr.Textbox(label="Stats")

    run_btn.click(process_image, [input_image, w_slider], [output_drct, output_hybrid, stats_drct_box, stats_hybrid_box])

if __name__ == "__main__":
    demo.queue().launch()