Update README.md
Browse files
README.md
CHANGED
|
@@ -13,24 +13,40 @@ base_model: Salesforce/blip-image-captioning-base
|
|
| 13 |
|
| 14 |
# Example Usage
|
| 15 |
```python
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
```python
|
| 23 |
-
past_the_code
|
| 24 |
-
```
|
| 25 |
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
```
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
| Original | Original prompt | Generated prompt by image | Generated image |
|
| 34 |
-
| -------- | --------------- | ------------------------- | --------------- |
|
| 35 |
-
| pass | pass | pass | pass |
|
| 36 |
|
|
|
|
| 13 |
|
| 14 |
# Example Usage
|
| 15 |
```python
|
| 16 |
+
import torch
|
| 17 |
+
import requests
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 20 |
|
| 21 |
+
def prepare(text):
|
| 22 |
+
text = text.replace('. ','.').replace(' .','.')
|
| 23 |
+
text = text.replace('< ','<').replace(' <','<')
|
| 24 |
+
text = text.replace('> ','>').replace(' >','>')
|
| 25 |
+
text = text.replace('( ','(').replace(' (','(')
|
| 26 |
+
text = text.replace(') ',')').replace(' )',')')
|
| 27 |
+
text = text.replace(': ',':').replace(' :',':')
|
| 28 |
+
text = text.replace('_ ','_').replace(' _','_')
|
| 29 |
+
text = text.replace(',(())','')
|
| 30 |
+
for i in range(10):
|
| 31 |
+
text = text.replace(')))','))').replace('(((','((')
|
| 32 |
+
return text
|
| 33 |
|
| 34 |
+
path_to_model = "blip-image2promt-stable-diffusion-v0.15"
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
processor = BlipProcessor.from_pretrained(path_to_model)
|
| 37 |
+
model = BlipForConditionalGeneration.from_pretrained(path_to_model, torch_dtype=torch.float16).to("cuda")
|
| 38 |
|
| 39 |
+
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
|
| 40 |
+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
|
|
|
|
| 41 |
|
| 42 |
+
# unconditional image captioning
|
| 43 |
+
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
|
| 44 |
+
|
| 45 |
+
out = model.generate(**inputs, max_new_tokens=100)
|
| 46 |
+
|
| 47 |
+
out_txt = processor.decode(out[0], skip_special_tokens=True)
|
| 48 |
+
|
| 49 |
+
print(prepare(out_txt))
|
| 50 |
+
```
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
|