OneRestore / app.py
gy65896's picture
Update app.py
d049335 verified
raw
history blame
3.6 kB
import torch
import gradio as gr
from torchvision import transforms
from PIL import Image
import numpy as np
from utils.utils import load_restore_ckpt, load_embedder_ckpt
import os
from gradio_imageslider import ImageSlider
# Enforce CPU usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedder_model_path = "ckpts/embedder_model.tar" # Update with actual path to embedder checkpoint
restorer_model_path = "ckpts/onerestore_cdd-11.tar" # Update with actual path to restorer checkpoint
# Load models on CPU only
embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=embedder_model_path)
restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=restorer_model_path)
# Define image preprocessing and postprocessing
transform_resize = transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor()
])
def postprocess_image(tensor):
image = tensor.squeeze(0).cpu().detach().numpy()
image = (image) * 255 # Assuming output in [-1, 1], rescale to [0, 255]
image = np.clip(image, 0, 255).astype("uint8") # Clip values to [0, 255]
return Image.fromarray(image.transpose(1, 2, 0)) # Reorder to (H, W, C)
# Define the enhancement function
def enhance_image(image, degradation_type=None):
# Preprocess the image
input_tensor = torch.Tensor((np.array(image)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
# Generate embedding
if degradation_type == "auto" or degradation_type is None:
text_embedding, _, [text] = embedder(lq_em, 'image_encoder')
else:
text_embedding, _, [text] = embedder([degradation_type], 'text_encoder')
# Model inference
with torch.no_grad():
enhanced_tensor = restorer(input_tensor, text_embedding)
# Postprocess the output
return (image, postprocess_image(enhanced_tensor)), text
# Define the Gradio interface
def inference(image, degradation_type=None):
return enhance_image(image, degradation_type)
#### Image,Prompts examples
examples = [
['image/low_haze_rain_00469_01_lq.png'],
['image/low_haze_snow_00337_01_lq.png'],
]
# Create the Gradio app interface using updated API
interface = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="pil", value="image/low_haze_rain_00469_01_lq.png"), # Image input
gr.Dropdown(['auto', 'low', 'haze', 'rain', 'snow',\
'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
'haze_snow', 'low_haze_rain', 'low_haze_snow'], label="Degradation Type", value="auto") # Manual or auto degradation
],
outputs=[
ImageSlider(label="Restored Image",
type="pil",
show_download_button=True,
), # Enhanced image outputImageSlider(type="pil", show_download_button=True, ),
gr.Textbox(label="Degradation Type") # Display the estimated degradation type
],
title="Image Restoration with OneRestore",
description="Upload an image and enhance it using OneRestore model. You can choose to let the model automatically estimate the degradation type or set it manually.",
examples=examples,
)
# Launch the app
if __name__ == "__main__":
interface.launch()