File size: 6,625 Bytes
12d491a
 
 
437bf1c
 
 
 
 
 
 
 
12d491a
437bf1c
12d491a
437bf1c
5f9b018
12d491a
 
 
0ac8e86
 
 
 
 
 
437bf1c
 
12d491a
437bf1c
 
 
 
 
 
 
12d491a
437bf1c
12d491a
437bf1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12d491a
437bf1c
12d491a
437bf1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12d491a
 
437bf1c
12d491a
437bf1c
 
 
 
0ac8e86
 
 
437bf1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12d491a
437bf1c
12d491a
 
437bf1c
0ac8e86
 
 
 
 
 
 
 
 
 
 
 
 
 
437bf1c
 
 
0ac8e86
437bf1c
 
 
 
 
 
 
 
 
 
552bf54
437bf1c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# ======================
# --- 0. LIBRERIAS ---
# ======================
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# ========================
# --- 1. CONFIGURACIÓN ---
# ========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 384 # MODIFICAR TAMAÑO DE IMAGEN (384,192,256) SI USAS VERSIONES DE SPACE GRATUITAS PARA NO IR LENTO :P


# Transformación de entrada 

loader = transforms.Compose([
    transforms.Resize((imsize, imsize)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Transformación inversa (Desnormalizar para mostrar la imagen final)

unloader = transforms.Compose([
    transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 
                         std=[1/0.229, 1/0.224, 1/0.225]),
    transforms.Lambda(lambda x: x.clamp(0, 1)),
    transforms.ToPILImage()
])

# ===============================
# --- 2. FUNCIONES DE PÉRDIDA ---
# ===============================
def calc_content_loss(gen_features, content_features):
    return torch.mean((gen_features - content_features) ** 2)

def gram_matrix(tensor):
    _, c, h, w = tensor.size()
    tensor = tensor.view(c, h * w)
    return torch.mm(tensor, tensor.t()) / (c * h * w)

def calc_style_loss(gen_features, style_features):
    G_gen = gram_matrix(gen_features)
    G_style = gram_matrix(style_features)
    return torch.mean((G_gen - G_style) ** 2)

def calc_tv_loss(img):
    tv_h = torch.sum((img[:, :, 1:, :] - img[:, :, :-1, :]) ** 2)
    tv_w = torch.sum((img[:, :, :, 1:] - img[:, :, :, :-1]) ** 2)
    return tv_h + tv_w

# ============================
# --- 3. MODELO EXTRACTOR ---
# ============================
class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
        for param in vgg.parameters():
            param.requires_grad = False
        self.model = vgg.to(device).eval()
        self.style_layers = {'0': 'block1_conv1', '5': 'block2_conv1', '10': 'block3_conv1', '19': 'block4_conv1', '28': 'block5_conv1'}
        self.content_layers = {'30': 'block5_conv2'} 

    def forward(self, x):
        style_features = {}
        content_features = {}
        for name, layer in self.model._modules.items():
            x = layer(x)
            if name in self.style_layers: style_features[self.style_layers[name]] = x
            if name in self.content_layers: content_features[self.content_layers[name]] = x
        return content_features, style_features
        
# ========================================
# --- 4. FUNCIÓN PRINCIPAL PARA GRADIO ---
# ========================================
def run_style_transfer(content_img, style_img, content_weight, style_weight, tv_weight, iterations):
    if content_img is None or style_img is None:
        return None
    
    # Aplicamos las transformaciones (incluyendo el resize a 384x384)
    content_tensor = loader(content_img).unsqueeze(0).to(device, torch.float)
    style_tensor = loader(style_img).unsqueeze(0).to(device, torch.float)
    
    gen_img = content_tensor.clone().requires_grad_(True)
    extractor = VGGFeatureExtractor().to(device)
    
    target_content_features, _ = extractor(content_tensor)
    _, target_style_features = extractor(style_tensor)
    
    optimizer = optim.LBFGS([gen_img], max_iter=20)
    
    for i in range(int(iterations)):
        def closure():
            optimizer.zero_grad()
            gen_img.data.clamp_(-2.1, 2.6)
            
            gen_content_features, gen_style_features = extractor(gen_img)
            
            c_loss = calc_content_loss(gen_content_features['block5_conv2'], target_content_features['block5_conv2'])
            
            s_loss = 0
            for layer_name in target_style_features:
                s_loss += calc_style_loss(gen_style_features[layer_name], target_style_features[layer_name])
            s_loss /= len(target_style_features)
            
            t_loss = calc_tv_loss(gen_img)
            
            total_loss = (content_weight * c_loss) + (style_weight * s_loss) + (tv_weight * t_loss)
            total_loss.backward()
            return total_loss
            
        optimizer.step(closure)
        
    gen_img.data.clamp_(-2.1, 2.6)
    
    final_image = unloader(gen_img.cpu().squeeze(0))
    return final_image

# =======================================
# --- 5. INTERFAZ DE USUARIO (GRADIO) ---
# =======================================

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    
    # ENCABEZADO Y ENLACES
    gr.Markdown(
        """
        <div style="text-align: center;">
            <h1>🎨 Transferencia de Estilo Neuronal</h1>
            <p>Sube una imagen base y una imagen de estilo para combinarlas. <i>Nota: Las imágenes se redimensionan automáticamente para procesamiento rápido.</i></p>
            <p>
                <a href="https://github.com/MGranados64" target="_blank" style="text-decoration: none;">🐙 <b>Mi GitHub</b></a> &nbsp; | &nbsp; 
                <a href="https://huggingface.co/MGC1991MF" target="_blank" style="text-decoration: none;">🤗 <b>Mi perfil en Hugging Face</b></a>
            </p>
        </div>
        """
    )
    
    with gr.Row():
        with gr.Column():
            content_in = gr.Image(type="pil", label="Imagen Base (A)")
            style_in = gr.Image(type="pil", label="Imagen de Estilo (B)")
        with gr.Column():
            output_image = gr.Image(type="pil", label="Imagen Resultante (C)")
            
    with gr.Row():
        with gr.Column():
            gr.Markdown("### ⚙️ Ajustes del Modelo")
            c_weight = gr.Slider(minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Peso del Contenido (Estructura)")
            s_weight = gr.Slider(minimum=1000, maximum=1000000, value=100000, step=1000, label="Peso del Estilo (Arte)")
            tv_weight = gr.Slider(minimum=0, maximum=0.001, value=0.000001, step=0.000001, label="Suavizado (Variación Total)")
            iters = gr.Slider(minimum=2, maximum=20, value=5, step=1, label="Iteraciones")
            
            run_btn = gr.Button("¡Mezclar Imágenes!", variant="primary")

    run_btn.click(
        fn=run_style_transfer,
        inputs=[content_in, style_in, c_weight, s_weight, tv_weight, iters],
        outputs=output_image
    )

if __name__ == "__main__":
    demo.launch()