IMAGEN_PLUS / app.py
Andro0s's picture
Update app.py
c5c5ef2 verified
import os
import sys
import zipfile
import shutil
import subprocess
import cv2
import numpy as np
import gradio as gr
from PIL import Image
import onnxruntime as ort
# Configuración de rutas
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
# =========================================================
# 1) EXTRAER Y CONFIGURAR BasicSR
# =========================================================
def prepare_basicsr():
if not os.path.exists("basicsr"):
zip_name = "BasicSR-master.zip"
if os.path.exists(zip_name):
print(f"📦 Extrayendo {zip_name}...")
with zipfile.ZipFile(zip_name, "r") as zip_ref:
zip_ref.extractall("temp_extract")
for root, dirs, files in os.walk("temp_extract"):
if "basicsr" in dirs:
source_path = os.path.join(root, "basicsr")
if os.path.exists("basicsr"): shutil.rmtree("basicsr")
shutil.move(source_path, os.path.join(BASE_DIR, "basicsr"))
break
# Parche de versión
version_file = os.path.join(BASE_DIR, "basicsr", "version.py")
with open(version_file, "w") as f:
f.write("__version__ = '1.4.2'\n")
f.write("__gitsha__ = 'unknown'\n")
if os.path.exists("temp_extract"): shutil.rmtree("temp_extract")
print("✅ Carpeta 'basicsr' inyectada y versionada.")
prepare_basicsr()
# =========================================================
# 2) INSTALACIÓN DE MOTORES
# =========================================================
def install_deps():
try:
import gfpgan
import realesrgan
except ImportError:
print("📥 Instalando motores...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "gfpgan", "realesrgan", "--no-deps"])
install_deps()
# =========================================================
# 3) Parches y Carga de Motores
# =========================================================
import types
import torchvision
try:
import torchvision.transforms.functional_tensor as T_f
except ImportError:
from torchvision.transforms import functional as F
module = types.ModuleType("torchvision.transforms.functional_tensor")
module.rgb_to_grayscale = F.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = module
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from gfpgan import GFPGANer
print("🚀 Inicializando modelos...")
esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23)
# El upsampler principal para TODA la foto
upsampler = RealESRGANer(
scale=4,
model_path="RealESRGAN_x4plus.pth",
model=esrgan_model,
tile=1024,
tile_pad=10,
pre_pad=0,
half=False
)
# El restaurador de rostros que usará el upsampler para el fondo
face_enhancer = GFPGANer(
model_path="GFPGANv1.4.pth",
upscale=4, # <--- Aumentamos esto para que coincida con el escalado total
arch="clean",
channel_multiplier=2,
bg_upsampler=upsampler # Vinculamos el upsampler para que no ignore el fondo
)
print("✅ TODO LISTO.")
# ============================
# Segmentador humano (U-2-Net)
# ============================
print("🧠 Cargando segmentador humano...")
ort_session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
# =========================================================
# Máscara humana
# =========================================================
def get_human_mask(img_bgr):
h, w = img_bgr.shape[:2]
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (320, 320))
img = img.astype(np.float32) / 255.0
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, 0)
input_name = ort_session.get_inputs()[0].name
pred = ort_session.run(None, {input_name: img})[0][0][0]
mask = cv2.resize(pred, (w, h))
mask = np.clip(mask, 0, 1)
return mask
# =========================================================
# 4) Lógica de Procesamiento Completo
# =========================================================
def process(image, fidelity):
if image is None: return None, None
# 1) Convertir imagen original
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# 2) Obtener máscara humana (para refinamiento posterior)
# Lo hacemos al principio con la imagen pequeña para ahorrar memoria
mask = get_human_mask(img)
# 3) MEJORA INTEGRAL (Fondo + Ropa + Rostros)
# No separamos antes de escalar para que la IA entienda el contexto global.
# face_enhancer ya tiene vinculado el 'upsampler' (RealESRGAN) en 'bg_upsampler'.
# Esto mejorará TODA la imagen a 4x de forma nativa.
_, _, final_hd = face_enhancer.enhance(
img,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=fidelity
)
# 4) REFINAMIENTO OPCIONAL (Si quieres asegurar que el fondo sea el de RealESRGAN puro)
# A veces GFPGAN suaviza de más la piel. Aquí podrías mezclar,
# pero el paso anterior ya hace el 90% del trabajo que buscas.
final_rgb = cv2.cvtColor(final_hd, cv2.COLOR_BGR2RGB)
# IMPORTANTE: Eliminamos todo el código que tenías después del return
return image, Image.fromarray(final_rgb)
# =========================================================
# 5) Interfaz
# =========================================================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 💎 IMAGEN PLUS — Restauración Total HD")
gr.Markdown("Esta herramienta mejora el fondo, la ropa y los rostros simultáneamente.")
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Subir Imagen Original")
fidelity = gr.Slider(0.0, 1.0, value=0.5, label="Naturalidad del Rostro")
btn = gr.Button("🚀 Mejorar Foto Completa", variant="primary")
with gr.Column():
after = gr.Image(label="Resultado HD (4x)")
before = gr.Image(label="Referencia Original")
btn.click(process, [inp, fidelity], [before, after])
if __name__ == "__main__":
demo.launch()