ahmdliaqat commited on
Commit
17685df
·
1 Parent(s): 0ccfa14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -2
app.py CHANGED
@@ -1,2 +1,133 @@
1
- pip uninstall basicsr
2
- pip install basicsr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ import gradio as gr
7
+ from basicsr.archs.rrdbnet_arch import RRDBNet
8
+ from basicsr.utils.download_util import load_file_from_url
9
+ from realesrgan import RealESRGANer
10
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
11
+
12
+ def enhance_image(image, model_name, denoise_strength, outscale, tile, face_enhance, ext):
13
+ # Convert PIL image to OpenCV format (BGR)
14
+ img = np.array(image.convert('RGB'))[:, :, ::-1] # Convert RGB to BGR
15
+
16
+ # Model configuration
17
+ if model_name == 'RealESRGAN_x4plus':
18
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
19
+ netscale = 4
20
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
21
+ elif model_name == 'RealESRNet_x4plus':
22
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
23
+ netscale = 4
24
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
25
+ elif model_name == 'RealESRGAN_x4plus_anime_6B':
26
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
27
+ netscale = 4
28
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
29
+ elif model_name == 'RealESRGAN_x2plus':
30
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
31
+ netscale = 2
32
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
33
+ elif model_name == 'realesr-animevideov3':
34
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
35
+ netscale = 4
36
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
37
+ elif model_name == 'realesr-general-x4v3':
38
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
39
+ netscale = 4
40
+ file_url = [
41
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
42
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
43
+ ]
44
+ else:
45
+ return "Error: Invalid model name."
46
+
47
+ # Download model weights if not available locally
48
+ model_path = os.path.join('weights', model_name + '.pth')
49
+ if not os.path.isfile(model_path):
50
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
51
+ for url in file_url:
52
+ model_path = load_file_from_url(url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
53
+
54
+ # Handle denoise strength
55
+ dni_weight = None
56
+ if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
57
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
58
+ model_path = [model_path, wdn_model_path]
59
+ dni_weight = [denoise_strength, 1 - denoise_strength]
60
+
61
+ # Create upsampler
62
+ upsampler = RealESRGANer(
63
+ scale=netscale,
64
+ model_path=model_path,
65
+ dni_weight=dni_weight,
66
+ model=model,
67
+ tile=tile,
68
+ tile_pad=10,
69
+ pre_pad=0,
70
+ half=False,
71
+ gpu_id=None
72
+ )
73
+
74
+ # Handle face enhancement
75
+ if face_enhance:
76
+ from gfpgan import GFPGANer
77
+ face_enhancer = GFPGANer(
78
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
79
+ upscale=outscale,
80
+ arch='clean',
81
+ channel_multiplier=2,
82
+ bg_upsampler=upsampler
83
+ )
84
+
85
+ # Process the image
86
+ try:
87
+ if face_enhance:
88
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
89
+ else:
90
+ output, _ = upsampler.enhance(img, outscale=outscale)
91
+ except RuntimeError as error:
92
+ return f'Error: {error}'
93
+ else:
94
+ # Convert BGR back to RGB
95
+ output = output[:, :, ::-1] # Convert BGR to RGB
96
+ output = np.clip(output, 0, 255).astype(np.uint8)
97
+ output_image = Image.fromarray(output)
98
+ return output_image
99
+
100
+ # interface-using gradio
101
+ def create_gradio_interface():
102
+ with gr.Blocks() as demo:
103
+ gr.Markdown("## Real-ESRGAN Image Enhancement")
104
+
105
+ with gr.Row():
106
+ with gr.Column():
107
+ image_input = gr.Image(type='pil', label="Input Image")
108
+ model_name = gr.Dropdown(["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B",
109
+ "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"],
110
+ label="Model Name")
111
+ denoise_strength = gr.Slider(0, 1, value=0.5, step=0.1, label="Denoise Strength")
112
+ outscale = gr.Slider(1, 4, value=4, step=1, label="Output Scale")
113
+ tile = gr.Slider(128, 512, value=256, step=64, label="Tile Size")
114
+ face_enhance = gr.Checkbox(False, label="Enable Face Enhancement")
115
+ ext = gr.Dropdown(['auto', 'jpg', 'png'], value='auto', label="Output Extension")
116
+
117
+ generate_button = gr.Button("Generate")
118
+
119
+ with gr.Column():
120
+ output_image = gr.Image(type='pil', label="Output Image")
121
+
122
+ generate_button.click(
123
+ lambda image, model_name, denoise_strength, outscale, tile, face_enhance, ext: enhance_image(
124
+ image, model_name, denoise_strength, outscale, tile, face_enhance, ext
125
+ ),
126
+ inputs=[image_input, model_name, denoise_strength, outscale, tile, face_enhance, ext],
127
+ outputs=[output_image]
128
+ )
129
+
130
+ demo.launch()
131
+
132
+ if __name__ == '__main__':
133
+ create_gradio_interface()