Spaces:
Sleeping
Sleeping
| import os | |
| from tqdm.auto import tqdm | |
| from utils.utils import create_client | |
| from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, utility | |
| from utils.get_embeddings import preprocess_image, extract_features, create_resnet18_model | |
| COLLECTION_NAME = "Resnet18" | |
| EMBEDDING_DIM = 512 | |
| IMAGE_FOLDER = "/home/nampham/Desktop/image-retrieval/data/images_mr" | |
| client = create_client() | |
| def load_collection(): | |
| check_collection = utility.has_collection(COLLECTION_NAME) | |
| if check_collection: | |
| print("Load and use collection right now!") | |
| collection = Collection(COLLECTION_NAME) | |
| collection.load() | |
| print(utility.load_state(COLLECTION_NAME)) | |
| else: | |
| print("Please create a collection and insert data!") | |
| collection = create_collection() | |
| # insert data into collection | |
| model = create_resnet18_model() | |
| insert_data(model, collection, IMAGE_FOLDER) | |
| # create index for search | |
| create_index(collection) | |
| return collection | |
| def create_collection(): | |
| image_id = FieldSchema( | |
| name="image_id", | |
| dtype=DataType.INT64, | |
| is_primary=True, | |
| description="Image ID" | |
| ) | |
| image_embedding = FieldSchema( | |
| name="image_embedding", | |
| dtype=DataType.FLOAT_VECTOR, | |
| description="Image Embedding" | |
| ) | |
| schema = CollectionSchema( | |
| fields=[image_id, image_embedding], | |
| auto_id=True, | |
| description="Image Retrieval using Resnet18" | |
| ) | |
| collection = Collection( | |
| name=COLLECTION_NAME, | |
| schema=schema | |
| ) | |
| return collection | |
| def insert_data(model, collection, image_folder): | |
| image_ids = sorted([ | |
| int(iamge_name.split('.')[0]) for image_name in os.listdir(image_folder) | |
| ]) | |
| image_embeddings = [] | |
| for image_name in tqdm(image_ids): | |
| file_name = str(image_name) + ".jpg" | |
| image_path = os.path.join(image_folder, file_name) | |
| processed_image = preprocess_image(image_path) | |
| processed_image = extract_features(model, processed_image) | |
| image_embeddings.append(processed_image) | |
| entities = [image_ids, image_embeddings] | |
| ins_resp = collection.insert(entities) | |
| collection.flush() | |
| def create_index(collection): | |
| index_params = { | |
| "index_type": "IVF_FLAT", | |
| "metric_type": "L2", | |
| "params": {} | |
| } | |
| collection.create_index( | |
| field_name=image_embedding.name, | |
| index_params=index_params | |
| ) | |
| # load collection | |
| collection.load() |