| | |
| |
|
| | from io import BytesIO |
| | from pathlib import Path |
| | from typing import Any, List, Tuple, Union |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from matplotlib import pyplot as plt |
| | from pandas import DataFrame |
| | from tqdm import tqdm |
| |
|
| | from ultralytics.data.augment import Format |
| | from ultralytics.data.dataset import YOLODataset |
| | from ultralytics.data.utils import check_det_dataset |
| | from ultralytics.models.yolo.model import YOLO |
| | from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks, USER_CONFIG_DIR |
| | from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch |
| |
|
| |
|
| | class ExplorerDataset(YOLODataset): |
| | def __init__(self, *args, data: dict = None, **kwargs) -> None: |
| | super().__init__(*args, data=data, **kwargs) |
| |
|
| | def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]: |
| | """Loads 1 image from dataset index 'i' without any resize ops.""" |
| | im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] |
| | if im is None: |
| | if fn.exists(): |
| | im = np.load(fn) |
| | else: |
| | im = cv2.imread(f) |
| | if im is None: |
| | raise FileNotFoundError(f"Image Not Found {f}") |
| | h0, w0 = im.shape[:2] |
| | return im, (h0, w0), im.shape[:2] |
| |
|
| | return self.ims[i], self.im_hw0[i], self.im_hw[i] |
| |
|
| | def build_transforms(self, hyp: IterableSimpleNamespace = None): |
| | """Creates transforms for dataset images without resizing.""" |
| | return Format( |
| | bbox_format="xyxy", |
| | normalize=False, |
| | return_mask=self.use_segments, |
| | return_keypoint=self.use_keypoints, |
| | batch_idx=True, |
| | mask_ratio=hyp.mask_ratio, |
| | mask_overlap=hyp.overlap_mask, |
| | ) |
| |
|
| |
|
| | class Explorer: |
| | def __init__( |
| | self, |
| | data: Union[str, Path] = "coco128.yaml", |
| | model: str = "yolov8n.pt", |
| | uri: str = USER_CONFIG_DIR / "explorer", |
| | ) -> None: |
| | |
| | checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"]) |
| | import lancedb |
| |
|
| | self.connection = lancedb.connect(uri) |
| | self.table_name = Path(data).name.lower() + "_" + model.lower() |
| | self.sim_idx_base_name = ( |
| | f"{self.table_name}_sim_idx".lower() |
| | ) |
| | self.model = YOLO(model) |
| | self.data = data |
| | self.choice_set = None |
| |
|
| | self.table = None |
| | self.progress = 0 |
| |
|
| | def create_embeddings_table(self, force: bool = False, split: str = "train") -> None: |
| | """ |
| | Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it |
| | already exists. Pass force=True to overwrite the existing table. |
| | |
| | Args: |
| | force (bool): Whether to overwrite the existing table or not. Defaults to False. |
| | split (str): Split of the dataset to use. Defaults to 'train'. |
| | |
| | Example: |
| | ```python |
| | exp = Explorer() |
| | exp.create_embeddings_table() |
| | ``` |
| | """ |
| | if self.table is not None and not force: |
| | LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.") |
| | return |
| | if self.table_name in self.connection.table_names() and not force: |
| | LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.") |
| | self.table = self.connection.open_table(self.table_name) |
| | self.progress = 1 |
| | return |
| | if self.data is None: |
| | raise ValueError("Data must be provided to create embeddings table") |
| |
|
| | data_info = check_det_dataset(self.data) |
| | if split not in data_info: |
| | raise ValueError( |
| | f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}" |
| | ) |
| |
|
| | choice_set = data_info[split] |
| | choice_set = choice_set if isinstance(choice_set, list) else [choice_set] |
| | self.choice_set = choice_set |
| | dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task) |
| |
|
| | |
| | batch = dataset[0] |
| | vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0] |
| | table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite") |
| | table.add( |
| | self._yield_batches( |
| | dataset, |
| | data_info, |
| | self.model, |
| | exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"], |
| | ) |
| | ) |
| |
|
| | self.table = table |
| |
|
| | def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]): |
| | """Generates batches of data for embedding, excluding specified keys.""" |
| | for i in tqdm(range(len(dataset))): |
| | self.progress = float(i + 1) / len(dataset) |
| | batch = dataset[i] |
| | for k in exclude_keys: |
| | batch.pop(k, None) |
| | batch = sanitize_batch(batch, data_info) |
| | batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist() |
| | yield [batch] |
| |
|
| | def query( |
| | self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25 |
| | ) -> Any: |
| | """ |
| | Query the table for similar images. Accepts a single image or a list of images. |
| | |
| | Args: |
| | imgs (str or list): Path to the image or a list of paths to the images. |
| | limit (int): Number of results to return. |
| | |
| | Returns: |
| | (pyarrow.Table): An arrow table containing the results. Supports converting to: |
| | - pandas dataframe: `result.to_pandas()` |
| | - dict of lists: `result.to_pydict()` |
| | |
| | Example: |
| | ```python |
| | exp = Explorer() |
| | exp.create_embeddings_table() |
| | similar = exp.query(img='https://ultralytics.com/images/zidane.jpg') |
| | ``` |
| | """ |
| | if self.table is None: |
| | raise ValueError("Table is not created. Please create the table first.") |
| | if isinstance(imgs, str): |
| | imgs = [imgs] |
| | assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}" |
| | embeds = self.model.embed(imgs) |
| | |
| | embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() |
| | return self.table.search(embeds).limit(limit).to_arrow() |
| |
|
| | def sql_query( |
| | self, query: str, return_type: str = "pandas" |
| | ) -> Union[DataFrame, Any, None]: |
| | """ |
| | Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. |
| | |
| | Args: |
| | query (str): SQL query to run. |
| | return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
| | |
| | Returns: |
| | (pyarrow.Table): An arrow table containing the results. |
| | |
| | Example: |
| | ```python |
| | exp = Explorer() |
| | exp.create_embeddings_table() |
| | query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
| | result = exp.sql_query(query) |
| | ``` |
| | """ |
| | assert return_type in { |
| | "pandas", |
| | "arrow", |
| | }, f"Return type should be either `pandas` or `arrow`, but got {return_type}" |
| | import duckdb |
| |
|
| | if self.table is None: |
| | raise ValueError("Table is not created. Please create the table first.") |
| |
|
| | |
| | table = self.table.to_arrow() |
| | if not query.startswith("SELECT") and not query.startswith("WHERE"): |
| | raise ValueError( |
| | f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}" |
| | ) |
| | if query.startswith("WHERE"): |
| | query = f"SELECT * FROM 'table' {query}" |
| | LOGGER.info(f"Running query: {query}") |
| |
|
| | rs = duckdb.sql(query) |
| | if return_type == "arrow": |
| | return rs.arrow() |
| | elif return_type == "pandas": |
| | return rs.df() |
| |
|
| | def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: |
| | """ |
| | Plot the results of a SQL-Like query on the table. |
| | Args: |
| | query (str): SQL query to run. |
| | labels (bool): Whether to plot the labels or not. |
| | |
| | Returns: |
| | (PIL.Image): Image containing the plot. |
| | |
| | Example: |
| | ```python |
| | exp = Explorer() |
| | exp.create_embeddings_table() |
| | query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
| | result = exp.plot_sql_query(query) |
| | ``` |
| | """ |
| | result = self.sql_query(query, return_type="arrow") |
| | if len(result) == 0: |
| | LOGGER.info("No results found.") |
| | return None |
| | img = plot_query_result(result, plot_labels=labels) |
| | return Image.fromarray(img) |
| |
|
| | def get_similar( |
| | self, |
| | img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
| | idx: Union[int, List[int]] = None, |
| | limit: int = 25, |
| | return_type: str = "pandas", |
| | ) -> Union[DataFrame, Any]: |
| | """ |
| | Query the table for similar images. Accepts a single image or a list of images. |
| | |
| | Args: |
| | img (str or list): Path to the image or a list of paths to the images. |
| | idx (int or list): Index of the image in the table or a list of indexes. |
| | limit (int): Number of results to return. Defaults to 25. |
| | return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
| | |
| | Returns: |
| | (pandas.DataFrame): A dataframe containing the results. |
| | |
| | Example: |
| | ```python |
| | exp = Explorer() |
| | exp.create_embeddings_table() |
| | similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg') |
| | ``` |
| | """ |
| | assert return_type in { |
| | "pandas", |
| | "arrow", |
| | }, f"Return type should be either `pandas` or `arrow`, but got {return_type}" |
| | img = self._check_imgs_or_idxs(img, idx) |
| | similar = self.query(img, limit=limit) |
| |
|
| | if return_type == "arrow": |
| | return similar |
| | elif return_type == "pandas": |
| | return similar.to_pandas() |
| |
|
| | def plot_similar( |
| | self, |
| | img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
| | idx: Union[int, List[int]] = None, |
| | limit: int = 25, |
| | labels: bool = True, |
| | ) -> Image.Image: |
| | """ |
| | Plot the similar images. Accepts images or indexes. |
| | |
| | Args: |
| | img (str or list): Path to the image or a list of paths to the images. |
| | idx (int or list): Index of the image in the table or a list of indexes. |
| | labels (bool): Whether to plot the labels or not. |
| | limit (int): Number of results to return. Defaults to 25. |
| | |
| | Returns: |
| | (PIL.Image): Image containing the plot. |
| | |
| | Example: |
| | ```python |
| | exp = Explorer() |
| | exp.create_embeddings_table() |
| | similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg') |
| | ``` |
| | """ |
| | similar = self.get_similar(img, idx, limit, return_type="arrow") |
| | if len(similar) == 0: |
| | LOGGER.info("No results found.") |
| | return None |
| | img = plot_query_result(similar, plot_labels=labels) |
| | return Image.fromarray(img) |
| |
|
| | def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame: |
| | """ |
| | Calculate the similarity index of all the images in the table. Here, the index will contain the data points that |
| | are max_dist or closer to the image in the embedding space at a given index. |
| | |
| | Args: |
| | max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
| | top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running |
| | vector search. Defaults: None. |
| | force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
| | |
| | Returns: |
| | (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns |
| | include indices of similar images and their respective distances. |
| | |
| | Example: |
| | ```python |
| | exp = Explorer() |
| | exp.create_embeddings_table() |
| | sim_idx = exp.similarity_index() |
| | ``` |
| | """ |
| | if self.table is None: |
| | raise ValueError("Table is not created. Please create the table first.") |
| | sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower() |
| | if sim_idx_table_name in self.connection.table_names() and not force: |
| | LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.") |
| | return self.connection.open_table(sim_idx_table_name).to_pandas() |
| |
|
| | if top_k and not (1.0 >= top_k >= 0.0): |
| | raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}") |
| | if max_dist < 0.0: |
| | raise ValueError(f"max_dist must be greater than 0. Got {max_dist}") |
| |
|
| | top_k = int(top_k * len(self.table)) if top_k else len(self.table) |
| | top_k = max(top_k, 1) |
| | features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict() |
| | im_files = features["im_file"] |
| | embeddings = features["vector"] |
| |
|
| | sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite") |
| |
|
| | def _yield_sim_idx(): |
| | """Generates a dataframe with similarity indices and distances for images.""" |
| | for i in tqdm(range(len(embeddings))): |
| | sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}") |
| | yield [ |
| | { |
| | "idx": i, |
| | "im_file": im_files[i], |
| | "count": len(sim_idx), |
| | "sim_im_files": sim_idx["im_file"].tolist(), |
| | } |
| | ] |
| |
|
| | sim_table.add(_yield_sim_idx()) |
| | self.sim_index = sim_table |
| | return sim_table.to_pandas() |
| |
|
| | def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image: |
| | """ |
| | Plot the similarity index of all the images in the table. Here, the index will contain the data points that are |
| | max_dist or closer to the image in the embedding space at a given index. |
| | |
| | Args: |
| | max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
| | top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when |
| | running vector search. Defaults to 0.01. |
| | force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
| | |
| | Returns: |
| | (PIL.Image): Image containing the plot. |
| | |
| | Example: |
| | ```python |
| | exp = Explorer() |
| | exp.create_embeddings_table() |
| | |
| | similarity_idx_plot = exp.plot_similarity_index() |
| | similarity_idx_plot.show() # view image preview |
| | similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file |
| | ``` |
| | """ |
| | sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) |
| | sim_count = sim_idx["count"].tolist() |
| | sim_count = np.array(sim_count) |
| |
|
| | indices = np.arange(len(sim_count)) |
| |
|
| | |
| | plt.bar(indices, sim_count) |
| |
|
| | |
| | plt.xlabel("data idx") |
| | plt.ylabel("Count") |
| | plt.title("Similarity Count") |
| | buffer = BytesIO() |
| | plt.savefig(buffer, format="png") |
| | buffer.seek(0) |
| |
|
| | |
| | return Image.fromarray(np.array(Image.open(buffer))) |
| |
|
| | def _check_imgs_or_idxs( |
| | self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]] |
| | ) -> List[np.ndarray]: |
| | if img is None and idx is None: |
| | raise ValueError("Either img or idx must be provided.") |
| | if img is not None and idx is not None: |
| | raise ValueError("Only one of img or idx must be provided.") |
| | if idx is not None: |
| | idx = idx if isinstance(idx, list) else [idx] |
| | img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"] |
| |
|
| | return img if isinstance(img, list) else [img] |
| |
|
| | def ask_ai(self, query): |
| | """ |
| | Ask AI a question. |
| | |
| | Args: |
| | query (str): Question to ask. |
| | |
| | Returns: |
| | (pandas.DataFrame): A dataframe containing filtered results to the SQL query. |
| | |
| | Example: |
| | ```python |
| | exp = Explorer() |
| | exp.create_embeddings_table() |
| | answer = exp.ask_ai('Show images with 1 person and 2 dogs') |
| | ``` |
| | """ |
| | result = prompt_sql_query(query) |
| | try: |
| | df = self.sql_query(result) |
| | except Exception as e: |
| | LOGGER.error("AI generated query is not valid. Please try again with a different prompt") |
| | LOGGER.error(e) |
| | return None |
| | return df |
| |
|
| | def visualize(self, result): |
| | """ |
| | Visualize the results of a query. TODO. |
| | |
| | Args: |
| | result (pyarrow.Table): Table containing the results of a query. |
| | """ |
| | pass |
| |
|
| | def generate_report(self, result): |
| | """ |
| | Generate a report of the dataset. |
| | |
| | TODO |
| | """ |
| | pass |
| |
|