Update README.md
Browse files
README.md
CHANGED
|
@@ -51,7 +51,8 @@ from transformers import AutoModel, AutoTokenizer
|
|
| 51 |
import torch
|
| 52 |
import torch.nn.functional as F
|
| 53 |
from PIL import Image
|
| 54 |
-
import
|
|
|
|
| 55 |
|
| 56 |
def weighted_mean_pooling(hidden, attention_mask):
|
| 57 |
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
|
|
@@ -83,20 +84,25 @@ def encode(text_or_image_list):
|
|
| 83 |
embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
|
| 84 |
return embeddings
|
| 85 |
|
| 86 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
| 87 |
-
model = AutoModel.from_pretrained("
|
| 88 |
model.eval()
|
| 89 |
|
| 90 |
-
script_dir = os.path.dirname(os.path.realpath(__file__))
|
| 91 |
queries = ["What does a dog look like?"]
|
| 92 |
-
passages = [
|
| 93 |
-
Image.open(os.path.join(script_dir, 'test_image/cat.jpeg')).convert('RGB'),
|
| 94 |
-
Image.open(os.path.join(script_dir, 'test_image/dog.jpg')).convert('RGB'),
|
| 95 |
-
]
|
| 96 |
-
|
| 97 |
INSTRUCTION = "Represent this query for retrieving relevant documents: "
|
| 98 |
queries = [INSTRUCTION + query for query in queries]
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
embeddings_query = encode(queries)
|
| 101 |
embeddings_doc = encode(passages)
|
| 102 |
|
|
|
|
| 51 |
import torch
|
| 52 |
import torch.nn.functional as F
|
| 53 |
from PIL import Image
|
| 54 |
+
import requests
|
| 55 |
+
from io import BytesIO
|
| 56 |
|
| 57 |
def weighted_mean_pooling(hidden, attention_mask):
|
| 58 |
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
|
|
|
|
| 84 |
embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
|
| 85 |
return embeddings
|
| 86 |
|
| 87 |
+
tokenizer = AutoTokenizer.from_pretrained("/mnt/data/user/tc_agi/klara/datasets/visrag_ret/visrag_ret", trust_remote_code=True)
|
| 88 |
+
model = AutoModel.from_pretrained("/mnt/data/user/tc_agi/klara/datasets/visrag_ret/visrag_ret", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 89 |
model.eval()
|
| 90 |
|
|
|
|
| 91 |
queries = ["What does a dog look like?"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
INSTRUCTION = "Represent this query for retrieving relevant documents: "
|
| 93 |
queries = [INSTRUCTION + query for query in queries]
|
| 94 |
|
| 95 |
+
print("Downloading images...")
|
| 96 |
+
passages = [
|
| 97 |
+
Image.open(BytesIO(requests.get(
|
| 98 |
+
'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/cat.jpeg'
|
| 99 |
+
).content)).convert('RGB'),
|
| 100 |
+
Image.open(BytesIO(requests.get(
|
| 101 |
+
'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/dog.jpg'
|
| 102 |
+
).content)).convert('RGB')
|
| 103 |
+
]
|
| 104 |
+
print("Images downloaded.")
|
| 105 |
+
|
| 106 |
embeddings_query = encode(queries)
|
| 107 |
embeddings_doc = encode(passages)
|
| 108 |
|