import torch from transformers import ( pipeline, BlipProcessor, BlipForConditionalGeneration, ) import gradio as gr DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # -------- VQA PIPELINE -------- vqa = pipeline( "visual-question-answering", model="Salesforce/blip-vqa-base", device=0 if DEVICE == "cuda" else -1, ) # -------- CAPTION MODEL (Manual, more stable) -------- caption_processor = BlipProcessor.from_pretrained( "Salesforce/blip-image-captioning-base" ) caption_model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base" ).to(DEVICE) def generate_caption(image): if image is None: return "" inputs = caption_processor(image, return_tensors="pt").to(DEVICE) out = caption_model.generate(**inputs) caption = caption_processor.decode(out[0], skip_special_tokens=True) return caption def answer_question(image, question): if image is None: return "Please upload an image first." if not question or not question.strip(): return "Please type a question about the image." result = vqa(question=question, image=image) return result[0]["answer"] with gr.Blocks() as demo: gr.Markdown("# BLIP Captioning + Visual Question Answering") gr.Markdown( "Upload an image to generate a caption, then ask a question about the image for an answer." ) with gr.Row(): image_in = gr.Image(type="pil", label="Upload an image") with gr.Column(): caption_out = gr.Textbox(label="Caption (auto-generated)", lines=2) answer_out = gr.Textbox(label="Answer", lines=2) question_in = gr.Textbox( label="Question", placeholder="e.g., What is in the image? How many people are there?", ) with gr.Row(): clear_btn = gr.Button("Clear") answer_btn = gr.Button("Submit") image_in.change(fn=generate_caption, inputs=image_in, outputs=caption_out) answer_btn.click(fn=answer_question, inputs=[image_in, question_in], outputs=answer_out) clear_btn.click(fn=lambda: (None, "", "", ""), outputs=[image_in, question_in, caption_out, answer_out]) gr.Markdown("**Note:** Outputs may be incorrect. Do not use for medical/legal decisions.") if __name__ == "__main__": demo.launch()