Spaces:
Sleeping
Sleeping
File size: 2,939 Bytes
c252082 fb068e1 d1f1944 11fbf23 d1f1944 0871fe0 d1f1944 c252082 0871fe0 9dc4906 fc2eb0c 9dc4906 717bbba 884a96b e472a2b 717bbba 884a96b e472a2b 717bbba f6b4289 e472a2b 0871fe0 e472a2b 717bbba 9dc4906 1dfb058 6c7a133 1dfb058 717bbba f253ea9 717bbba f253ea9 47ca826 717bbba f253ea9 1dfb058 5723f8a 47ca826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import gradio as gr
import torch
import os
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering, infer_device, AutoModel
#login to HF
login(token=os.getenv('HF_TOKEN'))
#define model and processor
processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
device = infer_device()
# Define inference function
def process_image(image, prompt):
# Process the image and prompt using the processor
inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
try:
# Generate output from the model
output = model.generate(**inputs, max_new_tokens=10)
# Decode and return the output
decoded_output = processor.batch_decode(output, skip_special_tokens=True)[0].strip()
#remove prompt from output
if decoded_output.startswith(prompt):
return decoded_output[len(prompt):].strip()
return decoded_output
except IndexError as e:
print(f"IndexError: {e}")
return "An error occurred during processing."
# Define model 2
processor2 = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
model2 = AutoModel.from_pretrained("merve/idefics3-llama-vqav2", dtype="auto")
device2 = infer_device()
# Define inference function
def process_image2(image, prompt):
# Process the image and prompt using the processor
inputs = processor2(image, text=prompt, return_tensors="pt").to(device, torch.float16)
try:
# Generate output from the model
output = model2.generate(**inputs, max_new_tokens=10)
# Decode and return the output
decoded_output = processor2.batch_decode(output, skip_special_tokens=True)[0].strip()
#remove prompt from output
if decoded_output.startswith(prompt):
return decoded_output[len(prompt):].strip()
return decoded_output
except IndexError as e:
print(f"IndexError: {e}")
return "An error occurred during processing."
# Define the Gradio interface
inputs_model1 = [
gr.Image(type="pil"),
gr.Textbox(label="Prompt", placeholder="Enter your question")
]
inputs_model2 = [
gr.Image(type="pil"),
gr.Textbox(label="Prompt", placeholder="Enter your question")
]
outputs_model1 = gr.Textbox(label="Answer")
outputs_model2 = gr.Textbox(label="Answer")
# Create the Gradio app
model1_inf = gr.Interface(fn=process_image, inputs=inputs_model1, outputs=outputs_model1, title="Visual Question Answering", description="Upload an image and ask questions to get answers.")
model2_inf = gr.Interface(fn=process_image2, inputs=inputs_model2, outputs=outputs_model2, title="Visual Question Answering", description="Upload an image and ask questions to get answers.")
demo = gr.TabbedInterface([model1_inf, model2_inf],["Model 1", "Model 2"])
demo.launch(share="true") |