yashsharmaa commited on
Commit
a982e1b
·
verified ·
1 Parent(s): 97708b6

Update models/caption.py

Browse files
Files changed (1) hide show
  1. models/caption.py +14 -14
models/caption.py CHANGED
@@ -1,14 +1,14 @@
1
- from transformers import BlipProcessor, BlipForConditionalGeneration
2
- import torch
3
-
4
- # Load BLIP model and processor once
5
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
6
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
7
-
8
- @torch.no_grad()
9
- def generate_caption(image):
10
- inputs = processor(images=image, return_tensors="pt").to("cuda")
11
- output = model.generate(**inputs, max_new_tokens=50)
12
- caption = processor.tokenizer.decode(output[0], skip_special_tokens=True)
13
- return caption
14
-
 
1
+ from transformers import BlipProcessor, BlipForConditionalGeneration
2
+ import torch
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
7
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
8
+
9
+ @torch.no_grad()
10
+ def generate_caption(image):
11
+ inputs = processor(images=image, return_tensors="pt").to(device)
12
+ output = model.generate(**inputs, max_new_tokens=50)
13
+ caption = processor.tokenizer.decode(output[0], skip_special_tokens=True)
14
+ return caption