cconklin's picture
Update app.py
9252c7e verified
import torch
from transformers import (
pipeline,
BlipProcessor,
BlipForConditionalGeneration,
)
import gradio as gr
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -------- VQA PIPELINE --------
vqa = pipeline(
"visual-question-answering",
model="Salesforce/blip-vqa-base",
device=0 if DEVICE == "cuda" else -1,
)
# -------- CAPTION MODEL (Manual, more stable) --------
caption_processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
caption_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
).to(DEVICE)
def generate_caption(image):
if image is None:
return ""
inputs = caption_processor(image, return_tensors="pt").to(DEVICE)
out = caption_model.generate(**inputs)
caption = caption_processor.decode(out[0], skip_special_tokens=True)
return caption
def answer_question(image, question):
if image is None:
return "Please upload an image first."
if not question or not question.strip():
return "Please type a question about the image."
result = vqa(question=question, image=image)
return result[0]["answer"]
with gr.Blocks() as demo:
gr.Markdown("# BLIP Captioning + Visual Question Answering")
gr.Markdown(
"Upload an image to generate a caption, then ask a question about the image for an answer."
)
with gr.Row():
image_in = gr.Image(type="pil", label="Upload an image")
with gr.Column():
caption_out = gr.Textbox(label="Caption (auto-generated)", lines=2)
answer_out = gr.Textbox(label="Answer", lines=2)
question_in = gr.Textbox(
label="Question",
placeholder="e.g., What is in the image? How many people are there?",
)
with gr.Row():
clear_btn = gr.Button("Clear")
answer_btn = gr.Button("Submit")
image_in.change(fn=generate_caption, inputs=image_in, outputs=caption_out)
answer_btn.click(fn=answer_question, inputs=[image_in, question_in], outputs=answer_out)
clear_btn.click(fn=lambda: (None, "", "", ""), outputs=[image_in, question_in, caption_out, answer_out])
gr.Markdown("**Note:** Outputs may be incorrect. Do not use for medical/legal decisions.")
if __name__ == "__main__":
demo.launch()