Update generate_embeddings.py
Browse files- generate_embeddings.py +18 -2
generate_embeddings.py
CHANGED
|
@@ -1,10 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
try:
|
| 2 |
embeddings_df = pd.read_pickle('image_embeddings.pickle')
|
| 3 |
index = embeddings_df.shape[0]
|
| 4 |
except:
|
| 5 |
index=0
|
| 6 |
embeddings_df = pd.DataFrame(columns=['image_embedding'])
|
| 7 |
-
formats = []
|
| 8 |
while index<tasks_df.shape[0]:
|
| 9 |
image = load_image(tasks_df['image_path'][index])
|
| 10 |
inputs = processor(images=[image], return_tensors="pt").to(model.device)
|
|
@@ -13,6 +29,6 @@ while index<tasks_df.shape[0]:
|
|
| 13 |
image_embeddings = model.get_image_features(**inputs)
|
| 14 |
new_row = {'image_embedding': image_embeddings}
|
| 15 |
embeddings_df = pd.concat([embeddings_df, pd.DataFrame([new_row])], ignore_index=True)
|
| 16 |
-
if index %
|
| 17 |
embeddings_df.to_pickle('image_embeddings.pickle')
|
| 18 |
index+=1
|
|
|
|
| 1 |
+
# Given a DataFrame tasks_df with 'image_path' col that contains all images paths this script will produce a 'image_embeddings.pickle'
|
| 2 |
+
# file that contains all the embeddings. You can stop and resume whenever you want, it will restart from the last saved image file
|
| 3 |
+
#
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import sys
|
| 6 |
+
from transformers import pipeline
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoModel, AutoProcessor
|
| 9 |
+
from transformers.image_utils import load_image
|
| 10 |
+
|
| 11 |
+
ckpt = "google/siglip2-so400m-patch16-512"
|
| 12 |
+
model = AutoModel.from_pretrained(ckpt, device_map="auto").eval()
|
| 13 |
+
processor = AutoProcessor.from_pretrained(ckpt)
|
| 14 |
+
|
| 15 |
+
tasks_df = # load DataFrame with 'image_path' col that contains all images paths
|
| 16 |
+
save_interval = 100 # save embeddings file every save_interval images
|
| 17 |
+
|
| 18 |
try:
|
| 19 |
embeddings_df = pd.read_pickle('image_embeddings.pickle')
|
| 20 |
index = embeddings_df.shape[0]
|
| 21 |
except:
|
| 22 |
index=0
|
| 23 |
embeddings_df = pd.DataFrame(columns=['image_embedding'])
|
|
|
|
| 24 |
while index<tasks_df.shape[0]:
|
| 25 |
image = load_image(tasks_df['image_path'][index])
|
| 26 |
inputs = processor(images=[image], return_tensors="pt").to(model.device)
|
|
|
|
| 29 |
image_embeddings = model.get_image_features(**inputs)
|
| 30 |
new_row = {'image_embedding': image_embeddings}
|
| 31 |
embeddings_df = pd.concat([embeddings_df, pd.DataFrame([new_row])], ignore_index=True)
|
| 32 |
+
if index % save_interval==0:
|
| 33 |
embeddings_df.to_pickle('image_embeddings.pickle')
|
| 34 |
index+=1
|