Spaces:
Running
Running
| import abc | |
| import logging | |
| import math | |
| import time | |
| from pathlib import Path | |
| from typing import TypeVar, Generic, cast, Any | |
| import numpy as np | |
| import numpy.typing as npt | |
| from tqdm import tqdm | |
| import faiss | |
| from faiss import IndexIVF, Index | |
| logger = logging.getLogger(__name__) | |
| T = TypeVar("T", bound=Index) | |
| NumpyArray = npt.NDArray[np.float32] | |
| class FaissFeatureIndex(Generic[T], abc.ABC): | |
| def __init__(self, index: T) -> None: | |
| self._index = index | |
| def save(self, filepath: Path, rewrite: bool = False) -> None: | |
| if filepath.exists() and not rewrite: | |
| raise FileExistsError(f"index already exists by path {filepath}") | |
| faiss.write_index(self._index, str(filepath)) | |
| class FaissRetrievableFeatureIndex(FaissFeatureIndex[Index], abc.ABC): | |
| """retrieve voice feature vectors by faiss index""" | |
| def __init__(self, index: T, ratio: float, n_nearest_vectors: int) -> None: | |
| super().__init__(index=index) | |
| if index.metric_type != self.supported_distance: | |
| raise ValueError(f"index metric type {index.metric_type=} is unsupported {self.supported_distance=}") | |
| if 1 > n_nearest_vectors: | |
| raise ValueError("n-retrieval-vectors must be gte 1") | |
| self._n_nearest = n_nearest_vectors | |
| if 0 > ratio > 1: | |
| raise ValueError(f"{ratio=} must be in rage (0, 1)") | |
| self._ratio = ratio | |
| def supported_distance(self) -> Any: | |
| raise NotImplementedError | |
| def _weight_nearest_vectors(self, nearest_vectors: NumpyArray, scores: NumpyArray) -> NumpyArray: | |
| raise NotImplementedError | |
| def retriv(self, features: NumpyArray) -> NumpyArray: | |
| # use method search_and_reconstruct instead of recreating the whole matrix | |
| scores, _, nearest_vectors = self._index.search_and_reconstruct(features, k=self._n_nearest) | |
| weighted_nearest_vectors = self._weight_nearest_vectors(nearest_vectors, scores) | |
| retriv_vector = (1 - self._ratio) * features + self._ratio * weighted_nearest_vectors | |
| return retriv_vector | |
| class FaissRVCRetrievableFeatureIndex(FaissRetrievableFeatureIndex): | |
| """ | |
| retrieve voice encoded features with algorith from RVC repository | |
| https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI | |
| """ | |
| def supported_distance(self) -> Any: | |
| return faiss.METRIC_L2 | |
| def _weight_nearest_vectors(self, nearest_vectors: NumpyArray, scores: NumpyArray) -> NumpyArray: | |
| """ | |
| magic code from original RVC | |
| https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/86ed98aacaa8b2037aad795abd11cdca122cf39f/vc_infer_pipeline.py#L213C18-L213C19 | |
| nearest_vectors dim (n_nearest, vector_dim) | |
| scores dim (num_vectors, n_nearest) | |
| """ | |
| logger.debug("shape: nv=%s sc=%s", nearest_vectors.shape, scores.shape) | |
| weight = np.square(1 / scores) | |
| weight /= weight.sum(axis=1, keepdims=True) | |
| weight = np.expand_dims(weight, axis=2) | |
| weighted_nearest_vectors = np.sum(nearest_vectors * weight, axis=1) | |
| logger.debug( | |
| "shape: nv=%s weight=%s weight_nearest=%s", | |
| nearest_vectors.shape, | |
| weight.shape, | |
| weighted_nearest_vectors.shape, | |
| ) | |
| return cast(NumpyArray, weighted_nearest_vectors) | |
| class FaissIVFTrainableFeatureIndex(FaissFeatureIndex[IndexIVF]): | |
| """IVF faiss index that can train and add feature vectors""" | |
| def __init__(self, index: IndexIVF, batch_size: int) -> None: | |
| super().__init__(index=index) | |
| self._batch_size = batch_size | |
| def _trained_index(self) -> IndexIVF: | |
| if not self._index.is_trained: | |
| raise RuntimeError("index needs to be trained first") | |
| return self._index | |
| def _not_trained_index(self) -> IndexIVF: | |
| if self._index.is_trained: | |
| raise RuntimeError("index is already trained") | |
| return self._index | |
| def _batch_count(self, feature_matrix: NumpyArray) -> int: | |
| return math.ceil(feature_matrix.shape[0] / self._batch_size) | |
| def _split_matrix_by_batch(self, feature_matrix: NumpyArray) -> list[NumpyArray]: | |
| return np.array_split(feature_matrix, indices_or_sections=self._batch_count(feature_matrix), axis=0) | |
| def _train_index(self, train_feature_matrix: NumpyArray) -> None: | |
| start = time.monotonic() | |
| self._not_trained_index.train(train_feature_matrix) | |
| took = time.monotonic() - start | |
| logger.info("index is trained. Took %.2f seconds", took) | |
| def add_to_index(self, feature_matrix: NumpyArray) -> None: | |
| n_batches = self._batch_count(feature_matrix) | |
| logger.info("adding %s batches to index", n_batches) | |
| start = time.monotonic() | |
| for batch in tqdm(self._split_matrix_by_batch(feature_matrix), total=n_batches): | |
| self._trained_index.add(batch) | |
| took = time.monotonic() - start | |
| logger.info("all batches added. Took %.2f seconds", took) | |
| def add_with_train(self, feature_matrix: NumpyArray) -> None: | |
| self._train_index(feature_matrix) | |
| self.add_to_index(feature_matrix) | |
| class FaissIVFFlatTrainableFeatureIndexBuilder: | |
| def __init__(self, batch_size: int, distance: int) -> None: | |
| self._batch_size = batch_size | |
| self._distance = distance | |
| def _build_index(self, num_vectors: int, vector_dim: int) -> IndexIVF: | |
| n_ivf = min(int(16 * np.sqrt(num_vectors)), num_vectors // 39) | |
| factory_string = f"IVF{n_ivf},Flat" | |
| index = faiss.index_factory(vector_dim, factory_string, self._distance) | |
| logger.debug('faiss index built by string "%s" and dimension %s', factory_string, vector_dim) | |
| index_ivf = faiss.extract_index_ivf(index) | |
| index_ivf.nprobe = 1 | |
| return index | |
| def build(self, num_vectors: int, vector_dim: int) -> FaissIVFTrainableFeatureIndex: | |
| return FaissIVFTrainableFeatureIndex( | |
| index=self._build_index(num_vectors, vector_dim), | |
| batch_size=self._batch_size, | |
| ) | |
| def load_retrieve_index(filepath: Path, ratio: float, n_nearest_vectors: int) -> FaissRetrievableFeatureIndex: | |
| return FaissRVCRetrievableFeatureIndex( | |
| index=faiss.read_index(str(filepath)), ratio=ratio, n_nearest_vectors=n_nearest_vectors | |
| ) | |