| | import os |
| | import logging |
| | import uuid |
| | from typing import List, Dict, Any, Optional |
| | from datetime import datetime |
| |
|
| | from azure.search.documents import SearchClient |
| | from azure.search.documents.indexes import SearchIndexClient |
| | from azure.search.documents.indexes.models import ( |
| | SearchIndex, |
| | SimpleField, |
| | SearchableField, |
| | SearchField, |
| | VectorSearch, |
| | HnswAlgorithmConfiguration, |
| | VectorSearchProfile, |
| | SearchFieldDataType |
| | ) |
| | from azure.core.credentials import AzureKeyCredential |
| | from openai import AzureOpenAI |
| |
|
| | from core.config import settings |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class RAGService: |
| | def __init__(self): |
| | |
| | self.search_endpoint = settings.AZURE_SEARCH_ENDPOINT |
| | self.search_key = settings.AZURE_SEARCH_KEY |
| | self.index_name = settings.AZURE_SEARCH_INDEX_NAME |
| | |
| | |
| | self.azure_openai_client = AzureOpenAI( |
| | api_key=settings.AZURE_OPENAI_API_KEY, |
| | api_version=settings.AZURE_OPENAI_API_VERSION, |
| | azure_endpoint=settings.AZURE_OPENAI_ENDPOINT.split("/openai/")[0] |
| | ) |
| | self.embedding_deployment = settings.AZURE_OPENAI_DEPLOYMENT_NAME |
| | |
| | |
| | self.search_client = SearchClient( |
| | endpoint=self.search_endpoint, |
| | index_name=self.index_name, |
| | credential=AzureKeyCredential(self.search_key) |
| | ) |
| | |
| | self.index_client = SearchIndexClient( |
| | endpoint=self.search_endpoint, |
| | credential=AzureKeyCredential(self.search_key) |
| | ) |
| | |
| | |
| | self._ensure_index_exists() |
| |
|
| | def _ensure_index_exists(self): |
| | """Create or recreate Azure AI Search index if it doesn't exist or is incompatible.""" |
| | try: |
| | existing_index = self.index_client.get_index(self.index_name) |
| | |
| | |
| | required_fields = {"filename", "doc_id", "user_id", "content_vector"} |
| | existing_fields = {field.name for field in existing_index.fields} |
| | |
| | if not required_fields.issubset(existing_fields): |
| | logger.warning(f"Index {self.index_name} is incompatible. Recreating...") |
| | self.index_client.delete_index(self.index_name) |
| | self._create_index() |
| | else: |
| | logger.info(f"Index {self.index_name} exists and is compatible") |
| | except Exception: |
| | logger.info(f"Creating index {self.index_name}...") |
| | self._create_index() |
| |
|
| | def _create_index(self): |
| | """Create the search index with vector configuration.""" |
| | fields = [ |
| | SimpleField(name="id", type=SearchFieldDataType.String, key=True), |
| | SearchableField(name="content", type=SearchFieldDataType.String), |
| | SearchableField(name="filename", type=SearchFieldDataType.String, filterable=True), |
| | SimpleField(name="doc_id", type=SearchFieldDataType.String, filterable=True), |
| | SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True), |
| | SimpleField(name="chunk_index", type=SearchFieldDataType.Int32), |
| | SimpleField(name="created_at", type=SearchFieldDataType.DateTimeOffset), |
| | SearchField( |
| | name="content_vector", |
| | type=SearchFieldDataType.Collection(SearchFieldDataType.Single), |
| | searchable=True, |
| | vector_search_dimensions=1536, |
| | vector_search_profile_name="my-vector-profile" |
| | ) |
| | ] |
| |
|
| | vector_search = VectorSearch( |
| | algorithms=[HnswAlgorithmConfiguration(name="my-hnsw")], |
| | profiles=[ |
| | VectorSearchProfile( |
| | name="my-vector-profile", |
| | algorithm_configuration_name="my-hnsw" |
| | ) |
| | ] |
| | ) |
| |
|
| | index = SearchIndex( |
| | name=self.index_name, |
| | fields=fields, |
| | vector_search=vector_search |
| | ) |
| |
|
| | self.index_client.create_index(index) |
| | logger.info(f"Created index: {self.index_name}") |
| |
|
| | def generate_embeddings(self, texts: List[str]) -> List[List[float]]: |
| | """Generate embeddings using Azure OpenAI.""" |
| | try: |
| | embeddings = [] |
| | for text in texts: |
| | response = self.azure_openai_client.embeddings.create( |
| | input=text, |
| | model=self.embedding_deployment |
| | ) |
| | embeddings.append(response.data[0].embedding) |
| | return embeddings |
| | except Exception as e: |
| | logger.error(f"Error generating embeddings: {e}") |
| | raise |
| |
|
| | def index_document( |
| | self, |
| | chunks: List[str], |
| | filename: str, |
| | user_id: int, |
| | doc_id: str |
| | ) -> int: |
| | """Index document chunks with embeddings in Azure Search.""" |
| | try: |
| | |
| | logger.info(f"Generating embeddings for {len(chunks)} chunks...") |
| | embeddings = self.generate_embeddings(chunks) |
| | |
| | |
| | documents = [] |
| | for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)): |
| | doc = { |
| | "id": f"{doc_id}_{idx}", |
| | "content": chunk, |
| | "filename": filename, |
| | "doc_id": doc_id, |
| | "user_id": str(user_id), |
| | "chunk_index": idx, |
| | "created_at": datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"), |
| | "content_vector": embedding |
| | } |
| | documents.append(doc) |
| | |
| | |
| | result = self.search_client.upload_documents(documents=documents) |
| | logger.info(f"Indexed {len(documents)} chunks for {filename}") |
| | |
| | return len(documents) |
| | |
| | except Exception as e: |
| | logger.error(f"Error indexing document: {e}") |
| | raise |
| |
|
| | def search_document( |
| | self, |
| | query: str, |
| | doc_id: str, |
| | user_id: int, |
| | top_k: int = 3 |
| | ) -> List[Dict[str, Any]]: |
| | """Search within a specific document using vector search.""" |
| | try: |
| | |
| | query_embedding = self.generate_embeddings([query])[0] |
| | |
| | |
| | from azure.search.documents.models import VectorizedQuery |
| | |
| | vector_query = VectorizedQuery( |
| | vector=query_embedding, |
| | k_nearest_neighbors=top_k, |
| | fields="content_vector" |
| | ) |
| | |
| | results = self.search_client.search( |
| | search_text=None, |
| | vector_queries=[vector_query], |
| | filter=f"doc_id eq '{doc_id}' and user_id eq '{user_id}'", |
| | top=top_k, |
| | select=["content", "filename", "chunk_index"] |
| | ) |
| | |
| | |
| | search_results = [] |
| | for result in results: |
| | search_results.append({ |
| | "content": result["content"], |
| | "chunk_index": result.get("chunk_index", 0) |
| | }) |
| | |
| | return search_results |
| | |
| | except Exception as e: |
| | logger.error(f"Error searching document: {e}") |
| | raise |
| |
|
| | def delete_document(self, doc_id: str): |
| | """Delete all chunks of a document from the search index.""" |
| | try: |
| | |
| | results = self.search_client.search( |
| | search_text="*", |
| | filter=f"doc_id eq '{doc_id}'", |
| | select=["id"], |
| | top=1000 |
| | ) |
| | |
| | |
| | doc_ids = [{"id": r["id"]} for r in results] |
| | if doc_ids: |
| | self.search_client.delete_documents(documents=doc_ids) |
| | logger.info(f"Deleted {len(doc_ids)} chunks for document {doc_id}") |
| | |
| | except Exception as e: |
| | logger.error(f"Error deleting document: {e}") |
| | raise |
| |
|
| | def document_exists(self, doc_id: str, user_id: int) -> bool: |
| | """Check if a document is already indexed.""" |
| | try: |
| | results = self.search_client.search( |
| | search_text="*", |
| | filter=f"doc_id eq '{doc_id}' and user_id eq '{user_id}'", |
| | top=1, |
| | select=["id"] |
| | ) |
| | return len(list(results)) > 0 |
| | except: |
| | return False |
| |
|
| | rag_service = RAGService() |
| |
|