File size: 3,008 Bytes
c252082
 
fb068e1
49a7d93
d1f1944
a36cd6c
4305571
d1f1944
76e2cb7
 
 
 
c252082
9dc4906
76e2cb7
9dc4906
76e2cb7
717bbba
76e2cb7
e472a2b
76e2cb7
 
 
e472a2b
76e2cb7
 
e472a2b
76e2cb7
 
 
 
 
 
 
 
9dc4906
55c82e1
e9c2a3c
62ddac2
a36cd6c
76e2cb7
1dfb058
76e2cb7
 
 
 
 
 
55c82e1
76e2cb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6149fbf
f253ea9
717bbba
 
 
f253ea9
 
debe6e8
f253ea9
 
47ca826
 
 
76e2cb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfb058
76e2cb7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
import os
import tempfile
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering, infer_device, PaliGemmaForConditionalGeneration
from accelerate import Accelerator

# Set the device
device = infer_device()

# MODEL 1: BLIP-VQA

processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)

# Define inference function for Model 1
def process_image(image, prompt):
    inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

    try:
        # Generate output from the model
        output = model.generate(**inputs, max_new_tokens=10)

        # Decode and return the output
        decoded_output = processor.batch_decode(output, skip_special_tokens=True)[0].strip()

        # remove prompt from output
        if decoded_output.startswith(prompt):
            return decoded_output[len(prompt):].strip()
        return decoded_output
    except Exception as e:
        print(f"Error in Model 1: {e}")
        return "An error occurred during processing for Model 1."


# MODEL 2: PaliGemma 

processor2 = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
model2 = PaliGemmaForConditionalGeneration.from_pretrained("merve/paligemma_vqav2")


# Define inference function for Model 2
def process_image2(image, prompt):
    inputs2 = processor2(
        text=prompt, 
        images=image, 
        return_tensors="pt"
    ).to(device, model2.dtype)

    try:
        output = model2.generate(**inputs2, max_new_tokens=10)
        decoded_output = processor2.batch_decode(
            output[:, inputs2["input_ids"].shape[1]:], 
            skip_special_tokens=True
        )[0].strip()
        
        return decoded_output
    except Exception as e:
        print(f"Error in Model 2: {e}")
        return "An error occurred during processing for Model 2. Ensure your hardware supports bfloat16 or adjust the torch_dtype."


# GRADIO INTERFACE

inputs_model1 = [
   gr.Image(type="pil"),
   gr.Textbox(label="Prompt", placeholder="Enter your question")
]
inputs_model2 = [
   gr.Image(type="pil"),
   gr.Textbox(label="Prompt", placeholder="Enter your question")
]

outputs_model1 = gr.Textbox(label="Answer")
outputs_model2 = gr.Textbox(label="Answer")

# Create the Gradio apps for each model
model1_inf = gr.Interface(
    fn=process_image, 
    inputs=inputs_model1, 
    outputs=outputs_model1, 
    title="Model 1: BLIP-VQA-Base", 
    description="Ask a question about the uploaded image using BLIP."
)

model2_inf = gr.Interface(
    fn=process_image2, 
    inputs=inputs_model2, 
    outputs=outputs_model2, 
    title="Model 2: PaliGemma VQAv2", 
    description="Ask a question about the uploaded image using PaliGemma (fine-tuned VQA)."
)

demo = gr.TabbedInterface([model1_inf, model2_inf],["Model 1 (BLIP)", "Model 2 (PaliGemma)"])
demo.launch(share=True)