Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Optional, Union | |
| import hydra | |
| from omegaconf import OmegaConf | |
| from relik.retriever.indexers.faiss import FaissDocumentIndex | |
| from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel | |
| from rich.pretty import pprint | |
| from relik.common.log import get_console_logger, get_logger | |
| from relik.common.upload import upload | |
| from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string | |
| from relik.inference.data.objects import EntitySpan, RelikOutput | |
| from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer | |
| from relik.inference.data.window.manager import WindowManager | |
| from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction | |
| from relik.reader.relik_reader import RelikReader | |
| from relik.retriever.data.utils import batch_generator | |
| from relik.retriever.indexers.base import BaseDocumentIndex | |
| from relik.retriever.pytorch_modules.model import GoldenRetriever | |
| logger = get_logger(__name__) | |
| console_logger = get_console_logger() | |
| class Relik: | |
| """ | |
| Relik main class. It is a wrapper around a retriever and a reader. | |
| Args: | |
| retriever (`Optional[GoldenRetriever]`, `optional`): | |
| The retriever to use. If `None`, a retriever will be instantiated from the | |
| provided `question_encoder`, `passage_encoder` and `document_index`. | |
| Defaults to `None`. | |
| question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): | |
| The question encoder to use. If `retriever` is `None`, a retriever will be | |
| instantiated from this parameter. Defaults to `None`. | |
| passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): | |
| The passage encoder to use. If `retriever` is `None`, a retriever will be | |
| instantiated from this parameter. Defaults to `None`. | |
| document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`): | |
| The document index to use. If `retriever` is `None`, a retriever will be | |
| instantiated from this parameter. Defaults to `None`. | |
| reader (`Optional[Union[str, RelikReader]]`, `optional`): | |
| The reader to use. If `None`, a reader will be instantiated from the | |
| provided `reader`. Defaults to `None`. | |
| retriever_device (`str`, `optional`, defaults to `cpu`): | |
| The device to use for the retriever. | |
| """ | |
| def __init__( | |
| self, | |
| retriever: GoldenRetriever | None = None, | |
| question_encoder: str | GoldenRetrieverModel | None = None, | |
| passage_encoder: str | GoldenRetrieverModel | None = None, | |
| document_index: str | BaseDocumentIndex | None = None, | |
| reader: str | RelikReader | None = None, | |
| device: str = "cpu", | |
| retriever_device: str | None = None, | |
| document_index_device: str | None = None, | |
| reader_device: str | None = None, | |
| precision: int = 32, | |
| retriever_precision: int | None = None, | |
| document_index_precision: int | None = None, | |
| reader_precision: int | None = None, | |
| reader_kwargs: dict | None = None, | |
| retriever_kwargs: dict | None = None, | |
| candidates_preprocessing_fn: str | Callable | None = None, | |
| top_k: int | None = None, | |
| window_size: int | None = None, | |
| window_stride: int | None = None, | |
| **kwargs, | |
| ) -> None: | |
| # retriever | |
| retriever_device = retriever_device or device | |
| document_index_device = document_index_device or device | |
| retriever_precision = retriever_precision or precision | |
| document_index_precision = document_index_precision or precision | |
| if retriever is None and question_encoder is None: | |
| raise ValueError( | |
| "Either `retriever` or `question_encoder` must be provided" | |
| ) | |
| if retriever is None: | |
| self.retriever_kwargs = dict( | |
| question_encoder=question_encoder, | |
| passage_encoder=passage_encoder, | |
| document_index=document_index, | |
| device=retriever_device, | |
| precision=retriever_precision, | |
| index_device=document_index_device, | |
| index_precision=document_index_precision, | |
| ) | |
| # overwrite default_retriever_kwargs with retriever_kwargs | |
| self.retriever_kwargs.update(retriever_kwargs or {}) | |
| retriever = GoldenRetriever(**self.retriever_kwargs) | |
| retriever.training = False | |
| retriever.eval() | |
| self.retriever = retriever | |
| # reader | |
| self.reader_device = reader_device or device | |
| self.reader_precision = reader_precision or precision | |
| self.reader_kwargs = reader_kwargs | |
| if isinstance(reader, str): | |
| reader_kwargs = reader_kwargs or {} | |
| reader = RelikReaderForSpanExtraction(reader, **reader_kwargs) | |
| self.reader = reader | |
| # windowization stuff | |
| self.tokenizer = SpacyTokenizer(language="en") | |
| self.window_manager: WindowManager | None = None | |
| # candidates preprocessing | |
| # TODO: maybe move this logic somewhere else | |
| candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x) | |
| if isinstance(candidates_preprocessing_fn, str): | |
| candidates_preprocessing_fn = get_callable_from_string( | |
| candidates_preprocessing_fn | |
| ) | |
| self.candidates_preprocessing_fn = candidates_preprocessing_fn | |
| # inference params | |
| self.top_k = top_k | |
| self.window_size = window_size | |
| self.window_stride = window_stride | |
| def __call__( | |
| self, | |
| text: Union[str, list], | |
| top_k: Optional[int] = None, | |
| window_size: Optional[int] = None, | |
| window_stride: Optional[int] = None, | |
| retriever_batch_size: Optional[int] = 32, | |
| reader_batch_size: Optional[int] = 32, | |
| return_also_windows: bool = False, | |
| **kwargs, | |
| ) -> Union[RelikOutput, list[RelikOutput]]: | |
| """ | |
| Annotate a text with entities. | |
| Args: | |
| text (`str` or `list`): | |
| The text to annotate. If a list is provided, each element of the list | |
| will be annotated separately. | |
| top_k (`int`, `optional`, defaults to `None`): | |
| The number of candidates to retrieve for each window. | |
| window_size (`int`, `optional`, defaults to `None`): | |
| The size of the window. If `None`, the whole text will be annotated. | |
| window_stride (`int`, `optional`, defaults to `None`): | |
| The stride of the window. If `None`, there will be no overlap between windows. | |
| retriever_batch_size (`int`, `optional`, defaults to `None`): | |
| The batch size to use for the retriever. The whole input is the batch for the retriever. | |
| reader_batch_size (`int`, `optional`, defaults to `None`): | |
| The batch size to use for the reader. The whole input is the batch for the reader. | |
| return_also_windows (`bool`, `optional`, defaults to `False`): | |
| Whether to return the windows in the output. | |
| **kwargs: | |
| Additional keyword arguments to pass to the retriever and the reader. | |
| Returns: | |
| `RelikOutput` or `list[RelikOutput]`: | |
| The annotated text. If a list was provided as input, a list of | |
| `RelikOutput` objects will be returned. | |
| """ | |
| if top_k is None: | |
| top_k = self.top_k or 100 | |
| if window_size is None: | |
| window_size = self.window_size | |
| if window_stride is None: | |
| window_stride = self.window_stride | |
| if isinstance(text, str): | |
| text = [text] | |
| if window_size is not None: | |
| if self.window_manager is None: | |
| self.window_manager = WindowManager(self.tokenizer) | |
| if window_size == "sentence": | |
| # todo: implement sentence windowizer | |
| raise NotImplementedError("Sentence windowizer not implemented yet") | |
| # if window_size < window_stride: | |
| # raise ValueError( | |
| # f"Window size ({window_size}) must be greater than window stride ({window_stride})" | |
| # ) | |
| # window generator | |
| windows = [ | |
| window | |
| for doc_id, t in enumerate(text) | |
| for window in self.window_manager.create_windows( | |
| t, | |
| window_size=window_size, | |
| stride=window_stride, | |
| doc_id=doc_id, | |
| ) | |
| ] | |
| # retrieve candidates first | |
| windows_candidates = [] | |
| # TODO: Move batching inside retriever | |
| for batch in batch_generator(windows, batch_size=retriever_batch_size): | |
| retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k) | |
| windows_candidates.extend( | |
| [[p.label for p in predictions] for predictions in retriever_out] | |
| ) | |
| # add passage to the windows | |
| for window, candidates in zip(windows, windows_candidates): | |
| window.window_candidates = [ | |
| self.candidates_preprocessing_fn(c) for c in candidates | |
| ] | |
| windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size) | |
| windows = self.window_manager.merge_windows(windows) | |
| # transform predictions into RelikOutput objects | |
| output = [] | |
| for w in windows: | |
| sample_output = RelikOutput( | |
| text=text[w.doc_id], | |
| labels=sorted( | |
| [ | |
| EntitySpan( | |
| start=ss, end=se, label=sl, text=text[w.doc_id][ss:se] | |
| ) | |
| for ss, se, sl in w.predicted_window_labels_chars | |
| ], | |
| key=lambda x: x.start, | |
| ), | |
| ) | |
| output.append(sample_output) | |
| if return_also_windows: | |
| for i, sample_output in enumerate(output): | |
| sample_output.windows = [w for w in windows if w.doc_id == i] | |
| # if only one text was provided, return a single RelikOutput object | |
| if len(output) == 1: | |
| return output[0] | |
| return output | |
| def from_pretrained( | |
| cls, | |
| model_name_or_dir: Union[str, os.PathLike], | |
| config_kwargs: Optional[Dict] = None, | |
| config_file_name: str = CONFIG_NAME, | |
| *args, | |
| **kwargs, | |
| ) -> "Relik": | |
| cache_dir = kwargs.pop("cache_dir", None) | |
| force_download = kwargs.pop("force_download", False) | |
| model_dir = from_cache( | |
| model_name_or_dir, | |
| filenames=[config_file_name], | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| ) | |
| config_path = model_dir / config_file_name | |
| if not config_path.exists(): | |
| raise FileNotFoundError( | |
| f"Model configuration file not found at {config_path}." | |
| ) | |
| # overwrite config with config_kwargs | |
| config = OmegaConf.load(config_path) | |
| if config_kwargs is not None: | |
| # TODO: check merging behavior | |
| config = OmegaConf.merge(config, OmegaConf.create(config_kwargs)) | |
| # do we want to print the config? I like it | |
| pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) | |
| # load relik from config | |
| relik = hydra.utils.instantiate(config, *args, **kwargs) | |
| return relik | |
| def save_pretrained( | |
| self, | |
| output_dir: Union[str, os.PathLike], | |
| config: Optional[Dict[str, Any]] = None, | |
| config_file_name: Optional[str] = None, | |
| save_weights: bool = False, | |
| push_to_hub: bool = False, | |
| model_id: Optional[str] = None, | |
| organization: Optional[str] = None, | |
| repo_name: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Save the configuration of Relik to the specified directory as a YAML file. | |
| Args: | |
| output_dir (`str`): | |
| The directory to save the configuration file to. | |
| config (`Optional[Dict[str, Any]]`, `optional`): | |
| The configuration to save. If `None`, the current configuration will be | |
| saved. Defaults to `None`. | |
| config_file_name (`Optional[str]`, `optional`): | |
| The name of the configuration file. Defaults to `config.yaml`. | |
| save_weights (`bool`, `optional`): | |
| Whether to save the weights of the model. Defaults to `False`. | |
| push_to_hub (`bool`, `optional`): | |
| Whether to push the saved model to the hub. Defaults to `False`. | |
| model_id (`Optional[str]`, `optional`): | |
| The id of the model to push to the hub. If `None`, the name of the | |
| directory will be used. Defaults to `None`. | |
| organization (`Optional[str]`, `optional`): | |
| The organization to push the model to. Defaults to `None`. | |
| repo_name (`Optional[str]`, `optional`): | |
| The name of the repository to push the model to. Defaults to `None`. | |
| **kwargs: | |
| Additional keyword arguments to pass to `OmegaConf.save`. | |
| """ | |
| if config is None: | |
| # create a default config | |
| config = { | |
| "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}" | |
| } | |
| if self.retriever is not None: | |
| if self.retriever.question_encoder is not None: | |
| config[ | |
| "question_encoder" | |
| ] = self.retriever.question_encoder.name_or_path | |
| if self.retriever.passage_encoder is not None: | |
| config[ | |
| "passage_encoder" | |
| ] = self.retriever.passage_encoder.name_or_path | |
| if self.retriever.document_index is not None: | |
| config["document_index"] = self.retriever.document_index.name_or_dir | |
| if self.reader is not None: | |
| config["reader"] = self.reader.model_path | |
| config["retriever_kwargs"] = self.retriever_kwargs | |
| config["reader_kwargs"] = self.reader_kwargs | |
| # expand the fn as to be able to save it and load it later | |
| config[ | |
| "candidates_preprocessing_fn" | |
| ] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}" | |
| # these are model-specific and should be saved | |
| config["top_k"] = self.top_k | |
| config["window_size"] = self.window_size | |
| config["window_stride"] = self.window_stride | |
| config_file_name = config_file_name or CONFIG_NAME | |
| # create the output directory | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Saving relik config to {output_dir / config_file_name}") | |
| # pretty print the config | |
| pprint(config, console=console_logger, expand_all=True) | |
| OmegaConf.save(config, output_dir / config_file_name) | |
| if save_weights: | |
| model_id = model_id or output_dir.name | |
| retriever_model_id = model_id + "-retriever" | |
| # save weights | |
| logger.info(f"Saving retriever to {output_dir / retriever_model_id}") | |
| self.retriever.save_pretrained( | |
| output_dir / retriever_model_id, | |
| question_encoder_name=retriever_model_id + "-question-encoder", | |
| passage_encoder_name=retriever_model_id + "-passage-encoder", | |
| document_index_name=retriever_model_id + "-index", | |
| push_to_hub=push_to_hub, | |
| organization=organization, | |
| repo_name=repo_name, | |
| **kwargs, | |
| ) | |
| reader_model_id = model_id + "-reader" | |
| logger.info(f"Saving reader to {output_dir / reader_model_id}") | |
| self.reader.save_pretrained( | |
| output_dir / reader_model_id, | |
| push_to_hub=push_to_hub, | |
| organization=organization, | |
| repo_name=repo_name, | |
| **kwargs, | |
| ) | |
| if push_to_hub: | |
| # push to hub | |
| logger.info(f"Pushing to hub") | |
| model_id = model_id or output_dir.name | |
| upload(output_dir, model_id, organization=organization, repo_name=repo_name) | |
| def main(): | |
| from pprint import pprint | |
| document_index = FaissDocumentIndex.from_pretrained( | |
| "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index", | |
| config_kwargs={"_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex", "index_type": "IVFx,Flat"}, | |
| ) | |
| relik = Relik( | |
| question_encoder="/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder", | |
| document_index=document_index, | |
| reader="/root/relik-spaces/models/relik-reader-aida-deberta-small", | |
| device="cuda", | |
| precision=16, | |
| top_k=100, | |
| window_size=32, | |
| window_stride=16, | |
| candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing", | |
| ) | |
| input_text = """ | |
| Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore. | |
| The 92-year-old billionaire did not disclose the trust to the government in July 2015. | |
| Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty. | |
| Ecclestone had been due to go on trial next month. | |
| """ | |
| preds = relik(input_text) | |
| pprint(preds) | |
| if __name__ == "__main__": | |
| main() | |