scmlewis's picture
Update app.py
8f847e7 verified
custom_css = """
/* Center main content and lock max width to 900px, with responsive shrink */
#main-app-area {
max-width: 900px;
margin-left: auto;
margin-right: auto;
padding: 0 16px;
}
/* Responsive for mobile (<950px) */
@media (max-width: 950px) {
#main-app-area {
max-width: 99vw;
padding: 0 2vw;
}
}
#app-title {
text-align: center;
font-size: 38px;
color: #53c9fc;
font-weight: bold;
padding-top: 12px;
}
#instructions {
text-align: center;
font-size: 19px;
margin: 14px 0 22px 0;
}
#generate-btn {
background: linear-gradient(90deg, #31b2fd 0%, #98f972 100%);
color: white;
font-size: 18px;
font-weight: bold;
border: none;
border-radius: 11px;
margin-top: 8px;
margin-bottom: 14px;
transition: 0.2s;
}
#generate-btn:hover {
filter: brightness(1.08);
box-shadow: 0 2px 16px #9efbc344;
}
"""
from transformers import BlipProcessor, BlipForConditionalGeneration
from ultralytics import YOLO
import torch
import gradio as gr
from PIL import Image
from collections import deque
import numpy as np
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
detect_model = YOLO('yolov5s.pt')
MEMORY_SIZE = 10
last_images = deque([], maxlen=MEMORY_SIZE)
last_captions = deque([], maxlen=MEMORY_SIZE)
def preprocess_image(image):
if image.mode != "RGB":
image = image.convert("RGB")
return image
def detect_objects(image):
img_np = np.array(image)
results = detect_model(img_np)
detected_objs = set()
for r in results:
for box in r.boxes.data.tolist():
class_id = int(box[-1])
label = detect_model.names[class_id]
detected_objs.add(label)
return list(detected_objs)
def generate_caption(image):
image = preprocess_image(image)
inputs = processor(image, return_tensors="pt")
out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
caption = processor.decode(out[0], skip_special_tokens=True)
detected_objs = detect_objects(image)
last_images.append(image)
last_captions.append(caption)
tags = ", ".join(detected_objs) if detected_objs else "None"
gallery = [(img, f"Detected objects: {tags}\nCaption: {caption}") for img, caption in zip(list(last_images), list(last_captions))]
result_text = f"Detected objects: {tags}\nCaption: {caption}"
return result_text, gallery
with gr.Blocks(css=custom_css) as iface:
gr.HTML('<div id="main-app-area">') # Start content region
gr.HTML('<div id="app-title">🖼️ Image Captioning with Object Detection</div>')
gr.HTML(
'<div id="instructions">'
'🙌 <b>Welcome!</b> Instantly analyze images using AI.<br>'
'1️⃣ <b>Upload</b> your image.<br>'
'2️⃣ Click <b>⭐ Generate Caption</b>.<br>'
'3️⃣ View and scroll through your history below.<br>'
'📜 <i>Last 10 results are stored for you.</i>'
'</div>'
)
image_input = gr.Image(type="pil", label="Upload Image")
generate_btn = gr.Button("⭐ Generate Caption", elem_id="generate-btn")
caption_output = gr.Textbox(label="📝 Caption and Detected Objects", lines=5, interactive=True)
gallery = gr.Gallery(label="Last 10 Images and Captions", scale=3)
def on_generate(image):
if image is None:
return "Please upload an image.", []
return generate_caption(image)
generate_btn.click(
fn=on_generate,
inputs=image_input,
outputs=[caption_output, gallery]
)
gr.HTML('</div>') # End content region
if __name__ == "__main__":
iface.launch()