Spaces:
Runtime error
Runtime error
| from datasets import load_dataset | |
| from PIL import Image | |
| import os | |
| import pandas as pd | |
| from transformers import AutoFeatureExtractor,AutoModel | |
| from faiss.contrib.inspect_tools import get_flat_data | |
| import pymde | |
| import numpy as np | |
| def get_embedding(model_name,viz_dat): | |
| index_file=f"./indexes/{model_name.split('/')[1]}.faiss" | |
| if os.path.exists(index_file): | |
| viz_dat.load_faiss_index('embeddings', index_file) | |
| else: | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name) | |
| # model.to("cuda:0") | |
| def embed(x): | |
| images=x["image"] | |
| inputs = feature_extractor(images=images, return_tensors="pt") | |
| # inputs.to("cuda:0") | |
| outputs = model(**inputs,output_hidden_states= True) | |
| final_emb=outputs.pooler_output.detach().cpu().numpy() # this line depends on the model you are using | |
| x["embeddings"]=final_emb | |
| return x | |
| # Add embeddings to dataset | |
| viz_dat = viz_dat.map(embed,batched=True,batch_size=20) | |
| viz_dat.add_faiss_index(column='embeddings') | |
| viz_dat.save_faiss_index('embeddings',index_file) | |
| embedding_file=f"./indexes/{model_name.split('/')[1]}.npy" | |
| if os.path.exists(embedding_file): | |
| embedding = np.load(embedding_file) # load | |
| else: | |
| index=viz_dat.get_index("embeddings").faiss_index | |
| embeddings=get_flat_data(index) | |
| embedding=pymde.preserve_neighbors(embeddings, verbose=True).embed().numpy() | |
| np.save(embedding_file, embedding) # save | |
| embedding=pd.DataFrame(embedding,columns=["x","y"]) | |
| embedding["image"]=viz_dat["image"] | |
| embedding["gender"]=viz_dat["gender"] | |
| embedding["masterCategory"]=viz_dat["masterCategory"] | |
| embedding["subCategory"]=viz_dat["subCategory"] | |
| return embedding |