Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import os | |
| from huggingface_hub import login | |
| from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering, infer_device, AutoModel | |
| #login to HF | |
| login(token=os.getenv('HF_TOKEN')) | |
| #define model and processor | |
| processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") | |
| device = infer_device() | |
| # Define inference function | |
| def process_image(image, prompt): | |
| # Process the image and prompt using the processor | |
| inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16) | |
| try: | |
| # Generate output from the model | |
| output = model.generate(**inputs, max_new_tokens=10) | |
| # Decode and return the output | |
| decoded_output = processor.batch_decode(output, skip_special_tokens=True)[0].strip() | |
| #remove prompt from output | |
| if decoded_output.startswith(prompt): | |
| return decoded_output[len(prompt):].strip() | |
| return decoded_output | |
| except IndexError as e: | |
| print(f"IndexError: {e}") | |
| return "An error occurred during processing." | |
| # Define model 2 | |
| processor2 = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3") | |
| model2 = AutoModel.from_pretrained("merve/idefics3-llama-vqav2", dtype="auto") | |
| device2 = infer_device() | |
| # Define inference function | |
| def process_image2(image, prompt): | |
| # Process the image and prompt using the processor | |
| inputs = processor2(image, text=prompt, return_tensors="pt").to(device, torch.float16) | |
| try: | |
| # Generate output from the model | |
| output = model2.generate(**inputs, max_new_tokens=10) | |
| # Decode and return the output | |
| decoded_output = processor2.batch_decode(output, skip_special_tokens=True)[0].strip() | |
| #remove prompt from output | |
| if decoded_output.startswith(prompt): | |
| return decoded_output[len(prompt):].strip() | |
| return decoded_output | |
| except IndexError as e: | |
| print(f"IndexError: {e}") | |
| return "An error occurred during processing." | |
| # Define the Gradio interface | |
| inputs_model1 = [ | |
| gr.Image(type="pil"), | |
| gr.Textbox(label="Prompt", placeholder="Enter your question") | |
| ] | |
| inputs_model2 = [ | |
| gr.Image(type="pil"), | |
| gr.Textbox(label="Prompt", placeholder="Enter your question") | |
| ] | |
| outputs_model1 = gr.Textbox(label="Answer") | |
| outputs_model2 = gr.Textbox(label="Answer") | |
| # Create the Gradio app | |
| model1_inf = gr.Interface(fn=process_image, inputs=inputs_model1, outputs=outputs_model1, title="Visual Question Answering", description="Upload an image and ask questions to get answers.") | |
| model2_inf = gr.Interface(fn=process_image2, inputs=inputs_model2, outputs=outputs_model2, title="Visual Question Answering", description="Upload an image and ask questions to get answers.") | |
| demo = gr.TabbedInterface([model1_inf, model2_inf],["Model 1", "Model 2"]) | |
| demo.launch(share="true") |