| import logging |
| import os |
| from dataclasses import dataclass, field |
| from functools import partial |
| from pathlib import Path |
| from tempfile import TemporaryDirectory |
| from typing import List, Optional |
|
|
| import faiss |
| import torch |
| from datasets import Features, Sequence, Value, load_dataset |
|
|
| from transformers import ( |
| DPRContextEncoder, |
| DPRContextEncoderTokenizerFast, |
| HfArgumentParser, |
| RagRetriever, |
| RagSequenceForGeneration, |
| RagTokenizer, |
| ) |
|
|
|
|
| logger = logging.getLogger(__name__) |
| torch.set_grad_enabled(False) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def split_text(text: str, n=100, character=" ") -> List[str]: |
| """Split the text every ``n``-th occurrence of ``character``""" |
| text = text.split(character) |
| return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)] |
|
|
|
|
| def split_documents(documents: dict) -> dict: |
| """Split documents into passages""" |
| titles, texts = [], [] |
| for title, text in zip(documents["title"], documents["text"]): |
| if text is not None: |
| for passage in split_text(text): |
| titles.append(title if title is not None else "") |
| texts.append(passage) |
| return {"title": titles, "text": texts} |
|
|
|
|
| def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict: |
| """Compute the DPR embeddings of document passages""" |
| input_ids = ctx_tokenizer( |
| documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt" |
| )["input_ids"] |
| embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output |
| return {"embeddings": embeddings.detach().cpu().numpy()} |
|
|
|
|
| def main( |
| rag_example_args: "RagExampleArguments", |
| processing_args: "ProcessingArguments", |
| index_hnsw_args: "IndexHnswArguments", |
| ): |
| |
| logger.info("Step 1 - Create the dataset") |
| |
|
|
| |
| |
| |
| |
|
|
| |
| assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file" |
|
|
| |
| dataset = load_dataset( |
| "csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"] |
| ) |
|
|
| |
|
|
| |
| dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc) |
|
|
| |
| ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device) |
| ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name) |
| new_features = Features( |
| {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))} |
| ) |
| dataset = dataset.map( |
| partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer), |
| batched=True, |
| batch_size=processing_args.batch_size, |
| features=new_features, |
| ) |
|
|
| |
| passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset") |
| dataset.save_to_disk(passages_path) |
| |
| |
|
|
| |
| logger.info("Step 2 - Index the dataset") |
| |
|
|
| |
| index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT) |
| dataset.add_faiss_index("embeddings", custom_index=index) |
|
|
| |
| index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss") |
| dataset.get_index("embeddings").save(index_path) |
| |
|
|
| |
| logger.info("Step 3 - Load RAG") |
| |
|
|
| |
| retriever = RagRetriever.from_pretrained( |
| rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset |
| ) |
| model = RagSequenceForGeneration.from_pretrained(rag_example_args.rag_model_name, retriever=retriever) |
| tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name) |
|
|
| |
| |
|
|
| |
| logger.info("Step 4 - Have fun") |
| |
|
|
| question = rag_example_args.question or "What does Moses' rod turn into ?" |
| input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"] |
| generated = model.generate(input_ids) |
| generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0] |
| logger.info("Q: " + question) |
| logger.info("A: " + generated_string) |
|
|
|
|
| @dataclass |
| class RagExampleArguments: |
| csv_path: str = field( |
| default=str(Path(__file__).parent / "test_data" / "my_knowledge_dataset.csv"), |
| metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"}, |
| ) |
| question: Optional[str] = field( |
| default=None, |
| metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."}, |
| ) |
| rag_model_name: str = field( |
| default="facebook/rag-sequence-nq", |
| metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"}, |
| ) |
| dpr_ctx_encoder_model_name: str = field( |
| default="facebook/dpr-ctx_encoder-multiset-base", |
| metadata={ |
| "help": ( |
| "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or" |
| " 'facebook/dpr-ctx_encoder-multiset-base'" |
| ) |
| }, |
| ) |
| output_dir: Optional[str] = field( |
| default=None, |
| metadata={"help": "Path to a directory where the dataset passages and the index will be saved"}, |
| ) |
|
|
|
|
| @dataclass |
| class ProcessingArguments: |
| num_proc: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": "The number of processes to use to split the documents into passages. Default is single process." |
| }, |
| ) |
| batch_size: int = field( |
| default=16, |
| metadata={ |
| "help": "The batch size to use when computing the passages embeddings using the DPR context encoder." |
| }, |
| ) |
|
|
|
|
| @dataclass |
| class IndexHnswArguments: |
| d: int = field( |
| default=768, |
| metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."}, |
| ) |
| m: int = field( |
| default=128, |
| metadata={ |
| "help": ( |
| "The number of bi-directional links created for every new element during the HNSW index construction." |
| ) |
| }, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.WARNING) |
| logger.setLevel(logging.INFO) |
|
|
| parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments)) |
| rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses() |
| with TemporaryDirectory() as tmp_dir: |
| rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir |
| main(rag_example_args, processing_args, index_hnsw_args) |
|
|