Spaces:
Runtime error
Runtime error
| from datasets import Dataset, load_from_disk | |
| import faiss | |
| import numpy as np | |
| from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration | |
| # Example: Create a dataset | |
| data = {"text": ["This is a sample text.", "Another sample text."]} | |
| dataset = Dataset.from_dict(data) | |
| # Save the dataset to disk | |
| dataset_path = "path/to/your/dataset" | |
| dataset.save_to_disk(dataset_path) | |
| # Create FAISS index | |
| passages = dataset["text"] | |
| tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
| model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq") | |
| passage_embeddings = model.get_encoder()( | |
| tokenizer(passages, return_tensors="pt", padding=True, truncation=True) | |
| ).last_hidden_state.mean(dim=1).detach().numpy() | |
| index = faiss.IndexFlatL2(passage_embeddings.shape[1]) | |
| index.add(passage_embeddings) | |
| # Save the index to disk | |
| index_path = "path/to/your/index" | |
| faiss.write_index(index, index_path) |