import os import logging import uuid from typing import List, Dict, Any, Optional from datetime import datetime from azure.search.documents import SearchClient from azure.search.documents.indexes import SearchIndexClient from azure.search.documents.indexes.models import ( SearchIndex, SimpleField, SearchableField, SearchField, VectorSearch, HnswAlgorithmConfiguration, VectorSearchProfile, SearchFieldDataType ) from azure.core.credentials import AzureKeyCredential from openai import AzureOpenAI from core.config import settings logger = logging.getLogger(__name__) class RAGService: def __init__(self): # Azure Search self.search_endpoint = settings.AZURE_SEARCH_ENDPOINT self.search_key = settings.AZURE_SEARCH_KEY self.index_name = settings.AZURE_SEARCH_INDEX_NAME # Azure OpenAI for embeddings self.azure_openai_client = AzureOpenAI( api_key=settings.AZURE_OPENAI_API_KEY, api_version=settings.AZURE_OPENAI_API_VERSION, azure_endpoint=settings.AZURE_OPENAI_ENDPOINT.split("/openai/")[0] ) self.embedding_deployment = settings.AZURE_OPENAI_DEPLOYMENT_NAME # Initialize clients self.search_client = SearchClient( endpoint=self.search_endpoint, index_name=self.index_name, credential=AzureKeyCredential(self.search_key) ) self.index_client = SearchIndexClient( endpoint=self.search_endpoint, credential=AzureKeyCredential(self.search_key) ) # Ensure index exists self._ensure_index_exists() def _ensure_index_exists(self): """Create or recreate Azure AI Search index if it doesn't exist or is incompatible.""" try: existing_index = self.index_client.get_index(self.index_name) # Check for required fields required_fields = {"filename", "doc_id", "user_id", "content_vector"} existing_fields = {field.name for field in existing_index.fields} if not required_fields.issubset(existing_fields): logger.warning(f"Index {self.index_name} is incompatible. Recreating...") self.index_client.delete_index(self.index_name) self._create_index() else: logger.info(f"Index {self.index_name} exists and is compatible") except Exception: logger.info(f"Creating index {self.index_name}...") self._create_index() def _create_index(self): """Create the search index with vector configuration.""" fields = [ SimpleField(name="id", type=SearchFieldDataType.String, key=True), SearchableField(name="content", type=SearchFieldDataType.String), SearchableField(name="filename", type=SearchFieldDataType.String, filterable=True), SimpleField(name="doc_id", type=SearchFieldDataType.String, filterable=True), SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True), SimpleField(name="chunk_index", type=SearchFieldDataType.Int32), SimpleField(name="created_at", type=SearchFieldDataType.DateTimeOffset), SearchField( name="content_vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, vector_search_dimensions=1536, vector_search_profile_name="my-vector-profile" ) ] vector_search = VectorSearch( algorithms=[HnswAlgorithmConfiguration(name="my-hnsw")], profiles=[ VectorSearchProfile( name="my-vector-profile", algorithm_configuration_name="my-hnsw" ) ] ) index = SearchIndex( name=self.index_name, fields=fields, vector_search=vector_search ) self.index_client.create_index(index) logger.info(f"Created index: {self.index_name}") def generate_embeddings(self, texts: List[str]) -> List[List[float]]: """Generate embeddings using Azure OpenAI.""" try: embeddings = [] for text in texts: response = self.azure_openai_client.embeddings.create( input=text, model=self.embedding_deployment ) embeddings.append(response.data[0].embedding) return embeddings except Exception as e: logger.error(f"Error generating embeddings: {e}") raise def index_document( self, chunks: List[str], filename: str, user_id: int, doc_id: str ) -> int: """Index document chunks with embeddings in Azure Search.""" try: # Generate embeddings logger.info(f"Generating embeddings for {len(chunks)} chunks...") embeddings = self.generate_embeddings(chunks) # Prepare documents documents = [] for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)): doc = { "id": f"{doc_id}_{idx}", "content": chunk, "filename": filename, "doc_id": doc_id, "user_id": str(user_id), "chunk_index": idx, "created_at": datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"), "content_vector": embedding } documents.append(doc) # Upload to search index result = self.search_client.upload_documents(documents=documents) logger.info(f"Indexed {len(documents)} chunks for {filename}") return len(documents) except Exception as e: logger.error(f"Error indexing document: {e}") raise def search_document( self, query: str, doc_id: str, user_id: int, top_k: int = 3 ) -> List[Dict[str, Any]]: """Search within a specific document using vector search.""" try: # Generate query embedding query_embedding = self.generate_embeddings([query])[0] # Vector search with filters from azure.search.documents.models import VectorizedQuery vector_query = VectorizedQuery( vector=query_embedding, k_nearest_neighbors=top_k, fields="content_vector" ) results = self.search_client.search( search_text=None, vector_queries=[vector_query], filter=f"doc_id eq '{doc_id}' and user_id eq '{user_id}'", top=top_k, select=["content", "filename", "chunk_index"] ) # Format results search_results = [] for result in results: search_results.append({ "content": result["content"], "chunk_index": result.get("chunk_index", 0) }) return search_results except Exception as e: logger.error(f"Error searching document: {e}") raise def delete_document(self, doc_id: str): """Delete all chunks of a document from the search index.""" try: # Search for all chunks results = self.search_client.search( search_text="*", filter=f"doc_id eq '{doc_id}'", select=["id"], top=1000 ) # Delete all chunks doc_ids = [{"id": r["id"]} for r in results] if doc_ids: self.search_client.delete_documents(documents=doc_ids) logger.info(f"Deleted {len(doc_ids)} chunks for document {doc_id}") except Exception as e: logger.error(f"Error deleting document: {e}") raise def document_exists(self, doc_id: str, user_id: int) -> bool: """Check if a document is already indexed.""" try: results = self.search_client.search( search_text="*", filter=f"doc_id eq '{doc_id}' and user_id eq '{user_id}'", top=1, select=["id"] ) return len(list(results)) > 0 except: return False rag_service = RAGService()