Spaces:
Sleeping
Sleeping
| import logging | |
| import weakref | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, List, Tuple, Dict | |
| from uuid import uuid4 | |
| import adalflow as adal | |
| from api.tools.embedder import get_embedder | |
| from api.prompts import RAG_SYSTEM_PROMPT as system_prompt, RAG_TEMPLATE | |
| # Create our own implementation of the conversation classes | |
| class UserQuery: | |
| query_str: str | |
| class AssistantResponse: | |
| response_str: str | |
| class DialogTurn: | |
| id: str | |
| user_query: UserQuery | |
| assistant_response: AssistantResponse | |
| class CustomConversation: | |
| """Custom implementation of Conversation to fix the list assignment index out of range error""" | |
| def __init__(self): | |
| self.dialog_turns = [] | |
| def append_dialog_turn(self, dialog_turn): | |
| """Safely append a dialog turn to the conversation""" | |
| if not hasattr(self, 'dialog_turns'): | |
| self.dialog_turns = [] | |
| self.dialog_turns.append(dialog_turn) | |
| # Import other adalflow components | |
| from adalflow.components.retriever.faiss_retriever import FAISSRetriever | |
| from api.config import configs | |
| from api.data_pipeline import DatabaseManager | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| # Maximum token limit for embedding models | |
| MAX_INPUT_TOKENS = 7500 # Safe threshold below 8192 token limit | |
| class Memory(adal.core.component.DataComponent): | |
| """Simple conversation management with a list of dialog turns.""" | |
| def __init__(self): | |
| super().__init__() | |
| # Use our custom implementation instead of the original Conversation class | |
| self.current_conversation = CustomConversation() | |
| def call(self) -> Dict: | |
| """Return the conversation history as a dictionary.""" | |
| all_dialog_turns = {} | |
| try: | |
| # Check if dialog_turns exists and is a list | |
| if hasattr(self.current_conversation, 'dialog_turns'): | |
| if self.current_conversation.dialog_turns: | |
| logger.info(f"Memory content: {len(self.current_conversation.dialog_turns)} turns") | |
| for i, turn in enumerate(self.current_conversation.dialog_turns): | |
| if hasattr(turn, 'id') and turn.id is not None: | |
| all_dialog_turns[turn.id] = turn | |
| logger.info(f"Added turn {i+1} with ID {turn.id} to memory") | |
| else: | |
| logger.warning(f"Skipping invalid turn object in memory: {turn}") | |
| else: | |
| logger.info("Dialog turns list exists but is empty") | |
| else: | |
| logger.info("No dialog_turns attribute in current_conversation") | |
| # Try to initialize it | |
| self.current_conversation.dialog_turns = [] | |
| except Exception as e: | |
| logger.error(f"Error accessing dialog turns: {str(e)}") | |
| # Try to recover | |
| try: | |
| self.current_conversation = CustomConversation() | |
| logger.info("Recovered by creating new conversation") | |
| except Exception as e2: | |
| logger.error(f"Failed to recover: {str(e2)}") | |
| logger.info(f"Returning {len(all_dialog_turns)} dialog turns from memory") | |
| return all_dialog_turns | |
| def add_dialog_turn(self, user_query: str, assistant_response: str) -> bool: | |
| """ | |
| Add a dialog turn to the conversation history. | |
| Args: | |
| user_query: The user's query | |
| assistant_response: The assistant's response | |
| Returns: | |
| bool: True if successful, False otherwise | |
| """ | |
| try: | |
| # Create a new dialog turn using our custom implementation | |
| dialog_turn = DialogTurn( | |
| id=str(uuid4()), | |
| user_query=UserQuery(query_str=user_query), | |
| assistant_response=AssistantResponse(response_str=assistant_response), | |
| ) | |
| # Make sure the current_conversation has the append_dialog_turn method | |
| if not hasattr(self.current_conversation, 'append_dialog_turn'): | |
| logger.warning("current_conversation does not have append_dialog_turn method, creating new one") | |
| # Initialize a new conversation if needed | |
| self.current_conversation = CustomConversation() | |
| # Ensure dialog_turns exists | |
| if not hasattr(self.current_conversation, 'dialog_turns'): | |
| logger.warning("dialog_turns not found, initializing empty list") | |
| self.current_conversation.dialog_turns = [] | |
| # Safely append the dialog turn | |
| self.current_conversation.dialog_turns.append(dialog_turn) | |
| logger.info(f"Successfully added dialog turn, now have {len(self.current_conversation.dialog_turns)} turns") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error adding dialog turn: {str(e)}") | |
| # Try to recover by creating a new conversation | |
| try: | |
| self.current_conversation = CustomConversation() | |
| dialog_turn = DialogTurn( | |
| id=str(uuid4()), | |
| user_query=UserQuery(query_str=user_query), | |
| assistant_response=AssistantResponse(response_str=assistant_response), | |
| ) | |
| self.current_conversation.dialog_turns.append(dialog_turn) | |
| logger.info("Recovered from error by creating new conversation") | |
| return True | |
| except Exception as e2: | |
| logger.error(f"Failed to recover from error: {str(e2)}") | |
| return False | |
| from dataclasses import dataclass, field | |
| class RAGAnswer(adal.DataClass): | |
| rationale: str = field(default="", metadata={"desc": "Chain of thoughts for the answer."}) | |
| answer: str = field(default="", metadata={"desc": "Answer to the user query, formatted in markdown for beautiful rendering with react-markdown. DO NOT include ``` triple backticks fences at the beginning or end of your answer."}) | |
| __output_fields__ = ["rationale", "answer"] | |
| class RAG(adal.Component): | |
| """RAG with one repo. | |
| If you want to load a new repos, call prepare_retriever(repo_url_or_path) first.""" | |
| def __init__(self, provider="google", model=None, use_s3: bool = False): # noqa: F841 - use_s3 is kept for compatibility | |
| """ | |
| Initialize the RAG component. | |
| Args: | |
| provider: Model provider to use (google, openai, openrouter, ollama) | |
| model: Model name to use with the provider | |
| use_s3: Whether to use S3 for database storage (default: False) | |
| """ | |
| super().__init__() | |
| self.provider = provider | |
| self.model = model | |
| # Import the helper functions | |
| from api.config import get_embedder_config, is_ollama_embedder | |
| # Determine if we're using Ollama embedder based on configuration | |
| self.is_ollama_embedder = is_ollama_embedder() | |
| # Check if Ollama model exists before proceeding | |
| if self.is_ollama_embedder: | |
| from api.ollama_patch import check_ollama_model_exists | |
| from api.config import get_embedder_config | |
| embedder_config = get_embedder_config() | |
| if embedder_config and embedder_config.get("model_kwargs", {}).get("model"): | |
| model_name = embedder_config["model_kwargs"]["model"] | |
| if not check_ollama_model_exists(model_name): | |
| raise Exception(f"Ollama model '{model_name}' not found. Please run 'ollama pull {model_name}' to install it.") | |
| # Initialize components | |
| self.memory = Memory() | |
| self.embedder = get_embedder() | |
| self_weakref = weakref.ref(self) | |
| # Patch: ensure query embedding is always single string for Ollama | |
| def single_string_embedder(query): | |
| # Accepts either a string or a list, always returns embedding for a single string | |
| if isinstance(query, list): | |
| if len(query) != 1: | |
| raise ValueError("Ollama embedder only supports a single string") | |
| query = query[0] | |
| instance = self_weakref() | |
| assert instance is not None, "RAG instance is no longer available, but the query embedder was called." | |
| return instance.embedder(input=query) | |
| # Use single string embedder for Ollama, regular embedder for others | |
| self.query_embedder = single_string_embedder if self.is_ollama_embedder else self.embedder | |
| self.initialize_db_manager() | |
| # Set up the output parser | |
| data_parser = adal.DataClassParser(data_class=RAGAnswer, return_data_class=True) | |
| # Format instructions to ensure proper output structure | |
| format_instructions = data_parser.get_output_format_str() + """ | |
| IMPORTANT FORMATTING RULES: | |
| 1. DO NOT include your thinking or reasoning process in the output | |
| 2. Provide only the final, polished answer | |
| 3. DO NOT include ```markdown fences at the beginning or end of your answer | |
| 4. DO NOT wrap your response in any kind of fences | |
| 5. Start your response directly with the content | |
| 6. The content will already be rendered as markdown | |
| 7. Do not use backslashes before special characters like [ ] { } in your answer | |
| 8. When listing tags or similar items, write them as plain text without escape characters | |
| 9. For pipe characters (|) in text, write them directly without escaping them""" | |
| # Get model configuration based on provider and model | |
| from api.config import get_model_config | |
| generator_config = get_model_config(self.provider, self.model) | |
| # Set up the main generator | |
| self.generator = adal.Generator( | |
| template=RAG_TEMPLATE, | |
| prompt_kwargs={ | |
| "output_format_str": format_instructions, | |
| "conversation_history": self.memory(), | |
| "system_prompt": system_prompt, | |
| "contexts": None, | |
| }, | |
| model_client=generator_config["model_client"](), | |
| model_kwargs=generator_config["model_kwargs"], | |
| output_processors=data_parser, | |
| ) | |
| def initialize_db_manager(self): | |
| """Initialize the database manager with local storage""" | |
| self.db_manager = DatabaseManager() | |
| self.transformed_docs = [] | |
| def _validate_and_filter_embeddings(self, documents: List) -> List: | |
| """ | |
| Validate embeddings and filter out documents with invalid or mismatched embedding sizes. | |
| Args: | |
| documents: List of documents with embeddings | |
| Returns: | |
| List of documents with valid embeddings of consistent size | |
| """ | |
| if not documents: | |
| logger.warning("No documents provided for embedding validation") | |
| return [] | |
| valid_documents = [] | |
| embedding_sizes = {} | |
| # First pass: collect all embedding sizes and count occurrences | |
| for i, doc in enumerate(documents): | |
| if not hasattr(doc, 'vector') or doc.vector is None: | |
| logger.warning(f"Document {i} has no embedding vector, skipping") | |
| continue | |
| try: | |
| if isinstance(doc.vector, list): | |
| embedding_size = len(doc.vector) | |
| elif hasattr(doc.vector, 'shape'): | |
| embedding_size = doc.vector.shape[0] if len(doc.vector.shape) == 1 else doc.vector.shape[-1] | |
| elif hasattr(doc.vector, '__len__'): | |
| embedding_size = len(doc.vector) | |
| else: | |
| logger.warning(f"Document {i} has invalid embedding vector type: {type(doc.vector)}, skipping") | |
| continue | |
| if embedding_size == 0: | |
| logger.warning(f"Document {i} has empty embedding vector, skipping") | |
| continue | |
| embedding_sizes[embedding_size] = embedding_sizes.get(embedding_size, 0) + 1 | |
| except Exception as e: | |
| logger.warning(f"Error checking embedding size for document {i}: {str(e)}, skipping") | |
| continue | |
| if not embedding_sizes: | |
| logger.error("No valid embeddings found in any documents") | |
| return [] | |
| # Find the most common embedding size (this should be the correct one) | |
| target_size = max(embedding_sizes.keys(), key=lambda k: embedding_sizes[k]) | |
| logger.info(f"Target embedding size: {target_size} (found in {embedding_sizes[target_size]} documents)") | |
| # Log all embedding sizes found | |
| for size, count in embedding_sizes.items(): | |
| if size != target_size: | |
| logger.warning(f"Found {count} documents with incorrect embedding size {size}, will be filtered out") | |
| # Second pass: filter documents with the target embedding size | |
| for i, doc in enumerate(documents): | |
| if not hasattr(doc, 'vector') or doc.vector is None: | |
| continue | |
| try: | |
| if isinstance(doc.vector, list): | |
| embedding_size = len(doc.vector) | |
| elif hasattr(doc.vector, 'shape'): | |
| embedding_size = doc.vector.shape[0] if len(doc.vector.shape) == 1 else doc.vector.shape[-1] | |
| elif hasattr(doc.vector, '__len__'): | |
| embedding_size = len(doc.vector) | |
| else: | |
| continue | |
| if embedding_size == target_size: | |
| valid_documents.append(doc) | |
| else: | |
| # Log which document is being filtered out | |
| file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}') | |
| logger.warning(f"Filtering out document '{file_path}' due to embedding size mismatch: {embedding_size} != {target_size}") | |
| except Exception as e: | |
| file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}') | |
| logger.warning(f"Error validating embedding for document '{file_path}': {str(e)}, skipping") | |
| continue | |
| logger.info(f"Embedding validation complete: {len(valid_documents)}/{len(documents)} documents have valid embeddings") | |
| if len(valid_documents) == 0: | |
| logger.error("No documents with valid embeddings remain after filtering") | |
| elif len(valid_documents) < len(documents): | |
| filtered_count = len(documents) - len(valid_documents) | |
| logger.warning(f"Filtered out {filtered_count} documents due to embedding issues") | |
| return valid_documents | |
| def prepare_retriever(self, repo_url_or_path: str, type: str = "github", access_token: str = None, | |
| excluded_dirs: List[str] = None, excluded_files: List[str] = None, | |
| included_dirs: List[str] = None, included_files: List[str] = None): | |
| """ | |
| Prepare the retriever for a repository. | |
| Will load database from local storage if available. | |
| Args: | |
| repo_url_or_path: URL or local path to the repository | |
| access_token: Optional access token for private repositories | |
| excluded_dirs: Optional list of directories to exclude from processing | |
| excluded_files: Optional list of file patterns to exclude from processing | |
| included_dirs: Optional list of directories to include exclusively | |
| included_files: Optional list of file patterns to include exclusively | |
| """ | |
| self.initialize_db_manager() | |
| self.repo_url_or_path = repo_url_or_path | |
| self.transformed_docs = self.db_manager.prepare_database( | |
| repo_url_or_path, | |
| type, | |
| access_token, | |
| is_ollama_embedder=self.is_ollama_embedder, | |
| excluded_dirs=excluded_dirs, | |
| excluded_files=excluded_files, | |
| included_dirs=included_dirs, | |
| included_files=included_files | |
| ) | |
| logger.info(f"Loaded {len(self.transformed_docs)} documents for retrieval") | |
| # Validate and filter embeddings to ensure consistent sizes | |
| self.transformed_docs = self._validate_and_filter_embeddings(self.transformed_docs) | |
| if not self.transformed_docs: | |
| raise ValueError("No valid documents with embeddings found. Cannot create retriever.") | |
| logger.info(f"Using {len(self.transformed_docs)} documents with valid embeddings for retrieval") | |
| try: | |
| # Use the appropriate embedder for retrieval | |
| retrieve_embedder = self.query_embedder if self.is_ollama_embedder else self.embedder | |
| self.retriever = FAISSRetriever( | |
| **configs["retriever"], | |
| embedder=retrieve_embedder, | |
| documents=self.transformed_docs, | |
| document_map_func=lambda doc: doc.vector, | |
| ) | |
| logger.info("FAISS retriever created successfully") | |
| except Exception as e: | |
| logger.error(f"Error creating FAISS retriever: {str(e)}") | |
| # Try to provide more specific error information | |
| if "All embeddings should be of the same size" in str(e): | |
| logger.error("Embedding size validation failed. This suggests there are still inconsistent embedding sizes.") | |
| # Log embedding sizes for debugging | |
| sizes = [] | |
| for i, doc in enumerate(self.transformed_docs[:10]): # Check first 10 docs | |
| if hasattr(doc, 'vector') and doc.vector is not None: | |
| try: | |
| if isinstance(doc.vector, list): | |
| size = len(doc.vector) | |
| elif hasattr(doc.vector, 'shape'): | |
| size = doc.vector.shape[0] if len(doc.vector.shape) == 1 else doc.vector.shape[-1] | |
| elif hasattr(doc.vector, '__len__'): | |
| size = len(doc.vector) | |
| else: | |
| size = "unknown" | |
| sizes.append(f"doc_{i}: {size}") | |
| except: | |
| sizes.append(f"doc_{i}: error") | |
| logger.error(f"Sample embedding sizes: {', '.join(sizes)}") | |
| raise | |
| def call(self, query: str, language: str = "en") -> Tuple[List]: | |
| """ | |
| Process a query using RAG. | |
| Args: | |
| query: The user's query | |
| Returns: | |
| Tuple of (RAGAnswer, retrieved_documents) | |
| """ | |
| try: | |
| retrieved_documents = self.retriever(query) | |
| # Fill in the documents | |
| retrieved_documents[0].documents = [ | |
| self.transformed_docs[doc_index] | |
| for doc_index in retrieved_documents[0].doc_indices | |
| ] | |
| return retrieved_documents | |
| except Exception as e: | |
| logger.error(f"Error in RAG call: {str(e)}") | |
| # Create error response | |
| error_response = RAGAnswer( | |
| rationale="Error occurred while processing the query.", | |
| answer=f"I apologize, but I encountered an error while processing your question. Please try again or rephrase your question." | |
| ) | |
| return error_response, [] | |