File size: 3,787 Bytes
bfd5ab7
8f847e7
 
 
 
 
 
 
 
 
 
 
 
 
 
bfd5ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ee7d17
8f847e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0f418e
 
 
 
 
3ee7d17
a51fe1b
 
 
 
 
 
 
 
 
0f88d29
7a9f97a
f0f418e
3e22b2b
3ee7d17
98c6b6d
3ee7d17
00eff89
 
4671d3e
00eff89
 
 
672e664
 
8f847e7
bfd5ab7
672e664
bfd5ab7
 
 
 
00eff89
f026d0c
bfd5ab7
8566c6a
f2791fd
 
00eff89
f026d0c
98c6b6d
3ee7d17
00eff89
 
3ee7d17
 
98c6b6d
00eff89
9517d28
8f847e7
9517d28
a51fe1b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
custom_css = """
/* Center main content and lock max width to 900px, with responsive shrink */
#main-app-area {
    max-width: 900px;
    margin-left: auto;
    margin-right: auto;
    padding: 0 16px;
}
/* Responsive for mobile (<950px) */
@media (max-width: 950px) {
    #main-app-area {
        max-width: 99vw;
        padding: 0 2vw;
    }
}
#app-title {
    text-align: center;
    font-size: 38px;
    color: #53c9fc;
    font-weight: bold;
    padding-top: 12px;
}
#instructions {
    text-align: center;
    font-size: 19px;
    margin: 14px 0 22px 0;
}
#generate-btn {
    background: linear-gradient(90deg, #31b2fd 0%, #98f972 100%);
    color: white;
    font-size: 18px;
    font-weight: bold;
    border: none;
    border-radius: 11px;
    margin-top: 8px;
    margin-bottom: 14px;
    transition: 0.2s;
}
#generate-btn:hover {
    filter: brightness(1.08);
    box-shadow: 0 2px 16px #9efbc344;
}
"""

from transformers import BlipProcessor, BlipForConditionalGeneration
from ultralytics import YOLO
import torch
import gradio as gr
from PIL import Image
from collections import deque
import numpy as np

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
detect_model = YOLO('yolov5s.pt')

MEMORY_SIZE = 10
last_images = deque([], maxlen=MEMORY_SIZE)
last_captions = deque([], maxlen=MEMORY_SIZE)

def preprocess_image(image):
    if image.mode != "RGB":
        image = image.convert("RGB")
    return image

def detect_objects(image):
    img_np = np.array(image)
    results = detect_model(img_np)
    detected_objs = set()
    for r in results:
        for box in r.boxes.data.tolist():
            class_id = int(box[-1])
            label = detect_model.names[class_id]
            detected_objs.add(label)
    return list(detected_objs)

def generate_caption(image):
    image = preprocess_image(image)
    inputs = processor(image, return_tensors="pt")
    out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
    caption = processor.decode(out[0], skip_special_tokens=True)
    detected_objs = detect_objects(image)
    last_images.append(image)
    last_captions.append(caption)
    tags = ", ".join(detected_objs) if detected_objs else "None"
    gallery = [(img, f"Detected objects: {tags}\nCaption: {caption}") for img, caption in zip(list(last_images), list(last_captions))]
    result_text = f"Detected objects: {tags}\nCaption: {caption}"
    return result_text, gallery

with gr.Blocks(css=custom_css) as iface:
    gr.HTML('<div id="main-app-area">')  # Start content region
    gr.HTML('<div id="app-title">🖼️ Image Captioning with Object Detection</div>')
    gr.HTML(
        '<div id="instructions">'
        '🙌 <b>Welcome!</b> Instantly analyze images using AI.<br>'
        '1️⃣ <b>Upload</b> your image.<br>'
        '2️⃣ Click <b>⭐ Generate Caption</b>.<br>'
        '3️⃣ View and scroll through your history below.<br>'
        '📜 <i>Last 10 results are stored for you.</i>'
        '</div>'
    )
    image_input = gr.Image(type="pil", label="Upload Image")
    generate_btn = gr.Button("⭐ Generate Caption", elem_id="generate-btn")
    caption_output = gr.Textbox(label="📝 Caption and Detected Objects", lines=5, interactive=True)
    gallery = gr.Gallery(label="Last 10 Images and Captions", scale=3)
    def on_generate(image):
        if image is None:
            return "Please upload an image.", []
        return generate_caption(image)
    generate_btn.click(
        fn=on_generate,
        inputs=image_input,
        outputs=[caption_output, gallery]
    )
    gr.HTML('</div>')  # End content region

if __name__ == "__main__":
    iface.launch()