EyeRestore / app.py
kebincontreras's picture
Update app.py
54f6aa6 verified
import gradio as gr
import torch
import numpy as np
import cv2
from torchvision.transforms import ToTensor, Resize, Grayscale
from Resources.Ultris.Ultris_model import UNet
import torch.nn.functional as F
# Configuraci贸n
n, m = 2, 0
IMG_SIZE = 256
AMP = 3.0
WEIGHT_PATH = "Resources/weights/modelo_final_1.pt" # Sube este archivo a Hugging Face Spaces
resize_to_256 = Resize((IMG_SIZE, IMG_SIZE))
def load_model(weight_path):
device = torch.device('cpu')
model = UNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(weight_path, map_location=device))
model.eval()
return model, device
model, device = load_model(WEIGHT_PATH)
def generate_zernike_map_local(n, m, size=256, amplitude=1.0):
y = torch.linspace(-1, 1, size, device=device)
x = torch.linspace(-1, 1, size, device=device)
X, Y = torch.meshgrid(y, x, indexing='ij')
rho = torch.sqrt(X**2 + Y**2)
theta = torch.atan2(Y, X)
mask = rho <= 1
Z = torch.zeros_like(rho)
Z[mask] = amplitude * zernike_local(n, m, rho[mask], theta[mask])
return Z
def zernike_local(n, m, rho, theta):
R = torch.zeros_like(rho)
m = abs(m)
for k in range((n - m) // 2 + 1):
coef = (-1)**k * torch.lgamma(torch.tensor(n - k + 1)) \
- torch.lgamma(torch.tensor(k + 1)) \
- torch.lgamma(torch.tensor((n + m) // 2 - k + 1)) \
- torch.lgamma(torch.tensor((n - m) // 2 - k + 1))
coef = torch.exp(coef)
R += coef * rho**(n - 2 * k)
if m > 0:
return R * torch.cos(m * theta)
elif m < 0:
return R * torch.sin(-m * theta)
else:
return R
def generate_psf_local(zernike_map):
if torch.all(zernike_map == 0):
size = zernike_map.shape[0]
psf = torch.zeros_like(zernike_map)
psf[size // 2, size // 2] = 1.0
return psf
else:
pupil_function = torch.exp(1j * 2 * torch.pi * zernike_map)
fft = torch.fft.fft2(pupil_function)
psf = torch.fft.fftshift(torch.abs(fft) ** 2)
psf = psf / psf.sum()
return psf.real
# Precompute PSF and its FFT
zmap = generate_zernike_map_local(n, m, amplitude=AMP)
psf = generate_psf_local(zmap)
psf_tensor = psf.unsqueeze(0).unsqueeze(0).to(device)
if psf_tensor.shape[-1] != IMG_SIZE or psf_tensor.shape[-2] != IMG_SIZE:
psf_tensor = torch.nn.functional.interpolate(psf_tensor, size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
psf_tensor = psf_tensor / psf_tensor.sum()
psf_tensor = torch.fft.fftshift(psf_tensor, dim=(-2, -1))
psf_fft = torch.fft.fft2(psf_tensor)
def aberrate_image_fft(img_tensor, psf_fft):
print(f"[DEBUG] Dimensiones del tensor de entrada para aberraci贸n: {img_tensor.shape}")
img_fft = torch.fft.fft2(img_tensor)
print(f"[DEBUG] Dimensiones del tensor FFT de la imagen: {img_fft.shape}")
result_fft = img_fft * psf_fft
print(f"[DEBUG] Dimensiones del tensor FFT del resultado: {result_fft.shape}")
result = torch.fft.ifft2(result_fft).real
print(f"[DEBUG] Dimensiones del tensor de la imagen aberrada: {result.shape}")
return result
def restore_image(model, aberrated_tensor):
with torch.no_grad():
restored = model(aberrated_tensor)
return restored.squeeze().cpu().numpy()
def process_image(image):
if image is None:
print("[DEBUG] Input image is None.")
return None, None
# Validate input image format
if not isinstance(image, np.ndarray):
raise ValueError("Input must be a NumPy image.")
print(f"[DEBUG] Input image dimensions: {image.shape}")
# Validate color image
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Input must be a color image (3 channels).")
# Resize image to 256x256
img_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
print(f"[DEBUG] Resized image dimensions: {img_resized.shape}")
img_np = img_resized.astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float().to(device) # [1, C, H, W]
print(f"[DEBUG] Image tensor dimensions: {img_tensor.shape}")
with torch.no_grad():
# Process each channel separately
aberrated_channels = []
restored_channels = []
for c in range(img_tensor.shape[1]):
channel_tensor = img_tensor[:, c:c+1, :, :]
aberrated_tensor = aberrate_image_fft(channel_tensor, psf_fft)
restored_tensor = model(aberrated_tensor)
aberrated_channels.append(aberrated_tensor.squeeze(0).cpu().numpy())
restored_channels.append(restored_tensor.squeeze(0).cpu().numpy())
# Combine aberrated channels
aberrated_np = np.stack(aberrated_channels, axis=-1)
aberrated_np = np.clip(aberrated_np, 0, 1)
# Combine restored channels
restored_np = np.stack(restored_channels, axis=-1)
restored_np = np.clip(restored_np, 0, 1)
print(f"[DEBUG] Aberrated image dimensions before squeeze: {aberrated_np.shape}")
print(f"[DEBUG] Restored image dimensions before squeeze: {restored_np.shape}")
# Remove extra dimensions
aberrated_np = np.squeeze(aberrated_np)
restored_np = np.squeeze(restored_np)
print(f"[DEBUG] Aberrated image dimensions after squeeze: {aberrated_np.shape}")
print(f"[DEBUG] Restored image dimensions after squeeze: {restored_np.shape}")
# Convert aberrated image to uint8
aberrated_resized = (aberrated_np * 255).astype(np.uint8)
print(f"[DEBUG] Aberrated image dimensions resized: {aberrated_resized.shape}")
# Convert restored image to uint8
restored_resized = (restored_np * 255).astype(np.uint8)
print(f"[DEBUG] Restored image dimensions resized: {restored_resized.shape}")
# Recortar 10 p铆xeles de cada lado
aberrated_cropped = aberrated_resized[10:-10, 10:-10]
restored_cropped = restored_resized[10:-10, 10:-10]
return aberrated_cropped, restored_cropped
def process_images(images):
if images is None or len(images) == 0:
return None, None
processed_images = []
restored_images = []
for image in images:
# Validar que la imagen sea en color
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Cada entrada debe ser una imagen en color (3 canales).")
# Redimensionar la imagen a 256x256
img_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
img_np = img_resized.astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float().to(device) # [1, C, H, W]
with torch.no_grad():
# Procesar cada canal por separado
channels = []
for c in range(img_tensor.shape[1]):
channel_tensor = img_tensor[:, c:c+1, :, :]
aberrated_tensor = aberrate_image_fft(channel_tensor, psf_fft)
restored_tensor = model(aberrated_tensor)
channels.append(restored_tensor.squeeze(0).cpu().numpy())
# Combinar canales restaurados
restored_np = np.stack(channels, axis=-1)
restored_np = np.clip(restored_np, 0, 1)
# Convertir la imagen restaurada a formato uint8
restored_resized = (restored_np * 255).astype(np.uint8)
# Asegurar que las im谩genes devueltas sean del formato esperado
img_resized = img_resized.astype(np.uint8)
# Recortar 10 p铆xeles de cada lado en la imagen restaurada
restored_cropped = restored_resized[10:-10, 10:-10]
processed_images.append(img_resized)
restored_images.append(restored_cropped)
return processed_images, restored_images
# Actualizar el t铆tulo din谩micamente para reflejar GPU o CPU
device_type = 'GPU' if torch.cuda.is_available() else 'CPU'
with gr.Blocks() as demo:
gr.Markdown(f"# Restauraci贸n de Im谩genes con Aberraci贸n de Zernike y UNet ({device_type}, Hugging Face Spaces)")
with gr.Row():
input_img = gr.Image(label="Imagen de entrada", type="numpy")
aberrated = gr.Image(label="Imagen aberrada")
restored = gr.Image(label="Imagen transformada")
btn = gr.Button("Procesar")
btn.click(fn=process_image, inputs=[input_img], outputs=[aberrated, restored])
demo.launch()