cconklin commited on
Commit
a8887c5
·
verified ·
1 Parent(s): 5e841d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -31
app.py CHANGED
@@ -2,40 +2,82 @@ import torch
2
  from transformers import pipeline
3
  import gradio as gr
4
 
5
- # Choose device: GPU if available, otherwise CPU. On Hugging Face Spaces, unless you explicitly pick a GPU runtime, you’re on CPU only
6
- if torch.cuda.is_available():
7
- vqa = pipeline(
8
- task="visual-question-answering",
9
- model="Salesforce/blip-vqa-base",
10
- torch_dtype=torch.float16,#newer versions of TRANSFORMERS in Hugging face is torch_dtype not dtype. dtype is still working fine in Google Colab space
11
- device=0, # GPU
12
- use_fast=False,
13
- )
14
- else:
15
- vqa = pipeline(
16
- task="visual-question-answering",
17
- model="Salesforce/blip-vqa-base",
18
- device=-1, # CPU
19
- use_fast=False,
20
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def answer_question(image, question):
23
- if not question:
 
 
 
24
  return "Please type a question about the image."
25
- # vqa returns a list of dicts like [{'score':..., 'answer':...}]
26
  result = vqa(question=question, image=image)
27
- return result[0]["answer"]
28
-
29
- demo = gr.Interface(
30
- fn=answer_question,
31
- inputs=[
32
- gr.Image(type="pil", label="Upload an image"),
33
- gr.Textbox(label="Question", placeholder="e.g. What is the weather in this image?"),
34
- ],
35
- outputs=gr.Textbox(label="Answer"),
36
- title="BLIP Visual Question Answering",
37
- description="Ask a question about the uploaded image using Salesforce/blip-vqa-base.",
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  if __name__ == "__main__":
41
- demo.launch()
 
2
  from transformers import pipeline
3
  import gradio as gr
4
 
5
+ # Choose device: GPU if available, otherwise CPU.
6
+ DEVICE = 0 if torch.cuda.is_available() else -1
7
+
8
+ # --- Load pipelines ---
9
+ # VQA (image + question -> answer)
10
+ vqa = pipeline(
11
+ task="visual-question-answering",
12
+ model="Salesforce/blip-vqa-base",
13
+ device=DEVICE,
14
+ torch_dtype=torch.float16 if DEVICE == 0 else None,
15
+ use_fast=False,
16
+ )
17
+
18
+ # Captioning (image -> text)
19
+ captioner = pipeline(
20
+ task="image-to-text",
21
+ model="Salesforce/blip-image-captioning-base",
22
+ device=DEVICE,
23
+ torch_dtype=torch.float16 if DEVICE == 0 else None,
24
+ use_fast=False,
25
+ )
26
+
27
+ # --- App functions ---
28
+ def generate_caption(image):
29
+ """Generate a short caption for the uploaded image."""
30
+ if image is None:
31
+ return ""
32
+ result = captioner(image)
33
+ # result is typically [{'generated_text': '...'}]
34
+ return result[0].get("generated_text", "").strip()
35
 
36
  def answer_question(image, question):
37
+ """Answer a question about the image."""
38
+ if image is None:
39
+ return "Please upload an image first."
40
+ if not question or not question.strip():
41
  return "Please type a question about the image."
 
42
  result = vqa(question=question, image=image)
43
+ return result[0].get("answer", "")
44
+
45
+ # --- Gradio UI ---
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown("# BLIP Captioning + Visual Question Answering")
48
+ gr.Markdown(
49
+ "1) Upload an image to generate a caption. \n"
50
+ "2) Ask a question about the image to get an answer. \n"
51
+ "Models: `Salesforce/blip-image-captioning-base` and `Salesforce/blip-vqa-base`."
52
+ )
53
+
54
+ with gr.Row():
55
+ image_in = gr.Image(type="pil", label="Upload an image")
56
+ with gr.Column():
57
+ caption_out = gr.Textbox(label="Caption (auto-generated)", lines=2)
58
+ answer_out = gr.Textbox(label="Answer", lines=2)
59
+
60
+ question_in = gr.Textbox(
61
+ label="Question",
62
+ placeholder="e.g., What is in the image? How many people are there? What color is the car?",
63
+ )
64
+
65
+ with gr.Row():
66
+ clear_btn = gr.Button("Clear")
67
+ answer_btn = gr.Button("Submit")
68
+
69
+ # Auto-caption when image changes
70
+ image_in.change(fn=generate_caption, inputs=image_in, outputs=caption_out)
71
+
72
+ # Answer on button click
73
+ answer_btn.click(fn=answer_question, inputs=[image_in, question_in], outputs=answer_out)
74
+
75
+ # Clear everything
76
+ clear_btn.click(fn=lambda: (None, "", "", ""), inputs=None, outputs=[image_in, question_in, caption_out, answer_out])
77
+
78
+ gr.Markdown(
79
+ "**Note:** This demo may produce incorrect outputs. Do not use for medical/legal decisions."
80
+ )
81
 
82
  if __name__ == "__main__":
83
+ demo.launch()