File size: 2,322 Bytes
5e841d6
9252c7e
 
 
 
 
5e841d6
 
9252c7e
a8887c5
9252c7e
 
 
a8887c5
9252c7e
a8887c5
 
9252c7e
 
 
a8887c5
9252c7e
 
 
cbe53f7
a8887c5
 
 
9252c7e
 
 
 
 
5e841d6
 
a8887c5
 
 
5e841d6
9252c7e
5e841d6
9252c7e
a8887c5
 
 
 
cbe53f7
a8887c5
 
 
 
 
 
 
 
 
 
9252c7e
a8887c5
 
 
 
 
 
 
 
cbe53f7
a8887c5
cbe53f7
5e841d6
 
a8887c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from transformers import (
    pipeline,
    BlipProcessor,
    BlipForConditionalGeneration,
)
import gradio as gr

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -------- VQA PIPELINE --------
vqa = pipeline(
    "visual-question-answering",
    model="Salesforce/blip-vqa-base",
    device=0 if DEVICE == "cuda" else -1,
)

# -------- CAPTION MODEL (Manual, more stable) --------
caption_processor = BlipProcessor.from_pretrained(
    "Salesforce/blip-image-captioning-base"
)
caption_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base"
).to(DEVICE)

def generate_caption(image):
    if image is None:
        return ""

    inputs = caption_processor(image, return_tensors="pt").to(DEVICE)
    out = caption_model.generate(**inputs)
    caption = caption_processor.decode(out[0], skip_special_tokens=True)
    return caption

def answer_question(image, question):
    if image is None:
        return "Please upload an image first."
    if not question or not question.strip():
        return "Please type a question about the image."

    result = vqa(question=question, image=image)
    return result[0]["answer"]

with gr.Blocks() as demo:
    gr.Markdown("# BLIP Captioning + Visual Question Answering")
    gr.Markdown(
        "Upload an image to generate a caption, then ask a question about the image for an answer."
    )

    with gr.Row():
        image_in = gr.Image(type="pil", label="Upload an image")
        with gr.Column():
            caption_out = gr.Textbox(label="Caption (auto-generated)", lines=2)
            answer_out = gr.Textbox(label="Answer", lines=2)

    question_in = gr.Textbox(
        label="Question",
        placeholder="e.g., What is in the image? How many people are there?",
    )

    with gr.Row():
        clear_btn = gr.Button("Clear")
        answer_btn = gr.Button("Submit")

    image_in.change(fn=generate_caption, inputs=image_in, outputs=caption_out)
    answer_btn.click(fn=answer_question, inputs=[image_in, question_in], outputs=answer_out)
    clear_btn.click(fn=lambda: (None, "", "", ""), outputs=[image_in, question_in, caption_out, answer_out])

    gr.Markdown("**Note:** Outputs may be incorrect. Do not use for medical/legal decisions.")

if __name__ == "__main__":
    demo.launch()