|
|
|
|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from datasets import concatenate_datasets, load_dataset |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
|
|
|
|
|
|
MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned" |
|
|
DATASET_NAMES = [f"EYEDOL/AGRILLAVA-image-text{i}" for i in range(1, 16)] |
|
|
BATCH_SIZE = 16 |
|
|
OUT_DIR = "faiss_free_data" |
|
|
EMBEDS_FILE = os.path.join(OUT_DIR, "text_embeds.npy") |
|
|
TEXTS_JSONL = os.path.join(OUT_DIR, "texts.jsonl") |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
os.makedirs(OUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
print("Loading datasets...") |
|
|
all_splits = [load_dataset(name)["train"] for name in DATASET_NAMES] |
|
|
dataset = concatenate_datasets(all_splits) |
|
|
texts = list(dataset["text"]) |
|
|
print(f"Got {len(texts)} texts.") |
|
|
|
|
|
|
|
|
print("Loading model & processor...") |
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
|
model = AutoModel.from_pretrained(MODEL_ID).to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
all_embeds = [] |
|
|
for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Encoding texts"): |
|
|
batch = texts[i:i+BATCH_SIZE] |
|
|
inputs = processor(text=batch, padding=True, truncation=True, return_tensors="pt").to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
embeds = model.get_text_features(**inputs) |
|
|
|
|
|
embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
all_embeds.append(embeds.cpu().numpy().astype("float32")) |
|
|
del inputs, embeds |
|
|
if DEVICE == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
all_embeds = np.concatenate(all_embeds, axis=0) |
|
|
print("Embeddings shape:", all_embeds.shape) |
|
|
|
|
|
|
|
|
np.save(EMBEDS_FILE, all_embeds) |
|
|
print(f"Saved embeddings to {EMBEDS_FILE}") |
|
|
|
|
|
with open(TEXTS_JSONL, "w", encoding="utf-8") as f: |
|
|
for t in texts: |
|
|
f.write(json.dumps({"text": t}, ensure_ascii=False) + "\n") |
|
|
print(f"Saved texts to {TEXTS_JSONL}") |
|
|
|
|
|
print("Done. Upload the folder 'faiss_free_data' to your Space repository (git lfs or upload_file).") |
|
|
|