庞喆
added original results for comparison
1f0013e
import torch
import clip
from PIL import Image
import requests
import io
import gradio as gr
from pinecone import Pinecone
import logging
import os
# Setup basic configuration for logging
logging.basicConfig(filename='embedding_errors.log', level=logging.ERROR,
format='%(asctime)s:%(levelname)s:%(message)s')
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize CLIP model and Pinecone
model, preprocess = clip.load("ViT-B/32", device=DEVICE)
pinecone_staging = Pinecone(api_key=os.getenv('PINECONE_API_KEY'))
index_staging = pinecone_staging.Index(os.getenv('PINECONE_INDEX_NAME'))
pinecone_prod = Pinecone(api_key=os.getenv('PINECONE_API_KEY_PROD'))
index_prod = pinecone_prod.Index(os.getenv('PINECONE_INDEX_NAME_PROD'))
def generate_combined_embeddings(image_url, text_description):
""" Generate combined embeddings for image and text using CLIP. """
try:
response = requests.get(image_url)
image = preprocess(Image.open(io.BytesIO(response.content))).unsqueeze(0).to(DEVICE)
text = clip.tokenize([text_description]).to(DEVICE)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
image_embedding = image_features.squeeze().cpu().numpy()
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
combined_features = torch.cat([text_features, image_features], dim=1)
return combined_features.squeeze().cpu().numpy().tolist(), image_embedding.tolist()
except Exception as e:
logging.error(f"Error processing {image_url} with description '{text_description}': {e}")
return None, None
def fetch_product_details(image_url):
try:
api_url = "https://microtools-staging.pietrastudio.com/ai-design-product-details"
headers = {
"Content-Type": "application/json",
"PIETRA-API-KEY": "hellopietra-123"
}
data = {"imageURL": image_url}
response = requests.post(api_url, json=data, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
# Specific error handling for HTTP errors
logging.error(f"HTTP error occurred while fetching product details for {image_url}: {e}")
except requests.exceptions.RequestException as e:
# Broad request exceptions handling
logging.error(f"Request exception occurred while fetching product details for {image_url}: {e}")
except Exception as e:
# General exceptions
logging.error(f"An error occurred while processing the image URL '{image_url}': {e}")
# Return None or a default value if the request fails
return None
def query_pinecone(image_url, text_description, top_k=10):
""" Query the Pinecone index with generated embeddings. """
embedding, image_embedding = generate_combined_embeddings(image_url, text_description)
if embedding is None:
return [], []
result1 = index_staging.query(
vector=embedding,
top_k=top_k,
include_metadata=True)
json_list1 = []
for item in result1['matches']:
data = {
"imageUrl": item["metadata"]['imageUrl'],
"itemId": int(item["metadata"]["itemId"]),
"score": item["score"]
}
json_list1.append(data)
result2 = index_prod.query(
vector=image_embedding,
top_k=top_k,
include_metadata=True)
json_list2 = []
for item in result2['matches']:
data = {
"imageUrl": item["metadata"]['imageUrl'],
"itemId": int(item["metadata"]["itemId"]),
"score": item["score"]
}
json_list2.append(data)
return json_list1, json_list2
def format_output(image_url, results_staging, results_prod):
""" Format the output to display images with IDs. """
formatted_html = f"<div style='text-align: center; margin-bottom: 20px;'>"
formatted_html += f"<img src='{image_url}' style='height: 300px; display: block; margin: 0 auto;'>"
formatted_html += f"<p style='font-size: 16px; color: grey;'>Input Image</p>" # Caption for the input image
formatted_html += "</div><hr>" # Horizontal line separator
formatted_html += "<div style='text-align: center; margin-bottom: 10px; font-size: 16px; color: grey;'>"
formatted_html += "Click on any image below to visit the product page."
formatted_html += "<div style='display: flex; flex-wrap: wrap; justify-content: center; margin-bottom: 20px'>"
formatted_html += "<h2 style='width: 100%; text-align: center; color: grey;'>Test Results</h2>"
for result in results_staging:
# Construct the store link using a constant pattern and the item_id
store_link = f"https://creators.pietrastudio.com/browse-suppliers/catalog-item/{result['itemId']}"
caption = f"Item ID: {result['itemId']}, Score: {result['score']:.3f}"
formatted_html += f"<div style='margin: 10px;'><a href='{store_link}' target='_blank'>" \
f"<img src='{result['imageUrl']}' style='height: 300px;'></a><p>{caption}</p></div>"
formatted_html += "</div>"
# Section for production results
formatted_html += "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>"
formatted_html += "<h2 style='width: 100%; text-align: center; color: grey;'>Original Results</h2>"
for result in results_prod:
store_link = f"https://creators.pietrastudio.com/browse-suppliers/catalog-item/{result['itemId']}"
caption = f"Item ID: {result['itemId']}, Score: {result['score']:.3f}"
formatted_html += f"<div style='margin: 10px;'><a href='{store_link}' target='_blank'>" \
f"<img src='{result['imageUrl']}' style='height: 300px;'></a><p>{caption}</p></div>"
formatted_html += "</div>"
return formatted_html
def search(image_url, text_description, include_ai_description):
""" Handle search requests from Gradio interface. """
ai_description = ""
if include_ai_description:
product_detail = fetch_product_details(image_url)
if 'designDescription' in product_detail:
ai_description = " ".join(product_detail['designDescription']) + ", "
if 'pantoneColors' in product_detail and product_detail['pantoneColors']:
color_descriptions = [f"{color['name']}" for color in product_detail['pantoneColors']]
color_text = " Colors: " + ", ".join(color_descriptions)
ai_description += color_text
print(ai_description)
text_description += ", " + ai_description
result_staging, result_prod = query_pinecone(image_url, text_description)
return ai_description, format_output(image_url, result_staging, result_prod)
def main():
""" Set up and launch the Gradio interface. """
interface = gr.Interface(
fn=search,
inputs=[
gr.Textbox(lines=1, placeholder="Image URL", label="Image URL"),
gr.Textbox(lines=2, placeholder="Text Description", label="Text Description"),
gr.Checkbox(label="Include AI-generated description", value=True)
],
outputs=[
gr.Textbox(label="AI Description"), # Display AI-generated description
gr.HTML(label="Search Results") # Display input images and search results
],
title="Image and Text Search",
description="Search similar items based on image URL and text description.",
concurrency_limit=20
)
interface.launch(share=True)
if __name__ == '__main__':
main()