aamsko commited on
Commit
51b72e4
Β·
verified Β·
1 Parent(s): 31b0b13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -90
app.py CHANGED
@@ -1,94 +1,19 @@
1
- import os
2
- import cv2
3
- import torch
4
  import gradio as gr
5
- import numpy as np
6
- from PIL import Image
7
- import tempfile
8
- import requests
9
-
10
- from gfpgan import GFPGANer
11
- from basicsr.archs.srvgg_arch import SRVGGNetCompact
12
- from realesrgan.utils import RealESRGANer
13
-
14
- # Create model directory
15
- MODEL_DIR = os.path.join(os.path.expanduser("~"), ".imgen_models")
16
- os.makedirs(MODEL_DIR, exist_ok=True)
17
-
18
- # Download GFPGAN model if missing
19
- GFPGAN_MODEL = os.path.join(MODEL_DIR, "GFPGANv1.4.pth")
20
- GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
21
- if not os.path.exists(GFPGAN_MODEL):
22
- print("Downloading GFPGAN model...")
23
- r = requests.get(GFPGAN_URL, allow_redirects=True)
24
- with open(GFPGAN_MODEL, 'wb') as f:
25
- f.write(r.content)
26
-
27
- # Download RealESRGAN model if missing
28
- ESRGAN_MODEL = os.path.join(MODEL_DIR, "realesr-general-x4v3.pth")
29
- ESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
30
- if not os.path.exists(ESRGAN_MODEL):
31
- print("Downloading Real-ESRGAN model...")
32
- r = requests.get(ESRGAN_URL, allow_redirects=True)
33
- with open(ESRGAN_MODEL, 'wb') as f:
34
- f.write(r.content)
35
-
36
- # Setup RealESRGAN background upsampler
37
- esr_model = SRVGGNetCompact(
38
- num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32,
39
- upscale=4, act_type='prelu'
40
- )
41
-
42
- bg_upsampler = RealESRGANer(
43
- scale=4,
44
- model_path=ESRGAN_MODEL,
45
- model=esr_model,
46
- tile=0,
47
- tile_pad=10,
48
- pre_pad=0,
49
- half=torch.cuda.is_available()
50
- )
51
-
52
- # Setup GFPGAN restorer with bg upsampler
53
- restorer = GFPGANer(
54
- model_path=GFPGAN_MODEL,
55
- upscale=2,
56
- arch='clean',
57
- channel_multiplier=2,
58
- bg_upsampler=bg_upsampler
59
- )
60
-
61
- def enhance(image):
62
- # Convert PIL image to OpenCV BGR
63
- img_np = np.array(image.convert("RGB"))
64
- img = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
65
-
66
- # Enhance faces + upscale background/outfit
67
- _, _, restored_img = restorer.enhance(
68
- img,
69
- has_aligned=False,
70
- only_center_face=False, # Enhance all faces in image
71
- paste_back=True # Paste restored faces back on bg
72
- )
73
-
74
- # Convert back to PIL RGB
75
- restored_pil = Image.fromarray(cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB))
76
-
77
- # Save to temporary JPG file for download
78
- temp_file = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
79
- restored_pil.save(temp_file.name, format="JPEG", quality=95)
80
-
81
- return temp_file.name
82
-
83
- iface = gr.Interface(
84
- fn=enhance,
85
- inputs=gr.Image(type="pil", label="Upload Image (ID, CV, Profile)"),
86
- outputs=gr.File(label="Download Enhanced JPG"),
87
- title="πŸ“Έ IMGEN - AI Photo Enhancer (Face + Outfit)",
88
- description="Upload your photo and enhance both face and outfit with AI (GFPGAN + RealESRGAN). Output is a downloadable JPG file.",
89
- allow_flagging="never"
90
  )
91
 
92
  if __name__ == "__main__":
93
- print("Running on GPU" if torch.cuda.is_available() else "Running on CPU")
94
- iface.launch()
 
 
 
 
1
  import gradio as gr
2
+ from model.face_enhancer import enhance_face
3
+ from model.dress_enhancer import enhance_clothes
4
+
5
+ def enhance_image(input_image):
6
+ face_enhanced = enhance_face(input_image)
7
+ final_image = enhance_clothes(face_enhanced)
8
+ return final_image
9
+
10
+ demo = gr.Interface(
11
+ fn=enhance_image,
12
+ inputs=gr.Image(type="pil", label="Upload an image"),
13
+ outputs=gr.Image(type="pil", label="Enhanced Image"),
14
+ title="Face + Dress Image Enhancer",
15
+ description="Enhances facial details using GFPGAN and clothing using SwinIR.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
17
 
18
  if __name__ == "__main__":
19
+ demo.launch()