Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| from PIL import Image | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from modules.postprocess.realesrgan_model_arch import SRVGGNetCompact | |
| from modules.upscaler import Upscaler | |
| from modules.shared import opts, device, log | |
| from modules import devices | |
| class UpscalerRealESRGAN(Upscaler): | |
| def __init__(self, dirname): | |
| self.name = "RealESRGAN" | |
| self.user_path = dirname | |
| super().__init__() | |
| self.scalers = self.find_scalers() | |
| self.models = {} | |
| for scaler in self.scalers: | |
| if scaler.name == 'RealESRGAN 2x+': | |
| scaler.model = lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
| scaler.scale = 2 | |
| elif scaler.name == 'RealESRGAN 4x+ Anime6B': | |
| scaler.model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) | |
| elif scaler.name == 'RealESRGAN 4x General V3': | |
| scaler.model = lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| elif scaler.name == 'RealESRGAN 4x General WDN V3': | |
| scaler.model = lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| elif scaler.name == 'RealESRGAN AnimeVideo V3': | |
| scaler.model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') | |
| elif scaler.name == 'RealESRGAN 4x+': | |
| scaler.model = lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| else: | |
| log.error(f"Upscaler unrecognized model: type={self.name} model={scaler.name}") | |
| def load_model(self, path): # pylint: disable=unused-argument | |
| pass | |
| def do_upscale(self, img, selected_model): | |
| if not self.enable: | |
| return img | |
| try: | |
| from modules.postprocess.realesrgan_model_arch import RealESRGANer | |
| except Exception: | |
| log.error("Error importing Real-ESRGAN:") | |
| return img | |
| info = self.find_model(selected_model) | |
| if info is None or not os.path.exists(info.local_data_path): | |
| return img | |
| if self.models.get(info.local_data_path, None) is not None: | |
| log.debug(f"Upscaler cached: type={self.name} model={info.local_data_path}") | |
| upsampler=self.models[info.local_data_path] | |
| else: | |
| upsampler = RealESRGANer( | |
| name=info.name, | |
| scale=info.scale, | |
| model_path=info.local_data_path, | |
| model=info.model(), | |
| half=not opts.no_half and not opts.upcast_sampling, | |
| tile=opts.upscaler_tile_size, | |
| tile_pad=opts.upscaler_tile_overlap, | |
| device=device, | |
| ) | |
| self.models[info.local_data_path] = upsampler | |
| upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] | |
| if opts.upscaler_unload and info.local_data_path in self.models: | |
| del self.models[info.local_data_path] | |
| log.debug(f"Upscaler unloaded: type={self.name} model={selected_model}") | |
| devices.torch_gc(force=True) | |
| image = Image.fromarray(upsampled) | |
| return image | |