Matt Blackman
Added batch operations
df76d7b
import gradio as gr
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import PIL.Image
from pathlib import Path
import tempfile
import os
import os.path
BLIP_MODEL_ID = "Salesforce/blip2-opt-6.7b"
CAPTION_CSV_DIR=os.path.join(os.getcwd(), 'csv')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device.type != 'cuda':
print(f"You are using {device}. This is much slower than using "
"a CUDA-enabled GPU.")
blip_processor = Blip2Processor.from_pretrained(BLIP_MODEL_ID)
blip_model = Blip2ForConditionalGeneration.from_pretrained(BLIP_MODEL_ID, torch_dtype=torch.float16, device_map="auto").to(device)
def captions_images_to_csv(images: list[PIL.Image.Image], prompt: str) -> str:
caption_map = caption_images(images, prompt)
lines = [k + "," + v + "\n" for k, v in caption_map.items()]
with tempfile.NamedTemporaryFile('w', dir=CAPTION_CSV_DIR, delete=False, suffix=".csv") as f:
f.writelines(lines)
return [f.name, ''.join(lines)]
def caption_images(images: list[PIL.Image.Image], prompt: str) -> dict[str, str]:
results = {}
for image in images:
try:
with PIL.Image.open(image.name).convert('RGB') as i:
inputs = blip_processor(i, text=prompt, return_tensors='pt').to(device, torch.float16)
generated_ids = blip_model.generate(**inputs, max_new_tokens=20)
results[Path(image.name).stem] = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
except Exception as e:
print("An error occured while trying to process image: " + image.name)
print(str(e))
return results
with gr.Blocks() as demo:
with gr.Tab("Image Caption Bot"):
images_input = gr.File(label= "Input Images", file_count="multiple", file_types=["image"])
caption_prompt = gr.Textbox(label="Prompt")
caption_images_button = gr.Button("Submit")
image_caption_output_file = gr.File(label="Output CSV")
image_caption_output_label = gr.Textbox(label="Output Data", interactive=False)
caption_images_button.click(
captions_images_to_csv,
inputs=[images_input, caption_prompt],
outputs=[image_caption_output_file, image_caption_output_label])
print("Launching Gradio")
demo.launch()