| import torch |
| from PIL import Image |
| from RealESRGAN import RealESRGAN |
| import gradio as gr |
| import os |
| import spaces |
|
|
| if torch.cuda.is_available(): |
| print(f"CUDA is available. GPU: {torch.cuda.get_device_name(0)}") |
| device = torch.device("cuda") |
| else: |
| print("CUDA is not available. Using CPU.") |
| device = torch.device("cpu") |
|
|
| class LazyRealESRGAN: |
| def __init__(self, device, scale): |
| self.device = device |
| self.scale = scale |
| self.model = None |
|
|
| def load_model(self): |
| if self.model is None: |
| self.model = RealESRGAN(self.device, scale=self.scale) |
| self.model.load_weights(f'weights/RealESRGAN_x{self.scale}.pth', download=True) |
|
|
| def predict(self, img): |
| self.load_model() |
| return self.model.predict(img) |
|
|
| model2 = LazyRealESRGAN(device, scale=2) |
| model4 = LazyRealESRGAN(device, scale=4) |
| model8 = LazyRealESRGAN(device, scale=8) |
|
|
| @spaces.GPU |
| def inference(image, size): |
| if image is None: |
| raise gr.Error("Image not uploaded") |
|
|
| try: |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| if size == '2x': |
| result = model2.predict(image.convert('RGB')) |
| elif size == '4x': |
| result = model4.predict(image.convert('RGB')) |
| else: |
| width, height = image.size |
| if width >= 5000 or height >= 5000: |
| raise gr.Error("The image is too large.") |
| result = model8.predict(image.convert('RGB')) |
|
|
| print(f"Image size ({device}): {size} ... OK") |
| return result |
| except torch.cuda.OutOfMemoryError: |
| raise gr.Error("GPU out of memory. Try a smaller image or lower upscaling factor.") |
| except Exception as e: |
| raise gr.Error(f"An error occurred: {str(e)}") |
|
|
| title = "Face Real ESRGAN UpScale: 2x 4x 8x" |
| description = "This is an unofficial demo for Real-ESRGAN. Scales the resolution of a photo. This model shows better results on faces compared to the original version." |
|
|
|
|
|
|
| iface = gr.Interface( |
| inference, |
| [ |
| gr.Image(type="pil"), |
| gr.Radio(["2x", "4x", "8x"], type="value", value="2x", label="Resolution model") |
| ], |
| gr.Image(type="pil", label="Output"), |
| title=title, |
| description=description, |
| flagging_mode="never", |
| cache_examples=True |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| iface.launch(debug=True, show_error=True) |