Matt Blackman commited on
Commit
478a4d7
·
1 Parent(s): 1fda43f

Caption images and send to CSV file

Browse files
Files changed (2) hide show
  1. .gitignore +3 -1
  2. app.py +20 -8
.gitignore CHANGED
@@ -162,4 +162,6 @@ cython_debug/
162
  # venv folder
163
  env/
164
 
165
- node_modules/
 
 
 
162
  # venv folder
163
  env/
164
 
165
+ node_modules/
166
+
167
+ csv/
app.py CHANGED
@@ -2,11 +2,16 @@ import gradio as gr
2
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
3
  import torch
4
  import PIL.Image
 
 
 
 
5
 
6
  BLIP_MODEL_ID = "Salesforce/blip2-opt-2.7b"
 
7
 
8
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9
- if device != 'cuda':
10
  print(f"You are using {device}. This is much slower than using "
11
  "a CUDA-enabled GPU.")
12
 
@@ -14,30 +19,37 @@ if device != 'cuda':
14
  blip_processor = Blip2Processor.from_pretrained(BLIP_MODEL_ID)
15
  blip_model = Blip2ForConditionalGeneration.from_pretrained(BLIP_MODEL_ID, torch_dtype=torch.float16, device_map="auto").to(device)
16
 
17
-
18
- def caption_images(images: list[PIL.Image.Image], prompt: str) -> dict[str, str]:
 
 
 
 
 
 
 
 
19
  image_files = [PIL.Image.open(image.name).convert('RGB') for image in images]
20
  inputs = blip_processor(images=image_files, return_tensors='pt').to(device, torch.float16)
21
 
22
  for image in image_files:
23
  image.close()
24
 
25
- generated_ids = blip_model.generate(**inputs)
26
 
27
  results = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)
28
 
29
- return str(dict(zip([image.name for image in images], results)))
30
 
31
 
32
  with gr.Blocks() as demo:
33
  with gr.Tab("Image Caption Bot"):
34
  images_input = gr.File(file_count="multiple", file_types=["image"])
35
- blip_prompt = gr.Textbox("Prompt")
36
  caption_images_button = gr.Button("Submit")
37
- image_caption_label = gr.Label()
38
 
39
 
40
- caption_images_button.click(caption_images, inputs=[images_input, blip_prompt], outputs=image_caption_label)
41
 
42
 
43
  print("Launching Gradio")
 
2
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
3
  import torch
4
  import PIL.Image
5
+ from pathlib import Path
6
+ import tempfile
7
+ import os
8
+ import os.path
9
 
10
  BLIP_MODEL_ID = "Salesforce/blip2-opt-2.7b"
11
+ CAPTION_CSV_DIR=os.path.join(os.getcwd(), 'csv')
12
 
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+ if device.type != 'cuda':
15
  print(f"You are using {device}. This is much slower than using "
16
  "a CUDA-enabled GPU.")
17
 
 
19
  blip_processor = Blip2Processor.from_pretrained(BLIP_MODEL_ID)
20
  blip_model = Blip2ForConditionalGeneration.from_pretrained(BLIP_MODEL_ID, torch_dtype=torch.float16, device_map="auto").to(device)
21
 
22
+ def captions_images_to_csv(images: list[PIL.Image.Image]) -> str:
23
+ caption_map = caption_images(images)
24
+ lines = [k + "," + v + "\n" for k, v in caption_map.items()]
25
+
26
+ with tempfile.NamedTemporaryFile('w', dir=CAPTION_CSV_DIR, delete=False, suffix=".csv") as f:
27
+ f.writelines(lines)
28
+ return f.name
29
+
30
+
31
+ def caption_images(images: list[PIL.Image.Image]) -> dict[str, str]:
32
  image_files = [PIL.Image.open(image.name).convert('RGB') for image in images]
33
  inputs = blip_processor(images=image_files, return_tensors='pt').to(device, torch.float16)
34
 
35
  for image in image_files:
36
  image.close()
37
 
38
+ generated_ids = blip_model.generate(**inputs, max_new_tokens=20)
39
 
40
  results = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)
41
 
42
+ return dict(zip([Path(image.name).stem for image in images], [result.replace("\n", "") for result in results]))
43
 
44
 
45
  with gr.Blocks() as demo:
46
  with gr.Tab("Image Caption Bot"):
47
  images_input = gr.File(file_count="multiple", file_types=["image"])
 
48
  caption_images_button = gr.Button("Submit")
49
+ image_caption_output = gr.File()
50
 
51
 
52
+ caption_images_button.click(captions_images_to_csv, inputs=[images_input], outputs=image_caption_output)
53
 
54
 
55
  print("Launching Gradio")