IRIS-FLOWER-CLASSIFICATION-using-machine-learning-models
/
transformers
/examples
/research_projects
/rag
/use_own_knowledge_dataset.py
| import logging | |
| import os | |
| from dataclasses import dataclass, field | |
| from functools import partial | |
| from pathlib import Path | |
| from tempfile import TemporaryDirectory | |
| from typing import List, Optional | |
| import faiss | |
| import torch | |
| from datasets import Features, Sequence, Value, load_dataset | |
| from transformers import ( | |
| DPRContextEncoder, | |
| DPRContextEncoderTokenizerFast, | |
| HfArgumentParser, | |
| RagRetriever, | |
| RagSequenceForGeneration, | |
| RagTokenizer, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| torch.set_grad_enabled(False) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def split_text(text: str, n=100, character=" ") -> List[str]: | |
| """Split the text every ``n``-th occurrence of ``character``""" | |
| text = text.split(character) | |
| return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)] | |
| def split_documents(documents: dict) -> dict: | |
| """Split documents into passages""" | |
| titles, texts = [], [] | |
| for title, text in zip(documents["title"], documents["text"]): | |
| if text is not None: | |
| for passage in split_text(text): | |
| titles.append(title if title is not None else "") | |
| texts.append(passage) | |
| return {"title": titles, "text": texts} | |
| def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict: | |
| """Compute the DPR embeddings of document passages""" | |
| input_ids = ctx_tokenizer( | |
| documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt" | |
| )["input_ids"] | |
| embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output | |
| return {"embeddings": embeddings.detach().cpu().numpy()} | |
| def main( | |
| rag_example_args: "RagExampleArguments", | |
| processing_args: "ProcessingArguments", | |
| index_hnsw_args: "IndexHnswArguments", | |
| ): | |
| ###################################### | |
| logger.info("Step 1 - Create the dataset") | |
| ###################################### | |
| # The dataset needed for RAG must have three columns: | |
| # - title (string): title of the document | |
| # - text (string): text of a passage of the document | |
| # - embeddings (array of dimension d): DPR representation of the passage | |
| # Let's say you have documents in tab-separated csv files with columns "title" and "text" | |
| assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file" | |
| # You can load a Dataset object this way | |
| dataset = load_dataset( | |
| "csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"] | |
| ) | |
| # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets?highlight=csv#csv-files | |
| # Then split the documents into passages of 100 words | |
| dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc) | |
| # And compute the embeddings | |
| ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device) | |
| ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name) | |
| new_features = Features( | |
| {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))} | |
| ) # optional, save as float32 instead of float64 to save space | |
| dataset = dataset.map( | |
| partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer), | |
| batched=True, | |
| batch_size=processing_args.batch_size, | |
| features=new_features, | |
| ) | |
| # And finally save your dataset | |
| passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset") | |
| dataset.save_to_disk(passages_path) | |
| # from datasets import load_from_disk | |
| # dataset = load_from_disk(passages_path) # to reload the dataset | |
| ###################################### | |
| logger.info("Step 2 - Index the dataset") | |
| ###################################### | |
| # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search | |
| index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT) | |
| dataset.add_faiss_index("embeddings", custom_index=index) | |
| # And save the index | |
| index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss") | |
| dataset.get_index("embeddings").save(index_path) | |
| # dataset.load_faiss_index("embeddings", index_path) # to reload the index | |
| ###################################### | |
| logger.info("Step 3 - Load RAG") | |
| ###################################### | |
| # Easy way to load the model | |
| retriever = RagRetriever.from_pretrained( | |
| rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset | |
| ) | |
| model = RagSequenceForGeneration.from_pretrained(rag_example_args.rag_model_name, retriever=retriever) | |
| tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name) | |
| # For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately. | |
| # retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path) | |
| ###################################### | |
| logger.info("Step 4 - Have fun") | |
| ###################################### | |
| question = rag_example_args.question or "What does Moses' rod turn into ?" | |
| input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"] | |
| generated = model.generate(input_ids) | |
| generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0] | |
| logger.info("Q: " + question) | |
| logger.info("A: " + generated_string) | |
| class RagExampleArguments: | |
| csv_path: str = field( | |
| default=str(Path(__file__).parent / "test_data" / "my_knowledge_dataset.csv"), | |
| metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"}, | |
| ) | |
| question: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."}, | |
| ) | |
| rag_model_name: str = field( | |
| default="facebook/rag-sequence-nq", | |
| metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"}, | |
| ) | |
| dpr_ctx_encoder_model_name: str = field( | |
| default="facebook/dpr-ctx_encoder-multiset-base", | |
| metadata={ | |
| "help": ( | |
| "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or" | |
| " 'facebook/dpr-ctx_encoder-multiset-base'" | |
| ) | |
| }, | |
| ) | |
| output_dir: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "Path to a directory where the dataset passages and the index will be saved"}, | |
| ) | |
| class ProcessingArguments: | |
| num_proc: Optional[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": "The number of processes to use to split the documents into passages. Default is single process." | |
| }, | |
| ) | |
| batch_size: int = field( | |
| default=16, | |
| metadata={ | |
| "help": "The batch size to use when computing the passages embeddings using the DPR context encoder." | |
| }, | |
| ) | |
| class IndexHnswArguments: | |
| d: int = field( | |
| default=768, | |
| metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."}, | |
| ) | |
| m: int = field( | |
| default=128, | |
| metadata={ | |
| "help": ( | |
| "The number of bi-directional links created for every new element during the HNSW index construction." | |
| ) | |
| }, | |
| ) | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.WARNING) | |
| logger.setLevel(logging.INFO) | |
| parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments)) | |
| rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses() | |
| with TemporaryDirectory() as tmp_dir: | |
| rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir | |
| main(rag_example_args, processing_args, index_hnsw_args) | |