Merlimhhs commited on
Commit
1be7ff1
·
verified ·
1 Parent(s): 51523af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -33
app.py CHANGED
@@ -2,40 +2,42 @@ import os
2
  import zipfile
3
  import tempfile
4
  from pathlib import Path
 
5
 
6
  import cv2
7
  import gradio as gr
8
  import numpy as np
9
  import torch
10
- from huggingface_hub import hf_hub_download
11
  from basicsr.archs.rrdbnet_arch import RRDBNet
12
  from realesrgan import RealESRGANer
13
 
14
  # =========================
15
  # CONFIG
16
  # =========================
17
- OUTSCALE = 2 # saída final em 2x
18
 
19
- HF_REPO_ID = os.getenv("HF_REPO_ID", "xinntao/Real-ESRGAN")
20
- HF_FILENAME = os.getenv("HF_FILENAME", "RealESRGAN_x4plus_anime_6B.pth")
21
- CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf-cache")
 
22
 
23
 
24
- def download_model() -> str:
25
- return hf_hub_download(
26
- repo_id=HF_REPO_ID,
27
- filename=HF_FILENAME,
28
- cache_dir=CACHE_DIR,
29
- )
 
 
30
 
31
 
32
  def build_upsampler():
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  use_half = device == "cuda"
35
 
36
- model_path = download_model()
37
 
38
- # Modelo anime 6B
39
  model = RRDBNet(
40
  num_in_ch=3,
41
  num_out_ch=3,
@@ -45,7 +47,7 @@ def build_upsampler():
45
  scale=4,
46
  )
47
 
48
- upsampler = RealESRGANer(
49
  scale=4,
50
  model_path=model_path,
51
  model=model,
@@ -55,7 +57,6 @@ def build_upsampler():
55
  half=use_half,
56
  device=device,
57
  )
58
- return upsampler
59
 
60
 
61
  UPSAMPLER = build_upsampler()
@@ -69,7 +70,6 @@ def upscale_one_image(image: np.ndarray) -> np.ndarray:
69
  image = np.clip(image, 0, 1)
70
  image = (image * 255).astype(np.uint8)
71
 
72
- # Suporte a alpha
73
  if image.ndim == 3 and image.shape[2] == 4:
74
  rgb = image[:, :, :3]
75
  alpha = image[:, :, 3]
@@ -92,11 +92,6 @@ def upscale_one_image(image: np.ndarray) -> np.ndarray:
92
 
93
 
94
  def process_batch(files):
95
- """
96
- Recebe uma lista de arquivos, processa um por um e retorna:
97
- - galeria com previews
98
- - caminho do zip final
99
- """
100
  if not files:
101
  return [], None
102
 
@@ -114,7 +109,6 @@ def process_batch(files):
114
  if image is None:
115
  continue
116
 
117
- # OpenCV lê em BGR/BGRA; converter para RGB/RGBA para o pipeline
118
  if image.ndim == 3 and image.shape[2] == 4:
119
  image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
120
  elif image.ndim == 3:
@@ -128,14 +122,11 @@ def process_batch(files):
128
  out_path = out_dir / out_name
129
 
130
  if result.ndim == 3 and result.shape[2] == 4:
131
- # RGBA -> BGRA para salvar via OpenCV
132
  save_img = cv2.cvtColor(result, cv2.COLOR_RGBA2BGRA)
133
  else:
134
- # RGB -> BGR para salvar via OpenCV
135
  save_img = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
136
 
137
  cv2.imwrite(str(out_path), save_img)
138
-
139
  previews.append((result, out_name))
140
 
141
  zip_path = tmpdir / "upscaled_images.zip"
@@ -143,7 +134,6 @@ def process_batch(files):
143
  for img_file in out_dir.iterdir():
144
  zf.write(img_file, arcname=img_file.name)
145
 
146
- # Copia o zip para um caminho persistente temporário do Gradio
147
  final_zip = Path(tempfile.gettempdir()) / "upscaled_images.zip"
148
  final_zip.write_bytes(zip_path.read_bytes())
149
 
@@ -151,14 +141,13 @@ def process_batch(files):
151
 
152
 
153
  with gr.Blocks() as demo:
154
- gr.Markdown("# Anime Upscaler 2x\nUpload em lote com saída em ZIP.")
155
 
156
- with gr.Row():
157
- files_in = gr.Files(
158
- label="Envie várias imagens",
159
- file_types=["image"],
160
- file_count="multiple",
161
- )
162
 
163
  run_btn = gr.Button("Processar")
164
  gallery_out = gr.Gallery(label="Prévia", columns=2, height=420)
 
2
  import zipfile
3
  import tempfile
4
  from pathlib import Path
5
+ from urllib.request import urlretrieve
6
 
7
  import cv2
8
  import gradio as gr
9
  import numpy as np
10
  import torch
 
11
  from basicsr.archs.rrdbnet_arch import RRDBNet
12
  from realesrgan import RealESRGANer
13
 
14
  # =========================
15
  # CONFIG
16
  # =========================
17
+ OUTSCALE = 2
18
 
19
+ # Peso oficial do anime model mostrado no README do Real-ESRGAN
20
+ MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
21
+ MODEL_DIR = Path("weights")
22
+ MODEL_PATH = MODEL_DIR / "RealESRGAN_x4plus_anime_6B.pth"
23
 
24
 
25
+ def ensure_model() -> str:
26
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
27
+
28
+ if MODEL_PATH.exists() and MODEL_PATH.stat().st_size > 0:
29
+ return str(MODEL_PATH)
30
+
31
+ urlretrieve(MODEL_URL, MODEL_PATH)
32
+ return str(MODEL_PATH)
33
 
34
 
35
  def build_upsampler():
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
  use_half = device == "cuda"
38
 
39
+ model_path = ensure_model()
40
 
 
41
  model = RRDBNet(
42
  num_in_ch=3,
43
  num_out_ch=3,
 
47
  scale=4,
48
  )
49
 
50
+ return RealESRGANer(
51
  scale=4,
52
  model_path=model_path,
53
  model=model,
 
57
  half=use_half,
58
  device=device,
59
  )
 
60
 
61
 
62
  UPSAMPLER = build_upsampler()
 
70
  image = np.clip(image, 0, 1)
71
  image = (image * 255).astype(np.uint8)
72
 
 
73
  if image.ndim == 3 and image.shape[2] == 4:
74
  rgb = image[:, :, :3]
75
  alpha = image[:, :, 3]
 
92
 
93
 
94
  def process_batch(files):
 
 
 
 
 
95
  if not files:
96
  return [], None
97
 
 
109
  if image is None:
110
  continue
111
 
 
112
  if image.ndim == 3 and image.shape[2] == 4:
113
  image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
114
  elif image.ndim == 3:
 
122
  out_path = out_dir / out_name
123
 
124
  if result.ndim == 3 and result.shape[2] == 4:
 
125
  save_img = cv2.cvtColor(result, cv2.COLOR_RGBA2BGRA)
126
  else:
 
127
  save_img = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
128
 
129
  cv2.imwrite(str(out_path), save_img)
 
130
  previews.append((result, out_name))
131
 
132
  zip_path = tmpdir / "upscaled_images.zip"
 
134
  for img_file in out_dir.iterdir():
135
  zf.write(img_file, arcname=img_file.name)
136
 
 
137
  final_zip = Path(tempfile.gettempdir()) / "upscaled_images.zip"
138
  final_zip.write_bytes(zip_path.read_bytes())
139
 
 
141
 
142
 
143
  with gr.Blocks() as demo:
144
+ gr.Markdown("# Anime Upscaler 2x\nUpload em lote e baixe um ZIP.")
145
 
146
+ files_in = gr.Files(
147
+ label="Envie várias imagens",
148
+ file_types=["image"],
149
+ file_count="multiple",
150
+ )
 
151
 
152
  run_btn = gr.Button("Processar")
153
  gallery_out = gr.Gallery(label="Prévia", columns=2, height=420)