Andro0s commited on
Commit
5e5bf12
·
verified ·
1 Parent(s): a4d2003

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -332
app.py CHANGED
@@ -1,357 +1,86 @@
1
- import os
2
- import sys
3
- import torch
4
- import subprocess
5
- import importlib
6
-
7
- # ================== ARREGLO CRÍTICO ==================
8
- # Parche para torchvision.transforms.functional_tensor
9
- try:
10
- import torchvision.transforms.functional as F
11
- # Crear un alias para compatibilidad
12
- sys.modules['torchvision.transforms.functional_tensor'] = F
13
- # También necesitamos el rgb_to_grayscale
14
- if not hasattr(F, 'rgb_to_grayscale'):
15
- from torchvision.transforms.functional import rgb_to_grayscale
16
- F.rgb_to_grayscale = rgb_to_grayscale
17
- except Exception as e:
18
- print(f"Warning: Could not patch torchvision: {e}")
19
-
20
- # ================== ARREGLO ALTERNATIVO ==================
21
- # O también puedes forzar la importación correcta
22
- def patch_torchvision():
23
- try:
24
- # Esto evita el error directamente
25
- import torchvision.transforms.functional as TF
26
- import types
27
-
28
- # Crear módulo falso
29
- fake_module = types.ModuleType('torchvision.transforms.functional_tensor')
30
- fake_module.rgb_to_grayscale = TF.rgb_to_grayscale
31
- sys.modules['torchvision.transforms.functional_tensor'] = fake_module
32
- except:
33
- pass
34
-
35
- # Ejecutar el parche ANTES de importar basicsr
36
- patch_torchvision()
37
-
38
- # ================== IMPORTS DESPUÉS DEL PARCHE ==================
39
  import gradio as gr
40
  import cv2
41
  import numpy as np
42
  from PIL import Image
 
 
 
 
43
 
44
- # Instalar basicsr específico
45
- def install_deps():
46
- """Instala dependencias con versiones compatibles"""
47
- required_packages = [
48
- 'opencv-python',
49
- 'gradio>=4.0.0',
50
- 'Pillow',
51
- 'numpy',
52
- 'scipy',
53
- 'tqdm',
54
- 'lmdb',
55
- 'yapf',
56
- 'tb-nightly',
57
- 'flake8',
58
- 'yapf',
59
- 'isort',
60
- 'gdown'
61
- ]
62
-
63
- # Intentar instalar basicsr desde GitHub (versión compatible)
64
- try:
65
- subprocess.check_call([sys.executable, "-m", "pip", "install",
66
- "git+https://github.com/xinntao/BasicSR.git"])
67
- except:
68
- # Fallback a pip
69
- subprocess.check_call([sys.executable, "-m", "pip", "install", "basicsr"])
70
-
71
- # Llamar a la instalación (comentada en producción)
72
- # install_deps()
73
-
74
- # Ahora importamos realesrgan
75
- try:
76
- from basicsr.archs.rrdbnet_arch import RRDBNet
77
- from realesrgan import RealESRGANer
78
- except ImportError as e:
79
- print(f"Error importing: {e}")
80
- print("Trying alternative import...")
81
- # Intento alternativo
82
- try:
83
- from realesrgan.realesrgan import RealESRGANer
84
- from realesrgan.archs.rrdbnet_arch import RRDBNet
85
- except:
86
- raise ImportError("Cannot import RealESRGAN. Make sure basicsr is installed.")
87
-
88
- # ================== CONFIGURACIÓN DEL MODELO ==================
89
- MODEL_CONFIGS = {
90
- 'x4': {
91
- 'name': 'RealESRGAN_x4plus',
92
- 'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
93
- 'scale': 4,
94
- 'blocks': 23
95
- },
96
- 'x2': {
97
- 'name': 'RealESRGAN_x2plus',
98
- 'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
99
- 'scale': 2,
100
- 'blocks': 23
101
- },
102
- 'anime': {
103
- 'name': 'RealESRGAN_x4plus_anime',
104
- 'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth',
105
- 'scale': 4,
106
- 'blocks': 6 # 6 blocks para anime
107
- }
108
- }
109
-
110
- def download_model(model_key='x4'):
111
- """Descarga el modelo si no existe"""
112
- config = MODEL_CONFIGS[model_key]
113
- model_dir = 'models'
114
- os.makedirs(model_dir, exist_ok=True)
115
-
116
- model_filename = f"{config['name']}.pth"
117
- model_path = os.path.join(model_dir, model_filename)
118
-
119
- if not os.path.exists(model_path):
120
- print(f"📥 Descargando modelo {config['name']}...")
121
- try:
122
- subprocess.run(['wget', config['url'], '-O', model_path, '-q'],
123
- check=True, capture_output=True)
124
- print(f"✅ Modelo descargado: {model_filename}")
125
- except:
126
- # Fallback con gdown si wget falla
127
- print("Intentando descarga alternativa...")
128
- try:
129
- import gdown
130
- # Extraer ID de Google Drive si es necesario
131
- gdown.download(config['url'], model_path, quiet=False)
132
- except:
133
- print("❌ Error al descargar el modelo")
134
- return None
135
-
136
- return model_path, config
137
-
138
- # ================== FUNCIÓN PRINCIPAL ==================
139
- def upscale_image(image, model_choice='x4', scale_factor=4):
140
  """
141
- Mejora la resolución de una imagen
142
  """
143
  if image is None:
144
  return None
145
 
146
- try:
147
- # Descargar modelo
148
- model_path, config = download_model(model_choice)
149
- if model_path is None:
150
- return image
151
-
152
- # Crear modelo
153
- model = RRDBNet(
154
- num_in_ch=3,
155
- num_out_ch=3,
156
- num_feat=64,
157
- num_block=config['blocks'],
158
- num_grow_ch=32,
159
- scale=config['scale']
 
 
 
 
 
 
 
160
  )
161
-
162
- # Inicializar upsampler
163
- upsampler = RealESRGANer(
164
- scale=config['scale'],
165
- model_path=model_path,
166
- model=model,
167
- tile=0, # 0 para no usar tiles
168
- tile_pad=10,
169
- pre_pad=0,
170
- half=False # Usar float32 para compatibilidad
171
  )
172
-
173
- # Convertir imagen
174
- if isinstance(image, Image.Image):
175
- img_array = np.array(image)
176
- else:
177
- img_array = image
178
-
179
- # Asegurar que es RGB
180
- if len(img_array.shape) == 2: # Grayscale
181
- img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
182
- elif img_array.shape[2] == 4: # RGBA
183
- img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
184
-
185
- # Convertir a BGR para OpenCV
186
- img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
187
-
188
- # Upscale
189
- output, _ = upsampler.enhance(img_bgr, outscale=scale_factor)
190
-
191
- # Convertir de vuelta a RGB
192
- output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
193
-
194
- # Convertir a PIL Image
195
- result = Image.fromarray(output_rgb)
196
-
197
- return result
198
-
199
- except Exception as e:
200
- print(f"❌ Error durante el upscaling: {str(e)}")
201
- import traceback
202
- traceback.print_exc()
203
- return image # Devolver original si hay error
204
 
205
- # ================== INTERFAZ GRADIO SIMPLIFICADA ==================
206
- def create_interface():
207
- with gr.Blocks(title="🖼️ Mejorador de Imágenes AI", theme="soft") as app:
208
- gr.Markdown("""
209
- # 🚀 Mejorador de Imágenes con IA
210
- ### Aumenta la resolución y calidad de tus imágenes usando Real-ESRGAN
211
-
212
- Sube una imagen y selecciona las opciones para mejorarla.
213
- """)
214
 
215
  with gr.Row():
216
- with gr.Column(scale=1):
217
- input_image = gr.Image(
218
- type="pil",
219
- label="📤 Subir Imagen",
220
- height=300
221
- )
222
-
223
- with gr.Accordion("⚙️ Configuración", open=True):
224
- model_select = gr.Dropdown(
225
- choices=[
226
- ("4x General (Recomendado)", "x4"),
227
- ("2x General", "x2"),
228
- ("4x Anime/Dibujos", "anime")
229
- ],
230
- value="x4",
231
- label="Modelo"
232
- )
233
-
234
- scale_slider = gr.Slider(
235
- minimum=1,
236
- maximum=4,
237
- value=4,
238
- step=1,
239
- label="Factor de Escala"
240
- )
241
 
242
- process_btn = gr.Button(
243
- "✨ Mejorar Imagen",
244
- variant="primary",
245
- size="lg"
 
246
  )
247
 
248
- gr.Markdown("""
249
- ### 💡 Consejos:
250
- - Para fotos reales usa "4x General"
251
- - Para dibujos/anime usa "4x Anime"
252
- - El proceso puede tomar algunos segundos
253
- """)
254
 
255
- with gr.Column(scale=1):
256
- output_image = gr.Image(
257
- type="pil",
258
- label="📥 Resultado Mejorado",
259
- height=300
260
- )
261
-
262
- with gr.Accordion("📊 Información", open=False):
263
- info_text = gr.Markdown("Esperando imagen...")
264
 
265
- # Ejemplos
266
- gr.Examples(
267
- examples=[
268
- ["example1.jpg"],
269
- ["example2.png"],
270
- ],
271
- inputs=[input_image],
272
- label="💡 Ejemplos"
273
- )
274
-
275
- # Funciones
276
- def update_info(image, model, scale):
277
- if image is None:
278
- return "Sube una imagen para comenzar"
279
-
280
- orig_w, orig_h = image.size
281
- new_w, new_h = orig_w * scale, orig_h * scale
282
-
283
- return f"""
284
- **Imagen Original:** {orig_w} x {orig_h} px
285
- **Imagen Mejorada:** {new_w} x {new_h} px
286
- **Modelo:** {MODEL_CONFIGS[model]['name']}
287
- **Escala:** {scale}x
288
- """
289
-
290
- def process_image(image, model, scale):
291
- if image is None:
292
- return None, "Por favor, sube una imagen"
293
-
294
- try:
295
- result = upscale_image(image, model, scale)
296
- info = update_info(image, model, scale)
297
- return result, info
298
- except Exception as e:
299
- return None, f"Error: {str(e)}"
300
-
301
- # Conectar eventos
302
- input_image.change(
303
- fn=lambda img, mod, sc: update_info(img, mod, sc),
304
- inputs=[input_image, model_select, scale_slider],
305
- outputs=info_text
306
- )
307
-
308
- model_select.change(
309
- fn=lambda img, mod, sc: update_info(img, mod, sc),
310
- inputs=[input_image, model_select, scale_slider],
311
- outputs=info_text
312
- )
313
-
314
- scale_slider.change(
315
- fn=lambda img, mod, sc: update_info(img, mod, sc),
316
- inputs=[input_image, model_select, scale_slider],
317
- outputs=info_text
318
- )
319
-
320
- process_btn.click(
321
- fn=process_image,
322
- inputs=[input_image, model_select, scale_slider],
323
- outputs=[output_image, info_text]
324
  )
325
 
326
  return app
327
 
328
- # ================== ARCHIVO requirements.txt ==================
329
- # Crea un archivo requirements.txt con esto:
330
- """
331
- torch>=1.9.0,<2.0.0
332
- torchvision>=0.10.0,<0.15.0
333
- gradio>=4.0.0
334
- opencv-python>=4.5.0
335
- numpy>=1.19.0
336
- Pillow>=8.0.0
337
- scipy>=1.6.0
338
- gdown>=4.4.0
339
- tqdm>=4.50.0
340
- basicsr==1.4.2
341
- realesrgan==0.3.0
342
- """
343
-
344
- # ================== MAIN ==================
345
  if __name__ == "__main__":
346
- # Crear la interfaz
347
- app = create_interface()
348
-
349
- # Configurar para Hugging Face Spaces
350
- share = os.getenv('SHARE', 'False').lower() == 'true'
351
-
352
- app.launch(
353
- server_name="0.0.0.0",
354
- server_port=7860,
355
- share=share,
356
- debug=True
357
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
  from PIL import Image
5
+ import torch
6
+ from torchvision import transforms
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
 
10
+ # Función simple de upscaling usando OpenCV
11
+ def simple_upscale(image, scale_factor=4, method='cubic'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
+ Upscaling simple usando interpolación
14
  """
15
  if image is None:
16
  return None
17
 
18
+ img = np.array(image)
19
+
20
+ # Métodos de interpolación
21
+ methods = {
22
+ 'nearest': cv2.INTER_NEAREST,
23
+ 'linear': cv2.INTER_LINEAR,
24
+ 'cubic': cv2.INTER_CUBIC,
25
+ 'lanczos': cv2.INTER_LANCZOS4
26
+ }
27
+
28
+ # Calcular nuevo tamaño
29
+ height, width = img.shape[:2]
30
+ new_width = int(width * scale_factor)
31
+ new_height = int(height * scale_factor)
32
+
33
+ # Upscale
34
+ if len(img.shape) == 3: # Color
35
+ upscaled = cv2.resize(
36
+ img,
37
+ (new_width, new_height),
38
+ interpolation=methods.get(method, cv2.INTER_CUBIC)
39
  )
40
+ else: # Grayscale
41
+ upscaled = cv2.resize(
42
+ img,
43
+ (new_width, new_height),
44
+ interpolation=methods.get(method, cv2.INTER_CUBIC)
 
 
 
 
 
45
  )
46
+
47
+ # Aplicar sharpening
48
+ kernel = np.array([[-1, -1, -1],
49
+ [-1, 9, -1],
50
+ [-1, -1, -1]])
51
+ sharpened = cv2.filter2D(upscaled, -1, kernel)
52
+
53
+ return Image.fromarray(sharpened)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Interfaz simple
56
+ def create_simple_interface():
57
+ with gr.Blocks(title="Mejorador Simple de Imágenes") as app:
58
+ gr.Markdown("# 🖼️ Mejorador de Imágenes Simple")
 
 
 
 
 
59
 
60
  with gr.Row():
61
+ with gr.Column():
62
+ input_img = gr.Image(type="pil", label="Imagen Original")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ scale = gr.Slider(1, 8, 4, step=1, label="Factor de Escala")
65
+ method = gr.Dropdown(
66
+ ['nearest', 'linear', 'cubic', 'lanczos'],
67
+ value='cubic',
68
+ label="Método de Interpolación"
69
  )
70
 
71
+ btn = gr.Button("Mejorar", variant="primary")
 
 
 
 
 
72
 
73
+ with gr.Column():
74
+ output_img = gr.Image(type="pil", label="Imagen Mejorada")
 
 
 
 
 
 
 
75
 
76
+ btn.click(
77
+ fn=simple_upscale,
78
+ inputs=[input_img, scale, method],
79
+ outputs=output_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
 
82
  return app
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if __name__ == "__main__":
85
+ app = create_simple_interface()
86
+ app.launch()