| |
|
|
| from io import BytesIO |
| from pathlib import Path |
| from typing import Any, List, Tuple, Union |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from PIL import Image |
| from matplotlib import pyplot as plt |
| from pandas import DataFrame |
| from tqdm import tqdm |
|
|
| from ultralytics.data.augment import Format |
| from ultralytics.data.dataset import YOLODataset |
| from ultralytics.data.utils import check_det_dataset |
| from ultralytics.models.yolo.model import YOLO |
| from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks, USER_CONFIG_DIR |
| from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch |
|
|
|
|
| class ExplorerDataset(YOLODataset): |
| def __init__(self, *args, data: dict = None, **kwargs) -> None: |
| super().__init__(*args, data=data, **kwargs) |
|
|
| def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]: |
| """Loads 1 image from dataset index 'i' without any resize ops.""" |
| im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] |
| if im is None: |
| if fn.exists(): |
| im = np.load(fn) |
| else: |
| im = cv2.imread(f) |
| if im is None: |
| raise FileNotFoundError(f"Image Not Found {f}") |
| h0, w0 = im.shape[:2] |
| return im, (h0, w0), im.shape[:2] |
|
|
| return self.ims[i], self.im_hw0[i], self.im_hw[i] |
|
|
| def build_transforms(self, hyp: IterableSimpleNamespace = None): |
| """Creates transforms for dataset images without resizing.""" |
| return Format( |
| bbox_format="xyxy", |
| normalize=False, |
| return_mask=self.use_segments, |
| return_keypoint=self.use_keypoints, |
| batch_idx=True, |
| mask_ratio=hyp.mask_ratio, |
| mask_overlap=hyp.overlap_mask, |
| ) |
|
|
|
|
| class Explorer: |
| def __init__( |
| self, |
| data: Union[str, Path] = "coco128.yaml", |
| model: str = "yolov8n.pt", |
| uri: str = USER_CONFIG_DIR / "explorer", |
| ) -> None: |
| |
| checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"]) |
| import lancedb |
|
|
| self.connection = lancedb.connect(uri) |
| self.table_name = Path(data).name.lower() + "_" + model.lower() |
| self.sim_idx_base_name = ( |
| f"{self.table_name}_sim_idx".lower() |
| ) |
| self.model = YOLO(model) |
| self.data = data |
| self.choice_set = None |
|
|
| self.table = None |
| self.progress = 0 |
|
|
| def create_embeddings_table(self, force: bool = False, split: str = "train") -> None: |
| """ |
| Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it |
| already exists. Pass force=True to overwrite the existing table. |
| |
| Args: |
| force (bool): Whether to overwrite the existing table or not. Defaults to False. |
| split (str): Split of the dataset to use. Defaults to 'train'. |
| |
| Example: |
| ```python |
| exp = Explorer() |
| exp.create_embeddings_table() |
| ``` |
| """ |
| if self.table is not None and not force: |
| LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.") |
| return |
| if self.table_name in self.connection.table_names() and not force: |
| LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.") |
| self.table = self.connection.open_table(self.table_name) |
| self.progress = 1 |
| return |
| if self.data is None: |
| raise ValueError("Data must be provided to create embeddings table") |
|
|
| data_info = check_det_dataset(self.data) |
| if split not in data_info: |
| raise ValueError( |
| f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}" |
| ) |
|
|
| choice_set = data_info[split] |
| choice_set = choice_set if isinstance(choice_set, list) else [choice_set] |
| self.choice_set = choice_set |
| dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task) |
|
|
| |
| batch = dataset[0] |
| vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0] |
| table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite") |
| table.add( |
| self._yield_batches( |
| dataset, |
| data_info, |
| self.model, |
| exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"], |
| ) |
| ) |
|
|
| self.table = table |
|
|
| def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]): |
| """Generates batches of data for embedding, excluding specified keys.""" |
| for i in tqdm(range(len(dataset))): |
| self.progress = float(i + 1) / len(dataset) |
| batch = dataset[i] |
| for k in exclude_keys: |
| batch.pop(k, None) |
| batch = sanitize_batch(batch, data_info) |
| batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist() |
| yield [batch] |
|
|
| def query( |
| self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25 |
| ) -> Any: |
| """ |
| Query the table for similar images. Accepts a single image or a list of images. |
| |
| Args: |
| imgs (str or list): Path to the image or a list of paths to the images. |
| limit (int): Number of results to return. |
| |
| Returns: |
| (pyarrow.Table): An arrow table containing the results. Supports converting to: |
| - pandas dataframe: `result.to_pandas()` |
| - dict of lists: `result.to_pydict()` |
| |
| Example: |
| ```python |
| exp = Explorer() |
| exp.create_embeddings_table() |
| similar = exp.query(img='https://ultralytics.com/images/zidane.jpg') |
| ``` |
| """ |
| if self.table is None: |
| raise ValueError("Table is not created. Please create the table first.") |
| if isinstance(imgs, str): |
| imgs = [imgs] |
| assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}" |
| embeds = self.model.embed(imgs) |
| |
| embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() |
| return self.table.search(embeds).limit(limit).to_arrow() |
|
|
| def sql_query( |
| self, query: str, return_type: str = "pandas" |
| ) -> Union[DataFrame, Any, None]: |
| """ |
| Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. |
| |
| Args: |
| query (str): SQL query to run. |
| return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
| |
| Returns: |
| (pyarrow.Table): An arrow table containing the results. |
| |
| Example: |
| ```python |
| exp = Explorer() |
| exp.create_embeddings_table() |
| query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
| result = exp.sql_query(query) |
| ``` |
| """ |
| assert return_type in { |
| "pandas", |
| "arrow", |
| }, f"Return type should be either `pandas` or `arrow`, but got {return_type}" |
| import duckdb |
|
|
| if self.table is None: |
| raise ValueError("Table is not created. Please create the table first.") |
|
|
| |
| table = self.table.to_arrow() |
| if not query.startswith("SELECT") and not query.startswith("WHERE"): |
| raise ValueError( |
| f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}" |
| ) |
| if query.startswith("WHERE"): |
| query = f"SELECT * FROM 'table' {query}" |
| LOGGER.info(f"Running query: {query}") |
|
|
| rs = duckdb.sql(query) |
| if return_type == "arrow": |
| return rs.arrow() |
| elif return_type == "pandas": |
| return rs.df() |
|
|
| def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: |
| """ |
| Plot the results of a SQL-Like query on the table. |
| Args: |
| query (str): SQL query to run. |
| labels (bool): Whether to plot the labels or not. |
| |
| Returns: |
| (PIL.Image): Image containing the plot. |
| |
| Example: |
| ```python |
| exp = Explorer() |
| exp.create_embeddings_table() |
| query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
| result = exp.plot_sql_query(query) |
| ``` |
| """ |
| result = self.sql_query(query, return_type="arrow") |
| if len(result) == 0: |
| LOGGER.info("No results found.") |
| return None |
| img = plot_query_result(result, plot_labels=labels) |
| return Image.fromarray(img) |
|
|
| def get_similar( |
| self, |
| img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
| idx: Union[int, List[int]] = None, |
| limit: int = 25, |
| return_type: str = "pandas", |
| ) -> Union[DataFrame, Any]: |
| """ |
| Query the table for similar images. Accepts a single image or a list of images. |
| |
| Args: |
| img (str or list): Path to the image or a list of paths to the images. |
| idx (int or list): Index of the image in the table or a list of indexes. |
| limit (int): Number of results to return. Defaults to 25. |
| return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
| |
| Returns: |
| (pandas.DataFrame): A dataframe containing the results. |
| |
| Example: |
| ```python |
| exp = Explorer() |
| exp.create_embeddings_table() |
| similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg') |
| ``` |
| """ |
| assert return_type in { |
| "pandas", |
| "arrow", |
| }, f"Return type should be either `pandas` or `arrow`, but got {return_type}" |
| img = self._check_imgs_or_idxs(img, idx) |
| similar = self.query(img, limit=limit) |
|
|
| if return_type == "arrow": |
| return similar |
| elif return_type == "pandas": |
| return similar.to_pandas() |
|
|
| def plot_similar( |
| self, |
| img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
| idx: Union[int, List[int]] = None, |
| limit: int = 25, |
| labels: bool = True, |
| ) -> Image.Image: |
| """ |
| Plot the similar images. Accepts images or indexes. |
| |
| Args: |
| img (str or list): Path to the image or a list of paths to the images. |
| idx (int or list): Index of the image in the table or a list of indexes. |
| labels (bool): Whether to plot the labels or not. |
| limit (int): Number of results to return. Defaults to 25. |
| |
| Returns: |
| (PIL.Image): Image containing the plot. |
| |
| Example: |
| ```python |
| exp = Explorer() |
| exp.create_embeddings_table() |
| similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg') |
| ``` |
| """ |
| similar = self.get_similar(img, idx, limit, return_type="arrow") |
| if len(similar) == 0: |
| LOGGER.info("No results found.") |
| return None |
| img = plot_query_result(similar, plot_labels=labels) |
| return Image.fromarray(img) |
|
|
| def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame: |
| """ |
| Calculate the similarity index of all the images in the table. Here, the index will contain the data points that |
| are max_dist or closer to the image in the embedding space at a given index. |
| |
| Args: |
| max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
| top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running |
| vector search. Defaults: None. |
| force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
| |
| Returns: |
| (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns |
| include indices of similar images and their respective distances. |
| |
| Example: |
| ```python |
| exp = Explorer() |
| exp.create_embeddings_table() |
| sim_idx = exp.similarity_index() |
| ``` |
| """ |
| if self.table is None: |
| raise ValueError("Table is not created. Please create the table first.") |
| sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower() |
| if sim_idx_table_name in self.connection.table_names() and not force: |
| LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.") |
| return self.connection.open_table(sim_idx_table_name).to_pandas() |
|
|
| if top_k and not (1.0 >= top_k >= 0.0): |
| raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}") |
| if max_dist < 0.0: |
| raise ValueError(f"max_dist must be greater than 0. Got {max_dist}") |
|
|
| top_k = int(top_k * len(self.table)) if top_k else len(self.table) |
| top_k = max(top_k, 1) |
| features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict() |
| im_files = features["im_file"] |
| embeddings = features["vector"] |
|
|
| sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite") |
|
|
| def _yield_sim_idx(): |
| """Generates a dataframe with similarity indices and distances for images.""" |
| for i in tqdm(range(len(embeddings))): |
| sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}") |
| yield [ |
| { |
| "idx": i, |
| "im_file": im_files[i], |
| "count": len(sim_idx), |
| "sim_im_files": sim_idx["im_file"].tolist(), |
| } |
| ] |
|
|
| sim_table.add(_yield_sim_idx()) |
| self.sim_index = sim_table |
| return sim_table.to_pandas() |
|
|
| def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image: |
| """ |
| Plot the similarity index of all the images in the table. Here, the index will contain the data points that are |
| max_dist or closer to the image in the embedding space at a given index. |
| |
| Args: |
| max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
| top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when |
| running vector search. Defaults to 0.01. |
| force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
| |
| Returns: |
| (PIL.Image): Image containing the plot. |
| |
| Example: |
| ```python |
| exp = Explorer() |
| exp.create_embeddings_table() |
| |
| similarity_idx_plot = exp.plot_similarity_index() |
| similarity_idx_plot.show() # view image preview |
| similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file |
| ``` |
| """ |
| sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) |
| sim_count = sim_idx["count"].tolist() |
| sim_count = np.array(sim_count) |
|
|
| indices = np.arange(len(sim_count)) |
|
|
| |
| plt.bar(indices, sim_count) |
|
|
| |
| plt.xlabel("data idx") |
| plt.ylabel("Count") |
| plt.title("Similarity Count") |
| buffer = BytesIO() |
| plt.savefig(buffer, format="png") |
| buffer.seek(0) |
|
|
| |
| return Image.fromarray(np.array(Image.open(buffer))) |
|
|
| def _check_imgs_or_idxs( |
| self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]] |
| ) -> List[np.ndarray]: |
| if img is None and idx is None: |
| raise ValueError("Either img or idx must be provided.") |
| if img is not None and idx is not None: |
| raise ValueError("Only one of img or idx must be provided.") |
| if idx is not None: |
| idx = idx if isinstance(idx, list) else [idx] |
| img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"] |
|
|
| return img if isinstance(img, list) else [img] |
|
|
| def ask_ai(self, query): |
| """ |
| Ask AI a question. |
| |
| Args: |
| query (str): Question to ask. |
| |
| Returns: |
| (pandas.DataFrame): A dataframe containing filtered results to the SQL query. |
| |
| Example: |
| ```python |
| exp = Explorer() |
| exp.create_embeddings_table() |
| answer = exp.ask_ai('Show images with 1 person and 2 dogs') |
| ``` |
| """ |
| result = prompt_sql_query(query) |
| try: |
| df = self.sql_query(result) |
| except Exception as e: |
| LOGGER.error("AI generated query is not valid. Please try again with a different prompt") |
| LOGGER.error(e) |
| return None |
| return df |
|
|
| def visualize(self, result): |
| """ |
| Visualize the results of a query. TODO. |
| |
| Args: |
| result (pyarrow.Table): Table containing the results of a query. |
| """ |
| pass |
|
|
| def generate_report(self, result): |
| """ |
| Generate a report of the dataset. |
| |
| TODO |
| """ |
| pass |
|
|