Spaces:
Runtime error
Runtime error
| from haystack import Document, Pipeline | |
| from haystack.document_stores.in_memory import InMemoryDocumentStore | |
| from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder | |
| from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever | |
| from haystack.components.builders import ChatPromptBuilder | |
| from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator | |
| from datasets import load_dataset | |
| from haystack.dataclasses import ChatMessage | |
| from typing import Optional, List, Union, Dict | |
| from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG | |
| class RAGPipeline: | |
| def __init__( | |
| self, | |
| google_api_key: str, | |
| dataset_config: Union[str, DatasetConfig], | |
| documents: Optional[List[Union[str, Document]]] = None, | |
| embedding_model: Optional[str] = None, | |
| llm_model: Optional[str] = None | |
| ): | |
| """ | |
| Initialize the RAG Pipeline. | |
| Args: | |
| google_api_key: API key for Google AI services | |
| dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object | |
| documents: Optional list of documents to use instead of loading from a dataset | |
| embedding_model: Optional override for embedding model | |
| llm_model: Optional override for LLM model | |
| """ | |
| # Load configuration | |
| if isinstance(dataset_config, str): | |
| if dataset_config not in DATASET_CONFIGS: | |
| raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}") | |
| self.config = DATASET_CONFIGS[dataset_config] | |
| else: | |
| self.config = dataset_config | |
| # Load documents either from provided list or dataset | |
| if documents is not None: | |
| self.documents = documents | |
| else: | |
| dataset = load_dataset(self.config.name, split=self.config.split) | |
| # Create documents with metadata based on configuration | |
| self.documents = [] | |
| for doc in dataset: | |
| # Create metadata dictionary from configured fields | |
| meta = {} | |
| if self.config.fields: | |
| for meta_key, dataset_field in self.config.fields.items(): | |
| if dataset_field in doc: | |
| meta[meta_key] = doc[dataset_field] | |
| # Create document with content and metadata | |
| document = Document( | |
| content=doc[self.config.content_field], | |
| meta=meta | |
| ) | |
| self.documents.append(document) | |
| # print 10 documents | |
| for doc in self.documents[:10]: | |
| print(f"Content: {doc.content}") | |
| print(f"Metadata: {doc.meta}") | |
| print("-"*100) | |
| # Initialize components | |
| self.document_store = InMemoryDocumentStore() | |
| self.doc_embedder = SentenceTransformersDocumentEmbedder( | |
| model=embedding_model or MODEL_CONFIG["embedding_model"] | |
| ) | |
| self.text_embedder = SentenceTransformersTextEmbedder( | |
| model=embedding_model or MODEL_CONFIG["embedding_model"] | |
| ) | |
| self.retriever = InMemoryEmbeddingRetriever(self.document_store) | |
| # Warm up the document embedder | |
| self.doc_embedder.warm_up() | |
| # Initialize prompt template | |
| template = [ | |
| ChatMessage.from_user(self.config.prompt_template) | |
| ] | |
| self.prompt_builder = ChatPromptBuilder(template=template) | |
| # Initialize the generator | |
| self.generator = GoogleAIGeminiChatGenerator( | |
| model=llm_model or MODEL_CONFIG["llm_model"] | |
| ) | |
| # Index documents | |
| self._index_documents(self.documents) | |
| # Build pipeline | |
| self.pipeline = self._build_pipeline() | |
| def from_preset(cls, google_api_key: str, preset_name: str): | |
| """ | |
| Create a pipeline from a preset configuration. | |
| Args: | |
| google_api_key: API key for Google AI services | |
| preset_name: Name of the preset configuration to use | |
| """ | |
| return cls(google_api_key=google_api_key, dataset_config=preset_name) | |
| def _index_documents(self, documents): | |
| # Embed and index documents | |
| docs_with_embeddings = self.doc_embedder.run(documents) | |
| self.document_store.write_documents(docs_with_embeddings["documents"]) | |
| def _build_pipeline(self): | |
| pipeline = Pipeline() | |
| pipeline.add_component("text_embedder", self.text_embedder) | |
| pipeline.add_component("retriever", self.retriever) | |
| pipeline.add_component("prompt_builder", self.prompt_builder) | |
| pipeline.add_component("llm", self.generator) | |
| # Connect components | |
| pipeline.connect("text_embedder.embedding", "retriever.query_embedding") | |
| pipeline.connect("retriever", "prompt_builder") | |
| pipeline.connect("prompt_builder.prompt", "llm.messages") | |
| return pipeline | |
| def answer_question(self, question: str) -> str: | |
| """Run the RAG pipeline to answer a question""" | |
| result = self.pipeline.run({ | |
| "text_embedder": {"text": question}, | |
| "prompt_builder": {"question": question} | |
| }) | |
| return result["llm"]["replies"][0].text |