Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import os | |
| from huggingface_hub import login | |
| from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering, infer_device, Idefics3ForConditionalGeneration | |
| #login to HF | |
| login(token=os.getenv('HF_TOKEN')) | |
| #define model and processor | |
| processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") | |
| device = "cpu" | |
| model.to(device) | |
| # Define inference function | |
| def process_image(image, prompt): | |
| # Process the image and prompt using the processor | |
| inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16) | |
| try: | |
| # Generate output from the model | |
| output = model.generate(**inputs, max_new_tokens=10) | |
| # Decode and return the output | |
| decoded_output = processor.batch_decode(output, skip_special_tokens=True)[0].strip() | |
| #remove prompt from output | |
| if decoded_output.startswith(prompt): | |
| return decoded_output[len(prompt):].strip() | |
| return decoded_output | |
| except IndexError as e: | |
| print(f"IndexError: {e}") | |
| return "An error occurred during processing." | |
| def format_idefics_prompt(prompt): | |
| """Formats the user's question with the necessary image token.""" | |
| return f"User: <image> {prompt}\nAssistant:" | |
| # Define model 2 | |
| processor2 = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3") | |
| model2 = Idefics3ForConditionalGeneration.from_pretrained("merve/idefics3-llama-vqav2", dtype=torch.bfloat16) | |
| device2 = infer_device() | |
| # Define inference function | |
| def process_image2(image, prompt): | |
| formatted_prompt = format_idefics_prompt(prompt) | |
| inputs2 = processor2( | |
| text=formatted_prompt, | |
| images=[image], | |
| return_tensors="pt" | |
| ).to(device2, model2.dtype) | |
| try: | |
| generated_ids = model2.generate( | |
| **inputs2, | |
| max_new_tokens=10, | |
| do_sample=False | |
| ) | |
| decoded_output = processor2.batch_decode( | |
| generated_ids[:, inputs2["input_ids"].shape[1]:], | |
| skip_special_tokens=True | |
| )[0].strip() | |
| if decoded_output.startswith(formatted_prompt): | |
| return decoded_output[len(formatted_prompt):].strip() | |
| return decoded_output | |
| except Exception as e: | |
| print(f"Error in Model 2 during generation: {e}") | |
| return "An error occurred during Model 2 processing. Please check console for memory issues." | |
| # Define the Gradio interface | |
| inputs_model1 = [ | |
| gr.Image(type="pil"), | |
| gr.Textbox(label="Prompt", placeholder="Enter your question") | |
| ] | |
| inputs_model2 = [ | |
| gr.Image(type="pil"), | |
| gr.Textbox(label="Prompt", placeholder="Enter your question") | |
| ] | |
| outputs_model1 = gr.Textbox(label="Answer") | |
| outputs_model2 = gr.Textbox(label="Answer") | |
| # Create the Gradio app | |
| model1_inf = gr.Interface(fn=process_image, inputs=inputs_model1, outputs=outputs_model1, title="Visual Question Answering", description="Upload an image and ask questions to get answers.") | |
| model2_inf = gr.Interface(fn=process_image2, inputs=inputs_model2, outputs=outputs_model2, title="Visual Question Answering", description="Upload an image and ask questions to get answers.") | |
| demo = gr.TabbedInterface([model1_inf, model2_inf],["Model 1", "Model 2"]) | |
| demo.launch(share="True") |