mblackman commited on
Commit
fc830bc
·
1 Parent(s): f2c1eba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
3
+ import torch
4
+ import PIL.Image
5
+ from huggingface_hub import HfApi
6
+
7
+ BLIP_MODEL_ID = "Salesforce/blip2-opt-2.7b"
8
+
9
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10
+ if device != 'cuda':
11
+ print(f"You are using {device}. This is much slower than using "
12
+ "a CUDA-enabled GPU.")
13
+
14
+
15
+ blip_processor = Blip2Processor.from_pretrained(BLIP_MODEL_ID)
16
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(BLIP_MODEL_ID, torch_dtype=torch.float16, device_map="auto").to(device)
17
+
18
+ print("Loading API")
19
+ api = HfApi()
20
+
21
+
22
+ def caption_images(images: list[PIL.Image.Image], prompt: str) -> dict[str, str]:
23
+ image_files = [PIL.Image.open(image.name).convert('RGB') for image in images]
24
+ inputs = blip_processor(images=image_files, return_tensors='pt').to(device, torch.float16)
25
+
26
+ for image in image_files:
27
+ image.close()
28
+
29
+ generated_ids = blip_model.generate(**inputs)
30
+
31
+ results = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)
32
+
33
+ return str(dict(zip([image.name for image in images], results)))
34
+
35
+
36
+ with gr.Blocks() as demo:
37
+ with gr.Tab("Image Caption Bot"):
38
+ images_input = gr.File(file_count="multiple", file_types=["image"])
39
+ blip_prompt = gr.Textbox("Prompt")
40
+ caption_images_button = gr.Button("Submit")
41
+ image_caption_label = gr.Label()
42
+
43
+
44
+ caption_images_button.click(caption_images, inputs=[images_input, blip_prompt], outputs=image_caption_label)
45
+
46
+
47
+ print("Launching Gradio")
48
+ demo.launch()