ahmdliaqat commited on
Commit
80c8cf9
·
1 Parent(s): e770ba4

initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +141 -0
  2. app.py +132 -0
  3. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # Virtual environments
30
+ venv/
31
+ env/
32
+ .venv/
33
+ .env/
34
+ ENV/
35
+ .env.bak/
36
+ venv.bak/
37
+ env.bak/
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .nox/
53
+ coverage.xml
54
+ *.cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+ coverage/
59
+ nosetests.xml
60
+ *.log
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ .python-version
91
+
92
+ # celery beat schedule file
93
+ celerybeat-schedule
94
+ celerybeat.pid
95
+
96
+ # SageMath parsed files
97
+ *.sage.py
98
+
99
+ # dotenv
100
+ .env
101
+
102
+ # Virtualenv
103
+ .venv/
104
+ env/
105
+ venv/
106
+ ENV/
107
+ env.bak/
108
+ venv.bak/
109
+
110
+ # Spyder project settings
111
+ .spyderproject
112
+ .spyproject
113
+
114
+ # Rope project settings
115
+ .ropeproject
116
+
117
+ # mkdocs documentation
118
+ /site
119
+
120
+ # mypy
121
+ .mypy_cache/
122
+ .dmypy.json
123
+ dmypy.json
124
+
125
+ # Pyre type checker
126
+ .pyre/
127
+
128
+ # pytype static type analyzer
129
+ .pytype/
130
+
131
+ # Cython debug symbols
132
+ cython_debug/
133
+
134
+ # PyCharm
135
+ .idea/
136
+
137
+ # VS Code
138
+ .vscode/
139
+
140
+ # Local environment
141
+ weights/
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ import gradio as gr
7
+ from basicsr.archs.rrdbnet_arch import RRDBNet
8
+ from basicsr.utils.download_util import load_file_from_url
9
+ from realesrgan import RealESRGANer
10
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
11
+
12
+ def enhance_image(image, model_name, denoise_strength, outscale, tile, face_enhance, ext):
13
+ # Convert PIL image to OpenCV format (BGR)
14
+ img = np.array(image.convert('RGB'))[:, :, ::-1] # Convert RGB to BGR
15
+
16
+ # Model configuration
17
+ if model_name == 'RealESRGAN_x4plus':
18
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
19
+ netscale = 4
20
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
21
+ elif model_name == 'RealESRNet_x4plus':
22
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
23
+ netscale = 4
24
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
25
+ elif model_name == 'RealESRGAN_x4plus_anime_6B':
26
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
27
+ netscale = 4
28
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
29
+ elif model_name == 'RealESRGAN_x2plus':
30
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
31
+ netscale = 2
32
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
33
+ elif model_name == 'realesr-animevideov3':
34
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
35
+ netscale = 4
36
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
37
+ elif model_name == 'realesr-general-x4v3':
38
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
39
+ netscale = 4
40
+ file_url = [
41
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
42
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
43
+ ]
44
+ else:
45
+ return "Error: Invalid model name."
46
+
47
+ # Download model weights if not available locally
48
+ model_path = os.path.join('weights', model_name + '.pth')
49
+ if not os.path.isfile(model_path):
50
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
51
+ for url in file_url:
52
+ model_path = load_file_from_url(url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
53
+
54
+ # Handle denoise strength
55
+ dni_weight = None
56
+ if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
57
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
58
+ model_path = [model_path, wdn_model_path]
59
+ dni_weight = [denoise_strength, 1 - denoise_strength]
60
+
61
+ # Create upsampler
62
+ upsampler = RealESRGANer(
63
+ scale=netscale,
64
+ model_path=model_path,
65
+ dni_weight=dni_weight,
66
+ model=model,
67
+ tile=tile,
68
+ tile_pad=10,
69
+ pre_pad=0,
70
+ half=False,
71
+ gpu_id=None
72
+ )
73
+
74
+ # Handle face enhancement
75
+ if face_enhance:
76
+ from gfpgan import GFPGANer
77
+ face_enhancer = GFPGANer(
78
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
79
+ upscale=outscale,
80
+ arch='clean',
81
+ channel_multiplier=2,
82
+ bg_upsampler=upsampler
83
+ )
84
+
85
+ # Process the image
86
+ try:
87
+ if face_enhance:
88
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
89
+ else:
90
+ output, _ = upsampler.enhance(img, outscale=outscale)
91
+ except RuntimeError as error:
92
+ return f'Error: {error}'
93
+ else:
94
+ # Convert BGR back to RGB
95
+ output = output[:, :, ::-1] # Convert BGR to RGB
96
+ output = np.clip(output, 0, 255).astype(np.uint8)
97
+ output_image = Image.fromarray(output)
98
+ return output_image
99
+
100
+ def create_gradio_interface():
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("## Real-ESRGAN Image Enhancement")
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ image_input = gr.Image(type='pil', label="Input Image")
107
+ model_name = gr.Dropdown(["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B",
108
+ "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"],
109
+ label="Model Name")
110
+ denoise_strength = gr.Slider(0, 1, value=0.5, step=0.1, label="Denoise Strength")
111
+ outscale = gr.Slider(1, 4, value=4, step=1, label="Output Scale")
112
+ tile = gr.Slider(128, 512, value=256, step=64, label="Tile Size")
113
+ face_enhance = gr.Checkbox(False, label="Enable Face Enhancement")
114
+ ext = gr.Dropdown(['auto', 'jpg', 'png'], value='auto', label="Output Extension")
115
+
116
+ generate_button = gr.Button("Generate")
117
+
118
+ with gr.Column():
119
+ output_image = gr.Image(type='pil', label="Output Image")
120
+
121
+ generate_button.click(
122
+ lambda image, model_name, denoise_strength, outscale, tile, face_enhance, ext: enhance_image(
123
+ image, model_name, denoise_strength, outscale, tile, face_enhance, ext
124
+ ),
125
+ inputs=[image_input, model_name, denoise_strength, outscale, tile, face_enhance, ext],
126
+ outputs=[output_image]
127
+ )
128
+
129
+ demo.launch()
130
+
131
+ if __name__ == '__main__':
132
+ create_gradio_interface()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ opencv-python-headless
3
+ numpy
4
+ pillow
5
+ gradio
6
+ basicsr
7
+ realesrgan
8
+ gfpgan
9
+ requests