Matt Blackman commited on
Commit
df76d7b
·
1 Parent(s): 478a4d7

Added batch operations

Browse files
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -7,7 +7,7 @@ import tempfile
7
  import os
8
  import os.path
9
 
10
- BLIP_MODEL_ID = "Salesforce/blip2-opt-2.7b"
11
  CAPTION_CSV_DIR=os.path.join(os.getcwd(), 'csv')
12
 
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -19,37 +19,45 @@ if device.type != 'cuda':
19
  blip_processor = Blip2Processor.from_pretrained(BLIP_MODEL_ID)
20
  blip_model = Blip2ForConditionalGeneration.from_pretrained(BLIP_MODEL_ID, torch_dtype=torch.float16, device_map="auto").to(device)
21
 
22
- def captions_images_to_csv(images: list[PIL.Image.Image]) -> str:
23
- caption_map = caption_images(images)
24
  lines = [k + "," + v + "\n" for k, v in caption_map.items()]
25
 
26
  with tempfile.NamedTemporaryFile('w', dir=CAPTION_CSV_DIR, delete=False, suffix=".csv") as f:
27
  f.writelines(lines)
28
- return f.name
29
 
30
 
31
- def caption_images(images: list[PIL.Image.Image]) -> dict[str, str]:
32
- image_files = [PIL.Image.open(image.name).convert('RGB') for image in images]
33
- inputs = blip_processor(images=image_files, return_tensors='pt').to(device, torch.float16)
34
 
35
- for image in image_files:
36
- image.close()
37
-
38
- generated_ids = blip_model.generate(**inputs, max_new_tokens=20)
39
-
40
- results = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)
41
-
42
- return dict(zip([Path(image.name).stem for image in images], [result.replace("\n", "") for result in results]))
 
 
 
 
43
 
44
 
45
  with gr.Blocks() as demo:
46
  with gr.Tab("Image Caption Bot"):
47
- images_input = gr.File(file_count="multiple", file_types=["image"])
 
48
  caption_images_button = gr.Button("Submit")
49
- image_caption_output = gr.File()
 
50
 
51
 
52
- caption_images_button.click(captions_images_to_csv, inputs=[images_input], outputs=image_caption_output)
 
 
 
53
 
54
 
55
  print("Launching Gradio")
 
7
  import os
8
  import os.path
9
 
10
+ BLIP_MODEL_ID = "Salesforce/blip2-opt-6.7b"
11
  CAPTION_CSV_DIR=os.path.join(os.getcwd(), 'csv')
12
 
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
19
  blip_processor = Blip2Processor.from_pretrained(BLIP_MODEL_ID)
20
  blip_model = Blip2ForConditionalGeneration.from_pretrained(BLIP_MODEL_ID, torch_dtype=torch.float16, device_map="auto").to(device)
21
 
22
+ def captions_images_to_csv(images: list[PIL.Image.Image], prompt: str) -> str:
23
+ caption_map = caption_images(images, prompt)
24
  lines = [k + "," + v + "\n" for k, v in caption_map.items()]
25
 
26
  with tempfile.NamedTemporaryFile('w', dir=CAPTION_CSV_DIR, delete=False, suffix=".csv") as f:
27
  f.writelines(lines)
28
+ return [f.name, ''.join(lines)]
29
 
30
 
31
+ def caption_images(images: list[PIL.Image.Image], prompt: str) -> dict[str, str]:
32
+ results = {}
 
33
 
34
+ for image in images:
35
+ try:
36
+ with PIL.Image.open(image.name).convert('RGB') as i:
37
+ inputs = blip_processor(i, text=prompt, return_tensors='pt').to(device, torch.float16)
38
+ generated_ids = blip_model.generate(**inputs, max_new_tokens=20)
39
+ results[Path(image.name).stem] = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
40
+
41
+ except Exception as e:
42
+ print("An error occured while trying to process image: " + image.name)
43
+ print(str(e))
44
+
45
+ return results
46
 
47
 
48
  with gr.Blocks() as demo:
49
  with gr.Tab("Image Caption Bot"):
50
+ images_input = gr.File(label= "Input Images", file_count="multiple", file_types=["image"])
51
+ caption_prompt = gr.Textbox(label="Prompt")
52
  caption_images_button = gr.Button("Submit")
53
+ image_caption_output_file = gr.File(label="Output CSV")
54
+ image_caption_output_label = gr.Textbox(label="Output Data", interactive=False)
55
 
56
 
57
+ caption_images_button.click(
58
+ captions_images_to_csv,
59
+ inputs=[images_input, caption_prompt],
60
+ outputs=[image_caption_output_file, image_caption_output_label])
61
 
62
 
63
  print("Launching Gradio")