AGRO / build_index.py
EYEDOL's picture
Update build_index.py
73ef8da verified
# build_index_no_faiss.py
import os
import json
import numpy as np
import torch
from tqdm import tqdm
from datasets import concatenate_datasets, load_dataset
from transformers import AutoProcessor, AutoModel
# CONFIG
MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned"
DATASET_NAMES = [f"EYEDOL/AGRILLAVA-image-text{i}" for i in range(1, 16)]
BATCH_SIZE = 16
OUT_DIR = "faiss_free_data" # folder to upload into your Space
EMBEDS_FILE = os.path.join(OUT_DIR, "text_embeds.npy")
TEXTS_JSONL = os.path.join(OUT_DIR, "texts.jsonl")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUT_DIR, exist_ok=True)
# Load datasets and concat
print("Loading datasets...")
all_splits = [load_dataset(name)["train"] for name in DATASET_NAMES]
dataset = concatenate_datasets(all_splits)
texts = list(dataset["text"])
print(f"Got {len(texts)} texts.")
# Load model & processor
print("Loading model & processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModel.from_pretrained(MODEL_ID).to(DEVICE)
model.eval()
# Compute text embeddings in batches and L2-normalize (helps cosine)
all_embeds = []
for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Encoding texts"):
batch = texts[i:i+BATCH_SIZE]
inputs = processor(text=batch, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
embeds = model.get_text_features(**inputs) # (bs, dim)
# normalize
embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True)
all_embeds.append(embeds.cpu().numpy().astype("float32"))
del inputs, embeds
if DEVICE == "cuda":
torch.cuda.empty_cache()
all_embeds = np.concatenate(all_embeds, axis=0) # shape (N, D)
print("Embeddings shape:", all_embeds.shape)
# Save embeddings and texts mapping
np.save(EMBEDS_FILE, all_embeds)
print(f"Saved embeddings to {EMBEDS_FILE}")
with open(TEXTS_JSONL, "w", encoding="utf-8") as f:
for t in texts:
f.write(json.dumps({"text": t}, ensure_ascii=False) + "\n")
print(f"Saved texts to {TEXTS_JSONL}")
print("Done. Upload the folder 'faiss_free_data' to your Space repository (git lfs or upload_file).")