Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import List | |
| from PIL import Image | |
| import io | |
| import base64 | |
| from sentence_transformers import SentenceTransformer, util | |
| class Item(BaseModel): | |
| images: List[bytes] | |
| #name: str | |
| #image1: bytes | |
| #image2: bytes | |
| class ScoreResponse(BaseModel): | |
| score: float | |
| app = FastAPI() | |
| def predict(item: Item): | |
| # Load the OpenAI CLIP Model | |
| print('Loading CLIP Model...') | |
| model = SentenceTransformer('clip-ViT-B-32') | |
| # Next we compute the embeddings | |
| # To encode an image, you can use the following code: | |
| # from PIL import Image | |
| # encoded_image = model.encode(Image.open(filepath)) | |
| #image_names = list(glob.glob('./*.jpg')) | |
| #print("Images:", len(image_names)) | |
| #image_stream1 = io.BytesIO(base64.b64decode(item.image1)); | |
| #image_stream2 = io.BytesIO(base64.b64decode(item.image2)); | |
| #encoded_image = model.encode([Image.open(image_stream1), Image.open(image_stream2)], batch_size=128, convert_to_tensor=True, show_progress_bar=True) | |
| # Assuming 'item.images' is a list of base64 encoded strings, each representing an image | |
| image_streams = [io.BytesIO(base64.b64decode(image_data)) for image_data in item.images] | |
| # Now open each image stream using PIL's Image.open | |
| opened_images = [Image.open(image_stream) for image_stream in image_streams] | |
| # Assuming 'model' is your pre-trained model that encodes the images | |
| encoded_image = model.encode(opened_images, batch_size=128, convert_to_tensor=True, show_progress_bar=True) | |
| # Now we run the clustering algorithm. This function compares images aganist | |
| # all other images and returns a list with the pairs that have the highest | |
| # cosine similarity score | |
| processed_images = util.paraphrase_mining_embeddings(encoded_image) | |
| #NUM_SIMILAR_IMAGES = 10 | |
| # ================= | |
| # DUPLICATES | |
| # ================= | |
| #print('Finding duplicate images...') | |
| # Filter list for duplicates. Results are triplets (score, image_id1, image_id2) and is scorted in decreasing order | |
| # A duplicate image will have a score of 1.00 | |
| # It may be 0.9999 due to lossy image compression (.jpg) | |
| #duplicates = [image for image in processed_images if image[0] >= 0.999] | |
| # Output the top X duplicate images | |
| #for score, image_id1, image_id2 in duplicates[0:NUM_SIMILAR_IMAGES]: | |
| #print("\nScore: {:.3f}%".format(score * 100)) | |
| # Check if there are any duplicates | |
| #if duplicates: | |
| # Find the top score among duplicates | |
| #top_score = max(duplicates, key=lambda x: x[0])[0] | |
| #formatted_score = round(top_score * 100, 3) # Multiplies by 100 and rounds to three decimal places | |
| #return ScoreResponse(score=formatted_score) | |
| # ================= | |
| # NEAR DUPLICATES | |
| # ================= | |
| print('Finding near duplicate images...') | |
| # Use a threshold parameter to identify two images as similar. By setting the threshold lower, | |
| # you will get larger clusters which have less similar images in it. Threshold 0 - 1.00 | |
| # A threshold of 1.00 means the two images are exactly the same. Since we are finding near | |
| # duplicate images, we can set it at 0.99 or any number 0 < X < 1.00. | |
| threshold = 1.0 | |
| near_duplicates = [image for image in processed_images if image[0] < threshold] | |
| #for score, image_id1, image_id2 in near_duplicates[0:NUM_SIMILAR_IMAGES]: | |
| #print("\nScore: {:.3f}%".format(score * 100)) | |
| # Find the top score from near duplicates | |
| if near_duplicates: | |
| top_score = max(near_duplicates, key=lambda x: x[0])[0] | |
| else: | |
| top_score = 0 # Default score if there are no near duplicates | |
| formatted_score = round(top_score * 100, 3) # Multiplies by 100 and rounds to three decimal places | |
| print("Score: " + str(formatted_score)) | |
| return ScoreResponse(score=formatted_score) | |
| #formatted_score = round(score * 100, 3) # Multiplies by 100 and rounds to three decimal places | |
| #return ScoreResponse(score=formatted_score) | |
| #return "Score: {:.3f}%".format(score * 100) | |
| #@app.get("/") | |
| #def read_root(): | |
| # return {"Hello": "World!"} | |
| #@app.put("/return/") | |
| #def return_item(item: Item): | |
| # return item |