aamsko commited on
Commit
4c97d0f
·
verified ·
1 Parent(s): 6123473

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -49
app.py CHANGED
@@ -4,43 +4,33 @@ import torch
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
7
- import tempfile
8
- import requests
9
-
10
  from gfpgan import GFPGANer
11
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
12
  from realesrgan.utils import RealESRGANer
13
 
14
- MODEL_DIR = os.path.join(os.path.expanduser("~"), ".imgen_models")
15
- os.makedirs(MODEL_DIR, exist_ok=True)
16
-
17
- def download_model(url, path):
18
- if not os.path.exists(path):
19
- print(f"Downloading model from {url}...")
20
- r = requests.get(url, allow_redirects=True)
21
- with open(path, 'wb') as f:
22
- f.write(r.content)
23
 
24
- download_model(
25
- "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
26
- os.path.join(MODEL_DIR, "GFPGANv1.4.pth")
27
- )
28
 
29
- download_model(
30
- "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
31
- os.path.join(MODEL_DIR, "realesr-general-x4v3.pth")
32
- )
33
 
34
- GFPGAN_MODEL = os.path.join(MODEL_DIR, "GFPGANv1.4.pth")
35
- ESRGAN_MODEL = os.path.join(MODEL_DIR, "realesr-general-x4v3.pth")
36
-
37
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
38
- print(f"Using device: {device}")
39
 
 
40
  esr_model = SRVGGNetCompact(
41
  num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32,
42
  upscale=4, act_type='prelu'
43
- ).to(device)
44
 
45
  bg_upsampler = RealESRGANer(
46
  scale=4,
@@ -49,45 +39,39 @@ bg_upsampler = RealESRGANer(
49
  tile=0,
50
  tile_pad=10,
51
  pre_pad=0,
52
- half=(device=='cuda')
53
  )
54
 
 
55
  restorer = GFPGANer(
56
  model_path=GFPGAN_MODEL,
57
  upscale=2,
58
  arch='clean',
59
  channel_multiplier=2,
60
- bg_upsampler=bg_upsampler,
61
- device=device
62
  )
63
 
 
64
  def enhance(image):
65
- try:
66
- img_np = np.array(image.convert("RGB"))
67
- img = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
68
-
69
- _, _, restored_img = restorer.enhance(
70
- img,
71
- has_aligned=False,
72
- only_center_face=False,
73
- paste_back=True
74
- )
75
- restored_pil = Image.fromarray(cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB))
76
 
77
- temp_file = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
78
- restored_pil.save(temp_file.name, format="JPEG", quality=95)
 
 
 
 
79
 
80
- return temp_file.name
81
- except Exception as e:
82
- return f"Error: {e}"
83
 
 
84
  iface = gr.Interface(
85
  fn=enhance,
86
- inputs=gr.Image(type="pil", label="Upload Image"),
87
- outputs=gr.File(label="Download Enhanced JPG"),
88
  title="IMGEN - AI Photo Enhancer (Face + Outfit)",
89
- description="Enhance face and outfit with GFPGAN + RealESRGAN.",
90
- allow_flagging="never"
91
  )
92
 
93
  if __name__ == "__main__":
 
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
 
 
 
7
  from gfpgan import GFPGANer
8
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
9
  from realesrgan.utils import RealESRGANer
10
 
11
+ # Download GFPGAN model if not already present
12
+ GFPGAN_MODEL = "GFPGANv1.4.pth"
13
+ GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
 
 
 
 
 
 
14
 
15
+ if not os.path.exists(GFPGAN_MODEL):
16
+ import requests
17
+ r = requests.get(GFPGAN_URL, allow_redirects=True)
18
+ open(GFPGAN_MODEL, 'wb').write(r.content)
19
 
20
+ # Download RealESRGAN model
21
+ ESRGAN_MODEL = "realesr-general-x4v3.pth"
22
+ ESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
 
23
 
24
+ if not os.path.exists(ESRGAN_MODEL):
25
+ import requests
26
+ r = requests.get(ESRGAN_URL, allow_redirects=True)
27
+ open(ESRGAN_MODEL, 'wb').write(r.content)
 
28
 
29
+ # Initialize Real-ESRGAN as background upsampler
30
  esr_model = SRVGGNetCompact(
31
  num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32,
32
  upscale=4, act_type='prelu'
33
+ )
34
 
35
  bg_upsampler = RealESRGANer(
36
  scale=4,
 
39
  tile=0,
40
  tile_pad=10,
41
  pre_pad=0,
42
+ half=torch.cuda.is_available()
43
  )
44
 
45
+ # Initialize GFPGAN with Real-ESRGAN as background upsampler
46
  restorer = GFPGANer(
47
  model_path=GFPGAN_MODEL,
48
  upscale=2,
49
  arch='clean',
50
  channel_multiplier=2,
51
+ bg_upsampler=bg_upsampler
 
52
  )
53
 
54
+ # Enhancement function
55
  def enhance(image):
56
+ img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
 
 
 
57
 
58
+ _, _, restored_img = restorer.enhance(
59
+ img,
60
+ has_aligned=False,
61
+ only_center_face=False,
62
+ paste_back=True
63
+ )
64
 
65
+ restored_pil = Image.fromarray(cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB))
66
+ return restored_pil
 
67
 
68
+ # Gradio UI
69
  iface = gr.Interface(
70
  fn=enhance,
71
+ inputs=gr.Image(type="pil"),
72
+ outputs=gr.Image(type="pil"),
73
  title="IMGEN - AI Photo Enhancer (Face + Outfit)",
74
+ description="Upload your photo (ID, CV, profile) and enhance both the face and outfit with AI using GFPGAN + RealESRGAN."
 
75
  )
76
 
77
  if __name__ == "__main__":