manirafy10-spec commited on
Commit
01d2b1b
·
1 Parent(s): 8715ff4

Add PicPro Gradio backend with modular services and weights

Browse files
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+
5
+ # Add current directory to path so we can import services
6
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
7
+
8
+ from services.enhancer import ImageEnhancer
9
+ from services.background_remover import BackgroundRemover
10
+ from utils.download_weights import main as download_weights
11
+
12
+ # --- Initialization ---
13
+ print("Initializing PicPro AI Backend...")
14
+
15
+ # Ensure weights are present
16
+ try:
17
+ download_weights()
18
+ except Exception as e:
19
+ print(f"Weight download warning: {e}")
20
+
21
+ # Global Instances
22
+ enhancer = ImageEnhancer()
23
+ bg_remover = BackgroundRemover()
24
+
25
+ # Pre-load models (optional, can be lazy)
26
+ # enhancer.load_models()
27
+ # bg_remover.load_model()
28
+
29
+ # --- Wrapper Functions for Gradio ---
30
+
31
+ def enhance_image(image):
32
+ """
33
+ Args:
34
+ image: numpy array (RGB) from Gradio
35
+ Returns:
36
+ numpy array (RGB)
37
+ """
38
+ if image is None:
39
+ return None
40
+
41
+ try:
42
+ result = enhancer.process(image)
43
+ return result
44
+ except Exception as e:
45
+ raise gr.Error(f"Enhancement failed: {str(e)}")
46
+
47
+ def remove_background(image):
48
+ """
49
+ Args:
50
+ image: PIL Image from Gradio
51
+ Returns:
52
+ PIL Image
53
+ """
54
+ if image is None:
55
+ return None
56
+
57
+ try:
58
+ result = bg_remover.process(image)
59
+ return result
60
+ except Exception as e:
61
+ raise gr.Error(f"Background removal failed: {str(e)}")
62
+
63
+ # --- Gradio Interface ---
64
+
65
+ with gr.Blocks(title="PicPro AI Backend") as app:
66
+ gr.Markdown("# PicPro AI Backend Services")
67
+ gr.Markdown("GPU-accelerated Image Enhancement and Background Removal.")
68
+
69
+ with gr.Tab("Enhance Image"):
70
+ gr.Markdown("### Real-ESRGAN + GFPGAN")
71
+ with gr.Row():
72
+ enhance_input = gr.Image(label="Input Image", type="numpy")
73
+ enhance_output = gr.Image(label="Enhanced Image", type="numpy")
74
+
75
+ enhance_btn = gr.Button("Enhance", variant="primary")
76
+ enhance_btn.click(
77
+ fn=enhance_image,
78
+ inputs=enhance_input,
79
+ outputs=enhance_output,
80
+ api_name="enhance_image"
81
+ )
82
+
83
+ with gr.Tab("Remove Background"):
84
+ gr.Markdown("### MODNet / U-2-Net (rembg)")
85
+ with gr.Row():
86
+ bg_input = gr.Image(label="Input Image", type="pil")
87
+ bg_output = gr.Image(label="No Background", type="pil", image_mode="RGBA")
88
+
89
+ bg_btn = gr.Button("Remove Background", variant="primary")
90
+ bg_btn.click(
91
+ fn=remove_background,
92
+ inputs=bg_input,
93
+ outputs=bg_output,
94
+ api_name="remove_background"
95
+ )
96
+
97
+ # Launch
98
+ if __name__ == "__main__":
99
+ app.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ numpy
5
+ opencv-python-headless
6
+ Pillow
7
+ gfpgan>=1.3.8
8
+ realesrgan>=0.3.0
9
+ basicsr>=1.4.2
10
+ rembg[gpu]
11
+ requests
12
+ aiofiles
services/__init__.py ADDED
File without changes
services/background_remover.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ try:
3
+ from rembg import remove, new_session
4
+ except ImportError:
5
+ print("Warning: rembg not found.")
6
+ remove = None
7
+ new_session = None
8
+
9
+ class BackgroundRemover:
10
+ def __init__(self):
11
+ self.session = None
12
+
13
+ def load_model(self):
14
+ if self.session is not None:
15
+ return
16
+
17
+ if new_session is None:
18
+ raise ImportError("rembg not installed")
19
+
20
+ # Initialize session
21
+ # 'u2net' is robust.
22
+ self.session = new_session("u2net")
23
+ print("Background Remover loaded")
24
+
25
+ def process(self, input_image: Image.Image) -> Image.Image:
26
+ """
27
+ Process a PIL Image.
28
+ Returns: PIL Image with background removed.
29
+ """
30
+ if self.session is None:
31
+ self.load_model()
32
+
33
+ # Process
34
+ output_image = remove(input_image, session=self.session)
35
+ return output_image
services/enhancer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ try:
6
+ from basicsr.archs.rrdbnet_arch import RRDBNet
7
+ from realesrgan import RealESRGANer
8
+ from gfpgan import GFPGANer
9
+ except ImportError:
10
+ print("Warning: AI modules not found. Ensure requirements are installed.")
11
+ RRDBNet = None
12
+ RealESRGANer = None
13
+ GFPGANer = None
14
+
15
+ class ImageEnhancer:
16
+ def __init__(self):
17
+ self.upsampler = None
18
+ self.face_enhancer = None
19
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+ def load_models(self):
22
+ if self.upsampler is not None:
23
+ return
24
+
25
+ if RRDBNet is None:
26
+ raise ImportError("AI modules not installed")
27
+
28
+ # Paths
29
+ # In HF Spaces, we can download weights if missing or rely on cache
30
+ weights_dir = 'weights'
31
+ os.makedirs(weights_dir, exist_ok=True)
32
+
33
+ realesrgan_path = os.path.join(weights_dir, 'RealESRGAN_x4plus.pth')
34
+ gfpgan_path = os.path.join(weights_dir, 'GFPGANv1.4.pth')
35
+
36
+ # Simple check/download if missing (handled by utils/download_weights.py usually)
37
+ # For robustness in Gradio Space, we assume they exist or we rely on the helper
38
+ if not os.path.exists(realesrgan_path) or not os.path.exists(gfpgan_path):
39
+ print("Weights not found locally. Ensure they are downloaded.")
40
+ # Fallback or error handled by caller
41
+
42
+ # Load Real-ESRGAN
43
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
44
+
45
+ self.upsampler = RealESRGANer(
46
+ scale=4,
47
+ model_path=realesrgan_path,
48
+ model=model,
49
+ tile=0,
50
+ tile_pad=10,
51
+ pre_pad=0,
52
+ half=True if self.device.type == 'cuda' else False,
53
+ device=self.device,
54
+ )
55
+
56
+ # Load GFPGAN
57
+ self.face_enhancer = GFPGANer(
58
+ model_path=gfpgan_path,
59
+ upscale=4,
60
+ arch='clean',
61
+ channel_multiplier=2,
62
+ bg_upsampler=self.upsampler
63
+ )
64
+ print(f"Enhancer loaded on {self.device}")
65
+
66
+ def process(self, img_rgb):
67
+ """
68
+ Process an image (numpy array, RGB).
69
+ Returns: Enhanced image (numpy array, RGB).
70
+ """
71
+ if self.face_enhancer is None:
72
+ self.load_models()
73
+
74
+ # Convert RGB (Gradio) to BGR (OpenCV/GFPGAN)
75
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
76
+
77
+ # Inference
78
+ # has_aligned=False means it will detect faces
79
+ _, _, output_bgr = self.face_enhancer.enhance(
80
+ img_bgr,
81
+ has_aligned=False,
82
+ only_center_face=False,
83
+ paste_back=True
84
+ )
85
+
86
+ # Convert BGR back to RGB for Gradio
87
+ output_rgb = cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)
88
+ return output_rgb
utils/__init__.py ADDED
File without changes
utils/download_weights.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import sys
4
+
5
+ def download_file(url, dest_path):
6
+ if os.path.exists(dest_path):
7
+ print(f"File already exists: {dest_path}")
8
+ return
9
+
10
+ print(f"Downloading {url} to {dest_path}...")
11
+ try:
12
+ response = requests.get(url, stream=True)
13
+ response.raise_for_status()
14
+ with open(dest_path, 'wb') as f:
15
+ for chunk in response.iter_content(chunk_size=8192):
16
+ f.write(chunk)
17
+ print("Download complete.")
18
+ except Exception as e:
19
+ print(f"Error downloading {url}: {e}")
20
+ if os.path.exists(dest_path):
21
+ os.remove(dest_path)
22
+
23
+ def main():
24
+ weights_dir = os.path.join(os.path.dirname(__file__), '..', 'weights')
25
+ os.makedirs(weights_dir, exist_ok=True)
26
+
27
+ # Real-ESRGAN
28
+ realesrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
29
+ download_file(realesrgan_url, os.path.join(weights_dir, "RealESRGAN_x4plus.pth"))
30
+
31
+ # GFPGAN
32
+ gfpgan_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
33
+ download_file(gfpgan_url, os.path.join(weights_dir, "GFPGANv1.4.pth"))
34
+
35
+ # rembg will download its own models to ~/.u2net by default,
36
+ # but for HF Spaces we might want to cache it.
37
+ # For now, rembg handles it on first run.
38
+
39
+ if __name__ == "__main__":
40
+ main()
weights/GFPGANv1.4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f63b495992c0874c28e6c627fce803896264687beed2b44b9957bcfd890edc78
3
+ size 4243456
weights/RealESRGAN_x4plus.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1
3
+ size 67040989