File size: 2,939 Bytes
c252082
 
fb068e1
d1f1944
11fbf23
d1f1944
0871fe0
d1f1944
c252082
0871fe0
9dc4906
 
fc2eb0c
9dc4906
717bbba
 
 
884a96b
e472a2b
717bbba
 
884a96b
e472a2b
717bbba
f6b4289
e472a2b
0871fe0
e472a2b
 
 
717bbba
 
 
9dc4906
1dfb058
 
 
6c7a133
1dfb058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717bbba
f253ea9
717bbba
 
 
f253ea9
 
 
 
 
47ca826
 
 
717bbba
f253ea9
 
1dfb058
5723f8a
47ca826
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import torch
import os
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering, infer_device, AutoModel

#login to HF
login(token=os.getenv('HF_TOKEN'))

#define model and processor
processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
device = infer_device()

# Define inference function
def process_image(image, prompt):
   # Process the image and prompt using the processor
   inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

   try:
       # Generate output from the model
       output = model.generate(**inputs, max_new_tokens=10)

       # Decode and return the output
       decoded_output = processor.batch_decode(output, skip_special_tokens=True)[0].strip()

       #remove prompt from output
       if decoded_output.startswith(prompt):
           return decoded_output[len(prompt):].strip()
       return decoded_output
   except IndexError as e:
       print(f"IndexError: {e}")
       return "An error occurred during processing."


# Define model 2
processor2 = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
model2 = AutoModel.from_pretrained("merve/idefics3-llama-vqav2", dtype="auto")
device2 = infer_device()

# Define inference function
def process_image2(image, prompt):
   # Process the image and prompt using the processor
   inputs = processor2(image, text=prompt, return_tensors="pt").to(device, torch.float16)

   try:
       # Generate output from the model
       output = model2.generate(**inputs, max_new_tokens=10)

       # Decode and return the output
       decoded_output = processor2.batch_decode(output, skip_special_tokens=True)[0].strip()

       #remove prompt from output
       if decoded_output.startswith(prompt):
           return decoded_output[len(prompt):].strip()
       return decoded_output
   except IndexError as e:
       print(f"IndexError: {e}")
       return "An error occurred during processing."



# Define the Gradio interface
inputs_model1 = [
   gr.Image(type="pil"),
   gr.Textbox(label="Prompt", placeholder="Enter your question")
]
inputs_model2 = [
   gr.Image(type="pil"),
   gr.Textbox(label="Prompt", placeholder="Enter your question")
]

outputs_model1 = gr.Textbox(label="Answer")
outputs_model2 = gr.Textbox(label="Answer")

# Create the Gradio app
model1_inf = gr.Interface(fn=process_image, inputs=inputs_model1, outputs=outputs_model1, title="Visual Question Answering", description="Upload an image and ask questions to get answers.")
model2_inf = gr.Interface(fn=process_image2, inputs=inputs_model2, outputs=outputs_model2, title="Visual Question Answering", description="Upload an image and ask questions to get answers.")

demo = gr.TabbedInterface([model1_inf, model2_inf],["Model 1", "Model 2"])
demo.launch(share="true")