Image_search / embeddings.py
Kabila22's picture
image search is done
8c65df5
import os
import pandas as pd
from pinecone import Pinecone, ServerlessSpec
from dotenv import load_dotenv
import requests
from tqdm import tqdm
from PIL import Image
from transformers import AutoProcessor, CLIPModel
import logging
import time
# Logging setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Pinecone setup
pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))
index_name = "image-index"
if index_name not in pc.list_indexes().names():
pc.create_index(
name=index_name,
metric="cosine",
dimension=512,
spec=ServerlessSpec(cloud="aws", region="us-east-1")
)
while not pc.describe_index(index_name).status.get("ready", False):
logger.info("Waiting for index to be ready...")
time.sleep(1)
unsplash_index = pc.Index(index_name)
# CLIP setup (loaded once)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Load dataset
images_df = pd.read_csv("image.csv")[["photo_id", "photo_image_url"]][:500]
total_images = len(images_df)
logger.info(f"Total images to process: {total_images}")
# Sequential processing function
def process_image(row):
try:
url = row["photo_image_url"]
photo_id = row["photo_id"]
# Download image
img = Image.open(requests.get(url, stream=True).raw)
# Generate embeddings
inputs = processor(images=img, return_tensors="pt")
image_features = model.get_image_features(**inputs)
embeddings = image_features.detach().cpu().numpy().flatten().tolist()
# Upsert to Pinecone
unsplash_index.upsert(
vectors=[{
"id": photo_id,
"values": embeddings,
"metadata": {"url": url, "photo_id": photo_id}
}],
namespace="image-search-dataset"
)
return f"Processed {photo_id}"
except Exception as e:
logger.error(f"Error processing {photo_id} with URL {url}: {e}")
return f"Error {photo_id}"
# Process images sequentially with tqdm
for _, row in tqdm(images_df.iterrows(), total=total_images, desc="Indexing images"):
result = process_image(row)
logger.info(result)
logger.info("Indexing complete!")