File size: 2,418 Bytes
fc830bc
 
 
 
478a4d7
 
 
 
fc830bc
df76d7b
478a4d7
fc830bc
 
478a4d7
fc830bc
 
 
 
 
 
 
df76d7b
 
478a4d7
 
 
 
df76d7b
478a4d7
 
df76d7b
 
fc830bc
df76d7b
 
 
 
 
 
 
 
 
 
 
 
fc830bc
 
 
 
df76d7b
 
fc830bc
df76d7b
 
fc830bc
 
df76d7b
 
 
 
fc830bc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import PIL.Image
from pathlib import Path
import tempfile
import os
import os.path

BLIP_MODEL_ID = "Salesforce/blip2-opt-6.7b"
CAPTION_CSV_DIR=os.path.join(os.getcwd(), 'csv')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device.type != 'cuda':
    print(f"You are using {device}. This is much slower than using "
          "a CUDA-enabled GPU.")


blip_processor = Blip2Processor.from_pretrained(BLIP_MODEL_ID)
blip_model = Blip2ForConditionalGeneration.from_pretrained(BLIP_MODEL_ID, torch_dtype=torch.float16, device_map="auto").to(device)

def captions_images_to_csv(images: list[PIL.Image.Image], prompt: str) -> str:
    caption_map = caption_images(images, prompt)
    lines = [k + "," + v + "\n" for k, v in caption_map.items()]
    
    with tempfile.NamedTemporaryFile('w', dir=CAPTION_CSV_DIR, delete=False, suffix=".csv") as f:    
        f.writelines(lines)
        return [f.name, ''.join(lines)]
    
    
def caption_images(images: list[PIL.Image.Image], prompt: str) -> dict[str, str]:   
    results = {}
    
    for image in images:
        try:
            with PIL.Image.open(image.name).convert('RGB') as i:
                inputs = blip_processor(i, text=prompt, return_tensors='pt').to(device, torch.float16)
                generated_ids = blip_model.generate(**inputs, max_new_tokens=20)
                results[Path(image.name).stem] = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
                
        except Exception as e: 
            print("An error occured while trying to process image: " + image.name)
            print(str(e))

    return results


with gr.Blocks() as demo:
    with gr.Tab("Image Caption Bot"):
        images_input = gr.File(label= "Input Images", file_count="multiple", file_types=["image"])
        caption_prompt = gr.Textbox(label="Prompt")
        caption_images_button = gr.Button("Submit")
        image_caption_output_file = gr.File(label="Output CSV")
        image_caption_output_label = gr.Textbox(label="Output Data", interactive=False)
        

    caption_images_button.click(
        captions_images_to_csv, 
        inputs=[images_input, caption_prompt], 
        outputs=[image_caption_output_file, image_caption_output_label])


print("Launching Gradio")
demo.launch()