|
|
|
|
| from io import BytesIO
|
| from pathlib import Path
|
| from typing import Any, List, Tuple, Union
|
|
|
| import cv2
|
| import numpy as np
|
| import torch
|
| from PIL import Image
|
| from matplotlib import pyplot as plt
|
| from pandas import DataFrame
|
| from tqdm import tqdm
|
|
|
| from ultralytics.data.augment import Format
|
| from ultralytics.data.dataset import YOLODataset
|
| from ultralytics.data.utils import check_det_dataset
|
| from ultralytics.models.yolo.model import YOLO
|
| from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks, USER_CONFIG_DIR
|
| from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
|
|
|
|
|
| class ExplorerDataset(YOLODataset):
|
| def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
| super().__init__(*args, data=data, **kwargs)
|
|
|
| def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
|
| """Loads 1 image from dataset index 'i' without any resize ops."""
|
| im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
| if im is None:
|
| if fn.exists():
|
| im = np.load(fn)
|
| else:
|
| im = cv2.imread(f)
|
| if im is None:
|
| raise FileNotFoundError(f"Image Not Found {f}")
|
| h0, w0 = im.shape[:2]
|
| return im, (h0, w0), im.shape[:2]
|
|
|
| return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
|
|
| def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
| """Creates transforms for dataset images without resizing."""
|
| return Format(
|
| bbox_format="xyxy",
|
| normalize=False,
|
| return_mask=self.use_segments,
|
| return_keypoint=self.use_keypoints,
|
| batch_idx=True,
|
| mask_ratio=hyp.mask_ratio,
|
| mask_overlap=hyp.overlap_mask,
|
| )
|
|
|
|
|
| class Explorer:
|
| def __init__(
|
| self,
|
| data: Union[str, Path] = "coco128.yaml",
|
| model: str = "yolov8n.pt",
|
| uri: str = USER_CONFIG_DIR / "explorer",
|
| ) -> None:
|
|
|
| checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"])
|
| import lancedb
|
|
|
| self.connection = lancedb.connect(uri)
|
| self.table_name = Path(data).name.lower() + "_" + model.lower()
|
| self.sim_idx_base_name = (
|
| f"{self.table_name}_sim_idx".lower()
|
| )
|
| self.model = YOLO(model)
|
| self.data = data
|
| self.choice_set = None
|
|
|
| self.table = None
|
| self.progress = 0
|
|
|
| def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
|
| """
|
| Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
| already exists. Pass force=True to overwrite the existing table.
|
|
|
| Args:
|
| force (bool): Whether to overwrite the existing table or not. Defaults to False.
|
| split (str): Split of the dataset to use. Defaults to 'train'.
|
|
|
| Example:
|
| ```python
|
| exp = Explorer()
|
| exp.create_embeddings_table()
|
| ```
|
| """
|
| if self.table is not None and not force:
|
| LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
|
| return
|
| if self.table_name in self.connection.table_names() and not force:
|
| LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
|
| self.table = self.connection.open_table(self.table_name)
|
| self.progress = 1
|
| return
|
| if self.data is None:
|
| raise ValueError("Data must be provided to create embeddings table")
|
|
|
| data_info = check_det_dataset(self.data)
|
| if split not in data_info:
|
| raise ValueError(
|
| f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
|
| )
|
|
|
| choice_set = data_info[split]
|
| choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
|
| self.choice_set = choice_set
|
| dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
|
|
|
|
|
| batch = dataset[0]
|
| vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
|
| table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
|
| table.add(
|
| self._yield_batches(
|
| dataset,
|
| data_info,
|
| self.model,
|
| exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
|
| )
|
| )
|
|
|
| self.table = table
|
|
|
| def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
|
| """Generates batches of data for embedding, excluding specified keys."""
|
| for i in tqdm(range(len(dataset))):
|
| self.progress = float(i + 1) / len(dataset)
|
| batch = dataset[i]
|
| for k in exclude_keys:
|
| batch.pop(k, None)
|
| batch = sanitize_batch(batch, data_info)
|
| batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
|
| yield [batch]
|
|
|
| def query(
|
| self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
|
| ) -> Any:
|
| """
|
| Query the table for similar images. Accepts a single image or a list of images.
|
|
|
| Args:
|
| imgs (str or list): Path to the image or a list of paths to the images.
|
| limit (int): Number of results to return.
|
|
|
| Returns:
|
| (pyarrow.Table): An arrow table containing the results. Supports converting to:
|
| - pandas dataframe: `result.to_pandas()`
|
| - dict of lists: `result.to_pydict()`
|
|
|
| Example:
|
| ```python
|
| exp = Explorer()
|
| exp.create_embeddings_table()
|
| similar = exp.query(img='https://ultralytics.com/images/zidane.jpg')
|
| ```
|
| """
|
| if self.table is None:
|
| raise ValueError("Table is not created. Please create the table first.")
|
| if isinstance(imgs, str):
|
| imgs = [imgs]
|
| assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
|
| embeds = self.model.embed(imgs)
|
|
|
| embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
| return self.table.search(embeds).limit(limit).to_arrow()
|
|
|
| def sql_query(
|
| self, query: str, return_type: str = "pandas"
|
| ) -> Union[DataFrame, Any, None]:
|
| """
|
| Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
|
|
| Args:
|
| query (str): SQL query to run.
|
| return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
|
|
| Returns:
|
| (pyarrow.Table): An arrow table containing the results.
|
|
|
| Example:
|
| ```python
|
| exp = Explorer()
|
| exp.create_embeddings_table()
|
| query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
| result = exp.sql_query(query)
|
| ```
|
| """
|
| assert return_type in {
|
| "pandas",
|
| "arrow",
|
| }, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
| import duckdb
|
|
|
| if self.table is None:
|
| raise ValueError("Table is not created. Please create the table first.")
|
|
|
|
|
| table = self.table.to_arrow()
|
| if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
| raise ValueError(
|
| f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
|
| )
|
| if query.startswith("WHERE"):
|
| query = f"SELECT * FROM 'table' {query}"
|
| LOGGER.info(f"Running query: {query}")
|
|
|
| rs = duckdb.sql(query)
|
| if return_type == "arrow":
|
| return rs.arrow()
|
| elif return_type == "pandas":
|
| return rs.df()
|
|
|
| def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
| """
|
| Plot the results of a SQL-Like query on the table.
|
| Args:
|
| query (str): SQL query to run.
|
| labels (bool): Whether to plot the labels or not.
|
|
|
| Returns:
|
| (PIL.Image): Image containing the plot.
|
|
|
| Example:
|
| ```python
|
| exp = Explorer()
|
| exp.create_embeddings_table()
|
| query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
| result = exp.plot_sql_query(query)
|
| ```
|
| """
|
| result = self.sql_query(query, return_type="arrow")
|
| if len(result) == 0:
|
| LOGGER.info("No results found.")
|
| return None
|
| img = plot_query_result(result, plot_labels=labels)
|
| return Image.fromarray(img)
|
|
|
| def get_similar(
|
| self,
|
| img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
| idx: Union[int, List[int]] = None,
|
| limit: int = 25,
|
| return_type: str = "pandas",
|
| ) -> Union[DataFrame, Any]:
|
| """
|
| Query the table for similar images. Accepts a single image or a list of images.
|
|
|
| Args:
|
| img (str or list): Path to the image or a list of paths to the images.
|
| idx (int or list): Index of the image in the table or a list of indexes.
|
| limit (int): Number of results to return. Defaults to 25.
|
| return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
|
|
| Returns:
|
| (pandas.DataFrame): A dataframe containing the results.
|
|
|
| Example:
|
| ```python
|
| exp = Explorer()
|
| exp.create_embeddings_table()
|
| similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
| ```
|
| """
|
| assert return_type in {
|
| "pandas",
|
| "arrow",
|
| }, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
| img = self._check_imgs_or_idxs(img, idx)
|
| similar = self.query(img, limit=limit)
|
|
|
| if return_type == "arrow":
|
| return similar
|
| elif return_type == "pandas":
|
| return similar.to_pandas()
|
|
|
| def plot_similar(
|
| self,
|
| img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
| idx: Union[int, List[int]] = None,
|
| limit: int = 25,
|
| labels: bool = True,
|
| ) -> Image.Image:
|
| """
|
| Plot the similar images. Accepts images or indexes.
|
|
|
| Args:
|
| img (str or list): Path to the image or a list of paths to the images.
|
| idx (int or list): Index of the image in the table or a list of indexes.
|
| labels (bool): Whether to plot the labels or not.
|
| limit (int): Number of results to return. Defaults to 25.
|
|
|
| Returns:
|
| (PIL.Image): Image containing the plot.
|
|
|
| Example:
|
| ```python
|
| exp = Explorer()
|
| exp.create_embeddings_table()
|
| similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
|
| ```
|
| """
|
| similar = self.get_similar(img, idx, limit, return_type="arrow")
|
| if len(similar) == 0:
|
| LOGGER.info("No results found.")
|
| return None
|
| img = plot_query_result(similar, plot_labels=labels)
|
| return Image.fromarray(img)
|
|
|
| def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
|
| """
|
| Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
| are max_dist or closer to the image in the embedding space at a given index.
|
|
|
| Args:
|
| max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
| top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
|
| vector search. Defaults: None.
|
| force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
|
|
| Returns:
|
| (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
|
| include indices of similar images and their respective distances.
|
|
|
| Example:
|
| ```python
|
| exp = Explorer()
|
| exp.create_embeddings_table()
|
| sim_idx = exp.similarity_index()
|
| ```
|
| """
|
| if self.table is None:
|
| raise ValueError("Table is not created. Please create the table first.")
|
| sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
|
| if sim_idx_table_name in self.connection.table_names() and not force:
|
| LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
|
| return self.connection.open_table(sim_idx_table_name).to_pandas()
|
|
|
| if top_k and not (1.0 >= top_k >= 0.0):
|
| raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
|
| if max_dist < 0.0:
|
| raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
|
|
|
| top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
| top_k = max(top_k, 1)
|
| features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
|
| im_files = features["im_file"]
|
| embeddings = features["vector"]
|
|
|
| sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
|
|
|
| def _yield_sim_idx():
|
| """Generates a dataframe with similarity indices and distances for images."""
|
| for i in tqdm(range(len(embeddings))):
|
| sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
|
| yield [
|
| {
|
| "idx": i,
|
| "im_file": im_files[i],
|
| "count": len(sim_idx),
|
| "sim_im_files": sim_idx["im_file"].tolist(),
|
| }
|
| ]
|
|
|
| sim_table.add(_yield_sim_idx())
|
| self.sim_index = sim_table
|
| return sim_table.to_pandas()
|
|
|
| def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
|
| """
|
| Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
|
| max_dist or closer to the image in the embedding space at a given index.
|
|
|
| Args:
|
| max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
| top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
|
| running vector search. Defaults to 0.01.
|
| force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
|
|
| Returns:
|
| (PIL.Image): Image containing the plot.
|
|
|
| Example:
|
| ```python
|
| exp = Explorer()
|
| exp.create_embeddings_table()
|
|
|
| similarity_idx_plot = exp.plot_similarity_index()
|
| similarity_idx_plot.show() # view image preview
|
| similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
|
| ```
|
| """
|
| sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
| sim_count = sim_idx["count"].tolist()
|
| sim_count = np.array(sim_count)
|
|
|
| indices = np.arange(len(sim_count))
|
|
|
|
|
| plt.bar(indices, sim_count)
|
|
|
|
|
| plt.xlabel("data idx")
|
| plt.ylabel("Count")
|
| plt.title("Similarity Count")
|
| buffer = BytesIO()
|
| plt.savefig(buffer, format="png")
|
| buffer.seek(0)
|
|
|
|
|
| return Image.fromarray(np.array(Image.open(buffer)))
|
|
|
| def _check_imgs_or_idxs(
|
| self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
|
| ) -> List[np.ndarray]:
|
| if img is None and idx is None:
|
| raise ValueError("Either img or idx must be provided.")
|
| if img is not None and idx is not None:
|
| raise ValueError("Only one of img or idx must be provided.")
|
| if idx is not None:
|
| idx = idx if isinstance(idx, list) else [idx]
|
| img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
|
|
|
| return img if isinstance(img, list) else [img]
|
|
|
| def ask_ai(self, query):
|
| """
|
| Ask AI a question.
|
|
|
| Args:
|
| query (str): Question to ask.
|
|
|
| Returns:
|
| (pandas.DataFrame): A dataframe containing filtered results to the SQL query.
|
|
|
| Example:
|
| ```python
|
| exp = Explorer()
|
| exp.create_embeddings_table()
|
| answer = exp.ask_ai('Show images with 1 person and 2 dogs')
|
| ```
|
| """
|
| result = prompt_sql_query(query)
|
| try:
|
| df = self.sql_query(result)
|
| except Exception as e:
|
| LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
| LOGGER.error(e)
|
| return None
|
| return df
|
|
|
| def visualize(self, result):
|
| """
|
| Visualize the results of a query. TODO.
|
|
|
| Args:
|
| result (pyarrow.Table): Table containing the results of a query.
|
| """
|
| pass
|
|
|
| def generate_report(self, result):
|
| """
|
| Generate a report of the dataset.
|
|
|
| TODO
|
| """
|
| pass
|
|
|