Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import ( | |
| pipeline, | |
| BlipProcessor, | |
| BlipForConditionalGeneration, | |
| ) | |
| import gradio as gr | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # -------- VQA PIPELINE -------- | |
| vqa = pipeline( | |
| "visual-question-answering", | |
| model="Salesforce/blip-vqa-base", | |
| device=0 if DEVICE == "cuda" else -1, | |
| ) | |
| # -------- CAPTION MODEL (Manual, more stable) -------- | |
| caption_processor = BlipProcessor.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ) | |
| caption_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ).to(DEVICE) | |
| def generate_caption(image): | |
| if image is None: | |
| return "" | |
| inputs = caption_processor(image, return_tensors="pt").to(DEVICE) | |
| out = caption_model.generate(**inputs) | |
| caption = caption_processor.decode(out[0], skip_special_tokens=True) | |
| return caption | |
| def answer_question(image, question): | |
| if image is None: | |
| return "Please upload an image first." | |
| if not question or not question.strip(): | |
| return "Please type a question about the image." | |
| result = vqa(question=question, image=image) | |
| return result[0]["answer"] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# BLIP Captioning + Visual Question Answering") | |
| gr.Markdown( | |
| "Upload an image to generate a caption, then ask a question about the image for an answer." | |
| ) | |
| with gr.Row(): | |
| image_in = gr.Image(type="pil", label="Upload an image") | |
| with gr.Column(): | |
| caption_out = gr.Textbox(label="Caption (auto-generated)", lines=2) | |
| answer_out = gr.Textbox(label="Answer", lines=2) | |
| question_in = gr.Textbox( | |
| label="Question", | |
| placeholder="e.g., What is in the image? How many people are there?", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| answer_btn = gr.Button("Submit") | |
| image_in.change(fn=generate_caption, inputs=image_in, outputs=caption_out) | |
| answer_btn.click(fn=answer_question, inputs=[image_in, question_in], outputs=answer_out) | |
| clear_btn.click(fn=lambda: (None, "", "", ""), outputs=[image_in, question_in, caption_out, answer_out]) | |
| gr.Markdown("**Note:** Outputs may be incorrect. Do not use for medical/legal decisions.") | |
| if __name__ == "__main__": | |
| demo.launch() |