File size: 2,593 Bytes
577ffe1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

import torch
from diffusers import DDPMScheduler
import json
from PIL import Image
import numpy as np

class LetterConditionedUnet(torch.nn.Module):
    def __init__(self, num_classes=26, class_emb_size=8):
        super().__init__()
        from diffusers import UNet2DModel

        self.class_emb = torch.nn.Embedding(num_classes, class_emb_size)

        self.model = UNet2DModel(
            sample_size=512,
            in_channels=1 + class_emb_size,
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(64, 128, 256, 512, 512),
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D",
            ),
            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
        )

    def forward(self, x, t, class_labels):
        bs, ch, w, h = x.shape
        class_cond = self.class_emb(class_labels)
        class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
        net_input = torch.cat((x, class_cond), 1)
        return self.model(net_input, t).sample

def generate_letter(letter, model_path="./"):
    """Genera una imagen de la letra especificada"""
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Cargar modelo
    model = LetterConditionedUnet()
    model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device))
    model = model.to(device)
    model.eval()

    # Cargar scheduler
    with open(f"{model_path}/scheduler_config.json", 'r') as f:
        scheduler_config = json.load(f)
    scheduler = DDPMScheduler(**scheduler_config)

    # Preparar entrada
    letter_label = ord(letter.upper()) - 65  # Convertir letra a número
    x = torch.randn(1, 1, 512, 512, device=device)
    labels = torch.tensor([letter_label], device=device)

    # Generar
    with torch.no_grad():
        for t in scheduler.timesteps:
            residual = model(x, t, labels)
            x = scheduler.step(residual, t, x).prev_sample

    # Convertir a imagen
    image = x[0, 0].cpu().numpy()
    image = (image + 1) / 2  # Desnormalizar de [-1,1] a [0,1]
    image = (image * 255).astype(np.uint8)

    return Image.fromarray(image, mode='L')

# Ejemplo de uso
if __name__ == "__main__":
    letter_image = generate_letter('A')
    letter_image.save('generated_letter_A.png')