File size: 3,091 Bytes
10c688c
 
 
 
 
 
 
5ffa6f9
 
 
 
 
 
02e46d5
5ffa6f9
 
 
 
02e46d5
5ffa6f9
 
 
 
 
 
02e46d5
5ffa6f9
 
 
 
 
 
 
10c688c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2978074
10c688c
 
 
 
 
 
 
 
 
 
 
 
 
2978074
 
 
 
10c688c
 
 
 
 
 
 
 
2978074
10c688c
 
 
 
 
2978074
10c688c
2978074
10c688c
 
 
 
5ffa6f9
 
 
 
 
 
2978074
 
 
 
 
 
 
 
 
 
 
 
5ffa6f9
55a2485
10c688c
2978074
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from torchvision import transforms
from PIL import Image
import gradio as gr
from transformers import AutoTokenizer
from model import CaptioningTransformer

css_str = """
body {
    background-color: #121212;
    color: #e0e0e0;
    font-family: Arial, sans-serif;
}

.container {
    max-width: 700px;
    margin: 15px auto;
}

h1 {
    font-size: 36px;
    font-weight: bold;
    text-align: center;
    color: #ffffff;
}

.description {
    font-size: 18px;
    text-align: center;
    color: #b0b0b0;
}
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 128
patch_size = 8
d_model = 192
n_layers = 6
n_heads = 8

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
transform = transforms.Compose(
    [
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

model = CaptioningTransformer(
    image_size=image_size,
    in_channels=3,
    vocab_size=tokenizer.vocab_size,
    device=device,
    patch_size=patch_size,
    n_layers=n_layers,
    d_model=d_model,
    n_heads=n_heads,
).to(device)

model_path = "image_captioning_model.pt"
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


def make_prediction(
    model, sos_token, eos_token, image, max_len=50, temp=0.5, device=device
):
    log_tokens = [sos_token]
    with torch.inference_mode():
        image_embedding = model.encoder(image.to(device))
        for _ in range(max_len):
            input_tokens = torch.cat(log_tokens, dim=1)
            data_pred = model.decoder(input_tokens.to(device), image_embedding)
            dist = torch.distributions.Categorical(logits=data_pred[:, -1] / temp)
            next_tokens = dist.sample().reshape(1, 1)
            log_tokens.append(next_tokens.cpu())
            if next_tokens.item() == 102:
                break
    return torch.cat(log_tokens, dim=1)


def predict(image: Image.Image):
    img_tensor = transform(image).unsqueeze(0)
    sos_token = 101 * torch.ones(1, 1).long().to(device)
    tokens = make_prediction(model, sos_token, 102, img_tensor)
    caption = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return caption


with gr.Blocks(css=css_str) as demo:
    gr.HTML("<div class='container'>")
    gr.Markdown("<h1>Image Captioning</h1>")
    gr.Markdown(
        "<div class='description'>Upload an image and get a descriptive caption about the image:</div>"
    )

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Your Image")
            generate_button = gr.Button("Generate Caption")
        with gr.Column():
            caption_output = gr.Textbox(
                label="Caption Output",
                placeholder="Your generated caption will appear here...",
            )

    generate_button.click(fn=predict, inputs=image_input, outputs=caption_output)
    gr.HTML("</div>")

if __name__ == "__main__":
    demo.launch(share=True)