Spaces:
Sleeping
Sleeping
File size: 12,869 Bytes
7f1bf24 ab770ce 7f1bf24 1922819 5d7b79c ab770ce 5d7b79c 0d31062 7f1bf24 ab770ce 0d31062 7f1bf24 5d7b79c 1922819 5d7b79c 3ba2c8f 5d7b79c 54d2d7e 0d31062 54d2d7e 5d7b79c 54d2d7e 5d7b79c 54d2d7e 1922819 ab770ce c321f55 4d8fafe 1922819 ab770ce 7f1bf24 9132a89 7f1bf24 4d8fafe 7f1bf24 5d7b79c 1922819 5d7b79c 9132a89 758398e 9132a89 4d8fafe 9132a89 ab770ce 9132a89 ab770ce 1922819 5d7b79c ab770ce 1922819 ab770ce 5d7b79c ab770ce 9132a89 a0d5d89 9132a89 758398e 9132a89 1922819 9132a89 0d31062 5d7b79c 9132a89 0d31062 9132a89 1922819 9132a89 7f1bf24 9132a89 1922819 9132a89 0d31062 9132a89 1922819 9132a89 1922819 9132a89 0d31062 9132a89 1922819 9132a89 1922819 0d31062 9132a89 0d31062 1922819 9132a89 05fd58a 9132a89 212dd68 9132a89 1922819 9132a89 d8ab7fe 9132a89 d8ab7fe 609e827 d8ab7fe 9132a89 1922819 9132a89 1922819 9132a89 ab770ce 0d31062 7f1bf24 9132a89 0d31062 9132a89 0d31062 1922819 0d31062 9132a89 0d31062 9132a89 0d31062 7f1bf24 9132a89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 | import os
import sys
import subprocess
# --- CẤU HÌNH PATH NGAY LẬP TỨC ---
sys.path.append(os.getcwd())
# --- PHẦN 1: SETUP MÔI TRƯỜNG (FINAL FIX) ---
print("⏳ Đang thiết lập môi trường...")
# 0. CÀI ĐẶT CÁC THƯ VIỆN BỊ THIẾU (BẮT BUỘC)
# DRCT_arch yêu cầu einops nhưng chưa có trong requirements.txt
print(" + Installing missing dependencies (einops)...")
subprocess.run([sys.executable, "-m", "pip", "install", "einops", "scipy"], check=True)
# 1. Clone CodeFormer
if not os.path.exists("CodeFormer"):
print(" + Cloning CodeFormer...")
subprocess.run(["git", "clone", "https://github.com/sczhou/CodeFormer.git"], check=True)
# 2. TẠO CÁC FILE GIẢ LẬP ĐỂ TRÁNH LỖI SETUP.PY
# Setup.py của BasicSR rất "khó tính", nó đòi hỏi file VERSION phải tồn tại ở đúng chỗ
print(" + Creating dummy version files...")
# Tạo file VERSION (Fix lỗi FileNotFoundError: './basicsr/VERSION')
if not os.path.exists("CodeFormer/basicsr/VERSION"):
with open("CodeFormer/basicsr/VERSION", "w", encoding="utf-8") as f:
f.write("1.4.2")
# Tạo file version.py đầy đủ (Fix lỗi ImportError: cannot import name '__gitsha__')
version_py_path = "CodeFormer/basicsr/version.py"
with open(version_py_path, "w", encoding="utf-8") as f:
f.write("version = '1.4.2'\n")
f.write("__gitsha__ = 'unknown'\n")
f.write("__version__ = '1.4.2'\n")
# Patch setup.py (Phòng hờ)
setup_file_path = "CodeFormer/basicsr/setup.py"
if os.path.exists(setup_file_path):
with open(setup_file_path, "r", encoding="utf-8") as f:
content = f.read()
content = content.replace("version=get_version(),", "version='1.4.2',")
with open(setup_file_path, "w", encoding="utf-8") as f:
f.write(content)
# 3. CÀI ĐẶT BASICSR
print(" + Installing BasicSR...")
if not os.path.exists("CodeFormer/basicsr.egg-info"):
try:
# --no-build-isolation: Dùng torch có sẵn
# --no-deps: Không cài lại torch
subprocess.run(
[sys.executable, "-m", "pip", "install", ".", "--no-build-isolation", "--no-deps"],
cwd="CodeFormer/basicsr",
check=True
)
except subprocess.CalledProcessError:
print("⚠️ Cài đặt BasicSR thất bại. Chuyển sang chế độ chạy trực tiếp (Pure Python).")
# 4. CÀI ĐẶT GFPGAN
print(" + Installing GFPGAN...")
try:
import gfpgan
except ImportError:
subprocess.run([sys.executable, "-m", "pip", "install", "gfpgan", "--no-deps"], check=True)
# Thêm CodeFormer vào path
sys.path.append(os.path.join(os.getcwd(), "CodeFormer"))
# -----------------------------------------------------------
import gradio as gr
import torch
import cv2
import time
import numpy as np
from PIL import Image, ImageEnhance
from torchvision.transforms.functional import normalize
# Import module an toàn
try:
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.realesrgan_utils import RealESRGANer
from basicsr.utils.download_util import load_file_from_url
from basicsr.archs.codeformer_arch import CodeFormer
from facelib.utils.face_restoration_helper import FaceRestoreHelper
except ImportError as e:
print(f"⚠️ Lỗi Import BasicSR: {e}. Đang kiểm tra lại path...")
sys.path.append(os.path.join(os.getcwd(), "CodeFormer"))
try:
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.realesrgan_utils import RealESRGANer
from basicsr.utils.download_util import load_file_from_url
from basicsr.archs.codeformer_arch import CodeFormer
from facelib.utils.face_restoration_helper import FaceRestoreHelper
except ImportError as e2:
print(f"❌ Lỗi Import nghiêm trọng: {e2}")
# --- CẤU HÌNH ---
DRCT_MODEL_PATH = "Real_DRCT_GAN_SRx4_finetuned_from_mse_net_g_latest.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --- CLASS TÙY CHỈNH ---
class RealESRGANer_Custom(RealESRGANer):
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None, gpu_id=None):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = 16
self.half = half
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
if model_path is not None:
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
keyname = 'params_ema' if 'params_ema' in loadnet else 'params'
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def pre_process(self, img):
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
if self.pre_pad != 0:
self.img = torch.nn.functional.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if (h % self.mod_scale != 0):
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
if (w % self.mod_scale != 0):
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
self.img = torch.nn.functional.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def load_drct_model(model_path, device):
try:
from DRCT_arch import DRCT
except ImportError as e:
print(f"Lỗi import DRCT: {e}")
# Thử import lại nếu einops vừa mới được cài
import site
site.main()
try:
from DRCT_arch import DRCT
except ImportError:
raise ImportError("❌ Không thể import class 'DRCT'. Đảm bảo đã cài 'einops'.")
model = DRCT(
upscale=4, in_chans=3, img_size=64, window_size=16,
compress_ratio=3, squeeze_factor=30, conv_scale=0.01, overlap_ratio=0.5,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffle',
resi_connection='1conv'
)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Thiếu file model weights: {model_path}")
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint['params_ema'] if 'params_ema' in checkpoint else checkpoint['params']
model.load_state_dict(state_dict, strict=False)
model.eval()
return model.to(device)
# --- LOAD MODEL ---
print("⏳ Đang tải Model...")
drct_model = None
codeformer = None
try:
drct_model = load_drct_model(DRCT_MODEL_PATH, device)
if not os.path.exists('weights/CodeFormer/codeformer.pth'):
os.makedirs('weights/CodeFormer', exist_ok=True)
load_file_from_url(url='https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
model_dir='weights/CodeFormer', progress=True, file_name='codeformer.pth')
codeformer = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
connect_list=['32', '64', '128', '256']).to(device)
ckpt = torch.load('weights/CodeFormer/codeformer.pth')['params_ema']
codeformer.load_state_dict(ckpt)
codeformer.eval()
print("✅ Model Ready!")
except Exception as e:
print(f"⚠️ Lỗi khởi tạo Model: {e}")
import traceback
traceback.print_exc()
# --- XỬ LÝ ẢNH ---
def process_image(input_img, w=0.7):
if drct_model is None: return None, None, "Lỗi Model (Xem Logs)", ""
if input_img is None: return None, None, "Thiếu ảnh input", ""
img = cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)
# 1. DRCT
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
try:
upsampler = RealESRGANer_Custom(
scale=4, model_path=None, model=drct_model,
tile=512, tile_pad=32, pre_pad=0, half=False, device=device
)
if device.type == 'cuda':
with torch.autocast(device_type='cuda', dtype=torch.float16):
bg_img, _ = upsampler.enhance(img, outscale=4)
else:
bg_img, _ = upsampler.enhance(img, outscale=4)
except Exception as e:
return None, None, f"Lỗi DRCT: {str(e)}", ""
drct_time = time.time() - start_time
drct_vram = 0
if torch.cuda.is_available():
drct_vram = torch.cuda.max_memory_allocated() / (1024 ** 3)
res_drct = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
stats_drct = f"⏱️ {drct_time:.2f}s | 💾 {drct_vram:.2f} GB | 📏 {bg_img.shape[1]}x{bg_img.shape[0]}"
# 2. CODEFORMER
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
start_time_cf = time.time()
try:
face_helper = FaceRestoreHelper(
upscale_factor=4, face_size=512, crop_ratio=(1, 1),
det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=device
)
face_helper.clean_all()
face_helper.read_image(img)
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
# Xử lý khuôn mặt
for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
with torch.no_grad():
output = codeformer(cropped_face_t, w=w, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
final_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)
# Chuyển BGR (OpenCV) sang RGB để xử lý với PIL và hiển thị trên UI
final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
final_img_pil = Image.fromarray(final_img_rgb)
# Thực hiện Enhance (nếu cần)
final_img_pil = ImageEnhance.Color(final_img_pil).enhance(1.0)
final_img_pil = ImageEnhance.Contrast(final_img_pil).enhance(1.0)
# Chuyển về mảng numpy để Gradio hiển thị đúng màu
res_hybrid = np.array(final_img_pil)
except Exception as e:
print(f"CodeFormer Error/No Face: {e}")
res_hybrid = res_drct
stats_hybrid = f"⚠️ Lỗi CF/Không có mặt: {str(e)}"
return res_drct, res_hybrid, stats_drct, stats_hybrid
cf_time = time.time() - start_time_cf
total_time = drct_time + cf_time
max_vram = drct_vram
if torch.cuda.is_available():
max_vram = max(drct_vram, torch.cuda.max_memory_allocated() / (1024 ** 3))
stats_hybrid = (f"⏱️ Tổng: {total_time:.2f}s\n"
f" (DRCT: {drct_time:.2f}s + CF: {cf_time:.2f}s)\n"
f"💾 VRAM Peak: {max_vram:.2f} GB")
return res_drct, res_hybrid, stats_drct, stats_hybrid
# --- UI ---
title = "So sánh Upscale: DRCT vs Hybrid"
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# {title}")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input")
w_slider = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="CodeFormer Weight (0=Restore, 1=Identity)")
run_btn = gr.Button("🚀 Chạy", variant="primary")
with gr.Row():
with gr.Column():
output_drct = gr.Image(label="DRCT Only")
stats_drct_box = gr.Textbox(label="Stats")
with gr.Column():
output_hybrid = gr.Image(label="DRCT + CodeFormer")
stats_hybrid_box = gr.Textbox(label="Stats")
run_btn.click(process_image, [input_image, w_slider], [output_drct, output_hybrid, stats_drct_box, stats_hybrid_box])
if __name__ == "__main__":
demo.queue().launch() |