MGC1991MF commited on
Commit
437bf1c
·
verified ·
1 Parent(s): 02bb4ba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision.models as models
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+
9
+ # --- 1. CONFIGURACIÓN ---
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # Transformación inversa (Desnormalizar para mostrar la imagen final)
13
+ unloader = transforms.Compose([
14
+ transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
15
+ std=[1/0.229, 1/0.224, 1/0.225]),
16
+ transforms.Lambda(lambda x: x.clamp(0, 1)),
17
+ transforms.ToPILImage()
18
+ ])
19
+
20
+ # --- 2. FUNCIONES DE PÉRDIDA ---
21
+ def calc_content_loss(gen_features, content_features):
22
+ return torch.mean((gen_features - content_features) ** 2)
23
+
24
+ def gram_matrix(tensor):
25
+ _, c, h, w = tensor.size()
26
+ tensor = tensor.view(c, h * w)
27
+ return torch.mm(tensor, tensor.t()) / (c * h * w)
28
+
29
+ def calc_style_loss(gen_features, style_features):
30
+ G_gen = gram_matrix(gen_features)
31
+ G_style = gram_matrix(style_features)
32
+ return torch.mean((G_gen - G_style) ** 2)
33
+
34
+ def calc_tv_loss(img):
35
+ tv_h = torch.sum((img[:, :, 1:, :] - img[:, :, :-1, :]) ** 2)
36
+ tv_w = torch.sum((img[:, :, :, 1:] - img[:, :, :, :-1]) ** 2)
37
+ return tv_h + tv_w
38
+
39
+ # --- 3. MODELO EXTRACTOR ---
40
+ class VGGFeatureExtractor(nn.Module):
41
+ def __init__(self):
42
+ super().__init__()
43
+ vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
44
+ for param in vgg.parameters():
45
+ param.requires_grad = False
46
+ self.model = vgg.to(device).eval()
47
+ self.style_layers = {'0': 'block1_conv1', '5': 'block2_conv1', '10': 'block3_conv1', '19': 'block4_conv1', '28': 'block5_conv1'}
48
+ self.content_layers = {'30': 'block5_conv2'}
49
+
50
+ def forward(self, x):
51
+ style_features = {}
52
+ content_features = {}
53
+ for name, layer in self.model._modules.items():
54
+ x = layer(x)
55
+ if name in self.style_layers: style_features[self.style_layers[name]] = x
56
+ if name in self.content_layers: content_features[self.content_layers[name]] = x
57
+ return content_features, style_features
58
+
59
+ # --- 4. FUNCIÓN PRINCIPAL PARA GRADIO ---
60
+ def run_style_transfer(content_img, style_img, content_weight, style_weight, tv_weight, iterations):
61
+ if content_img is None or style_img is None:
62
+ return None
63
+
64
+ # Obtenemos el tamaño ORIGINAL de la imagen de contenido
65
+ original_width, original_height = content_img.size
66
+
67
+ # Transformación del contenido: SIN redimensionar, mantiene su tamaño original
68
+ content_transform = transforms.Compose([
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
71
+ ])
72
+
73
+ # Transformación del estilo: lo redimensionamos para que coincida con el contenido
74
+ # Nota: transforms.Resize espera (Alto, Ancho)
75
+ style_transform = transforms.Compose([
76
+ transforms.Resize((original_height, original_width)),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
79
+ ])
80
+
81
+ # Aplicamos las transformaciones
82
+ content_tensor = content_transform(content_img).unsqueeze(0).to(device, torch.float)
83
+ style_tensor = style_transform(style_img).unsqueeze(0).to(device, torch.float)
84
+
85
+ # El resto del código se mantiene igual...
86
+ gen_img = content_tensor.clone().requires_grad_(True)
87
+ extractor = VGGFeatureExtractor().to(device)
88
+
89
+ target_content_features, _ = extractor(content_tensor)
90
+ _, target_style_features = extractor(style_tensor)
91
+
92
+ optimizer = optim.LBFGS([gen_img], max_iter=20)
93
+
94
+ for i in range(int(iterations)):
95
+ def closure():
96
+ optimizer.zero_grad()
97
+ gen_img.data.clamp_(-2.1, 2.6)
98
+
99
+ gen_content_features, gen_style_features = extractor(gen_img)
100
+
101
+ c_loss = calc_content_loss(gen_content_features['block5_conv2'], target_content_features['block5_conv2'])
102
+
103
+ s_loss = 0
104
+ for layer_name in target_style_features:
105
+ s_loss += calc_style_loss(gen_style_features[layer_name], target_style_features[layer_name])
106
+ s_loss /= len(target_style_features)
107
+
108
+ t_loss = calc_tv_loss(gen_img)
109
+
110
+ total_loss = (content_weight * c_loss) + (style_weight * s_loss) + (tv_weight * t_loss)
111
+ total_loss.backward()
112
+ return total_loss
113
+
114
+ optimizer.step(closure)
115
+
116
+ gen_img.data.clamp_(-2.1, 2.6)
117
+
118
+ # Convertimos de vuelta a imagen PIL (Saldrá sin ejes y en su tamaño original)
119
+ final_image = unloader(gen_img.cpu().squeeze(0))
120
+ return final_image
121
+
122
+ # --- 5. INTERFAZ DE USUARIO (GRADIO) ---
123
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
124
+ gr.Markdown("# 🎨 Transferencia de Estilo Neuronal")
125
+ gr.Markdown("Sube una imagen base (A) y una imagen de estilo (B) para combinarlas. **La imagen resultante mantendrá la resolución de tu imagen base.**")
126
+
127
+ with gr.Row():
128
+ with gr.Column():
129
+ content_in = gr.Image(type="pil", label="Imagen Base (A) - Dicta el tamaño")
130
+ style_in = gr.Image(type="pil", label="Imagen de Estilo (B)")
131
+ with gr.Column():
132
+ output_image = gr.Image(type="pil", label="Imagen Resultante (C)")
133
+
134
+ with gr.Row():
135
+ with gr.Column():
136
+ gr.Markdown("### ⚙️ Ajustes del Modelo")
137
+ c_weight = gr.Slider(minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Peso del Contenido (Estructura)")
138
+ s_weight = gr.Slider(minimum=1000, maximum=1000000, value=100000, step=1000, label="Peso del Estilo (Arte)")
139
+ tv_weight = gr.Slider(minimum=0, maximum=0.001, value=0.000001, step=0.000001, label="Suavizado (Variación Total)")
140
+ iters = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Iteraciones (¡Cuidado con imágenes grandes!)")
141
+
142
+ run_btn = gr.Button("¡Mezclar Imágenes!", variant="primary")
143
+
144
+ run_btn.click(
145
+ fn=run_style_transfer,
146
+ inputs=[content_in, style_in, c_weight, s_weight, tv_weight, iters],
147
+ outputs=output_image
148
+ )
149
+
150
+ if __name__ == "__main__":
151
+ demo.launch()