MGC1991MF commited on
Commit
12d491a
·
verified ·
1 Parent(s): 552bf54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
@@ -6,11 +9,15 @@ import torchvision.models as models
6
  import torchvision.transforms as transforms
7
  from PIL import Image
8
 
 
9
  # --- 1. CONFIGURACIÓN ---
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- imsize = 192 # Balance perfecto entre velocidad (CPU) y calidad
 
 
 
12
 
13
- # Transformación de entrada (Ahora redimensiona todo a imsize x imsize)
14
  loader = transforms.Compose([
15
  transforms.Resize((imsize, imsize)),
16
  transforms.ToTensor(),
@@ -18,6 +25,7 @@ loader = transforms.Compose([
18
  ])
19
 
20
  # Transformación inversa (Desnormalizar para mostrar la imagen final)
 
21
  unloader = transforms.Compose([
22
  transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
23
  std=[1/0.229, 1/0.224, 1/0.225]),
@@ -25,7 +33,9 @@ unloader = transforms.Compose([
25
  transforms.ToPILImage()
26
  ])
27
 
 
28
  # --- 2. FUNCIONES DE PÉRDIDA ---
 
29
  def calc_content_loss(gen_features, content_features):
30
  return torch.mean((gen_features - content_features) ** 2)
31
 
@@ -44,7 +54,9 @@ def calc_tv_loss(img):
44
  tv_w = torch.sum((img[:, :, :, 1:] - img[:, :, :, :-1]) ** 2)
45
  return tv_h + tv_w
46
 
 
47
  # --- 3. MODELO EXTRACTOR ---
 
48
  class VGGFeatureExtractor(nn.Module):
49
  def __init__(self):
50
  super().__init__()
@@ -63,8 +75,10 @@ class VGGFeatureExtractor(nn.Module):
63
  if name in self.style_layers: style_features[self.style_layers[name]] = x
64
  if name in self.content_layers: content_features[self.content_layers[name]] = x
65
  return content_features, style_features
66
-
 
67
  # --- 4. FUNCIÓN PRINCIPAL PARA GRADIO ---
 
68
  def run_style_transfer(content_img, style_img, content_weight, style_weight, tv_weight, iterations):
69
  if content_img is None or style_img is None:
70
  return None
@@ -108,7 +122,10 @@ def run_style_transfer(content_img, style_img, content_weight, style_weight, tv_
108
  final_image = unloader(gen_img.cpu().squeeze(0))
109
  return final_image
110
 
 
111
  # --- 5. INTERFAZ DE USUARIO (GRADIO) ---
 
 
112
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
113
 
114
  # ENCABEZADO Y ENLACES
 
1
+ # ======================
2
+ # --- 0. LIBRERIAS ---
3
+ # ======================
4
  import gradio as gr
5
  import torch
6
  import torch.nn as nn
 
9
  import torchvision.transforms as transforms
10
  from PIL import Image
11
 
12
+ # ========================
13
  # --- 1. CONFIGURACIÓN ---
14
+ # ========================
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ imsize = 256 # MODIFICAR TAMAÑO DE IMAGEN (384,192,256) SI USAS VERSIONES DE SPACE GRATUITAS PARA NO IR LENTO :P
17
+
18
+
19
+ # Transformación de entrada
20
 
 
21
  loader = transforms.Compose([
22
  transforms.Resize((imsize, imsize)),
23
  transforms.ToTensor(),
 
25
  ])
26
 
27
  # Transformación inversa (Desnormalizar para mostrar la imagen final)
28
+
29
  unloader = transforms.Compose([
30
  transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
31
  std=[1/0.229, 1/0.224, 1/0.225]),
 
33
  transforms.ToPILImage()
34
  ])
35
 
36
+ # ===============================
37
  # --- 2. FUNCIONES DE PÉRDIDA ---
38
+ # ===============================
39
  def calc_content_loss(gen_features, content_features):
40
  return torch.mean((gen_features - content_features) ** 2)
41
 
 
54
  tv_w = torch.sum((img[:, :, :, 1:] - img[:, :, :, :-1]) ** 2)
55
  return tv_h + tv_w
56
 
57
+ # ============================
58
  # --- 3. MODELO EXTRACTOR ---
59
+ # ============================
60
  class VGGFeatureExtractor(nn.Module):
61
  def __init__(self):
62
  super().__init__()
 
75
  if name in self.style_layers: style_features[self.style_layers[name]] = x
76
  if name in self.content_layers: content_features[self.content_layers[name]] = x
77
  return content_features, style_features
78
+
79
+ # ========================================
80
  # --- 4. FUNCIÓN PRINCIPAL PARA GRADIO ---
81
+ # ========================================
82
  def run_style_transfer(content_img, style_img, content_weight, style_weight, tv_weight, iterations):
83
  if content_img is None or style_img is None:
84
  return None
 
122
  final_image = unloader(gen_img.cpu().squeeze(0))
123
  return final_image
124
 
125
+ # =======================================
126
  # --- 5. INTERFAZ DE USUARIO (GRADIO) ---
127
+ # =======================================
128
+
129
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
130
 
131
  # ENCABEZADO Y ENLACES