matsuap's picture
Upload folder using huggingface_hub
792ad00 verified
import os
import logging
import uuid
from typing import List, Dict, Any, Optional
from datetime import datetime
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
SearchIndex,
SimpleField,
SearchableField,
SearchField,
VectorSearch,
HnswAlgorithmConfiguration,
VectorSearchProfile,
SearchFieldDataType
)
from azure.core.credentials import AzureKeyCredential
from openai import AzureOpenAI
from core.config import settings
logger = logging.getLogger(__name__)
class RAGService:
def __init__(self):
# Azure Search
self.search_endpoint = settings.AZURE_SEARCH_ENDPOINT
self.search_key = settings.AZURE_SEARCH_KEY
self.index_name = settings.AZURE_SEARCH_INDEX_NAME
# Azure OpenAI for embeddings
self.azure_openai_client = AzureOpenAI(
api_key=settings.AZURE_OPENAI_API_KEY,
api_version=settings.AZURE_OPENAI_API_VERSION,
azure_endpoint=settings.AZURE_OPENAI_ENDPOINT.split("/openai/")[0]
)
self.embedding_deployment = settings.AZURE_OPENAI_DEPLOYMENT_NAME
# Initialize clients
self.search_client = SearchClient(
endpoint=self.search_endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(self.search_key)
)
self.index_client = SearchIndexClient(
endpoint=self.search_endpoint,
credential=AzureKeyCredential(self.search_key)
)
# Ensure index exists
self._ensure_index_exists()
def _ensure_index_exists(self):
"""Create or recreate Azure AI Search index if it doesn't exist or is incompatible."""
try:
existing_index = self.index_client.get_index(self.index_name)
# Check for required fields
required_fields = {"filename", "doc_id", "user_id", "content_vector"}
existing_fields = {field.name for field in existing_index.fields}
if not required_fields.issubset(existing_fields):
logger.warning(f"Index {self.index_name} is incompatible. Recreating...")
self.index_client.delete_index(self.index_name)
self._create_index()
else:
logger.info(f"Index {self.index_name} exists and is compatible")
except Exception:
logger.info(f"Creating index {self.index_name}...")
self._create_index()
def _create_index(self):
"""Create the search index with vector configuration."""
fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SearchableField(name="content", type=SearchFieldDataType.String),
SearchableField(name="filename", type=SearchFieldDataType.String, filterable=True),
SimpleField(name="doc_id", type=SearchFieldDataType.String, filterable=True),
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
SimpleField(name="chunk_index", type=SearchFieldDataType.Int32),
SimpleField(name="created_at", type=SearchFieldDataType.DateTimeOffset),
SearchField(
name="content_vector",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=1536,
vector_search_profile_name="my-vector-profile"
)
]
vector_search = VectorSearch(
algorithms=[HnswAlgorithmConfiguration(name="my-hnsw")],
profiles=[
VectorSearchProfile(
name="my-vector-profile",
algorithm_configuration_name="my-hnsw"
)
]
)
index = SearchIndex(
name=self.index_name,
fields=fields,
vector_search=vector_search
)
self.index_client.create_index(index)
logger.info(f"Created index: {self.index_name}")
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings using Azure OpenAI."""
try:
embeddings = []
for text in texts:
response = self.azure_openai_client.embeddings.create(
input=text,
model=self.embedding_deployment
)
embeddings.append(response.data[0].embedding)
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise
def index_document(
self,
chunks: List[str],
filename: str,
user_id: int,
doc_id: str
) -> int:
"""Index document chunks with embeddings in Azure Search."""
try:
# Generate embeddings
logger.info(f"Generating embeddings for {len(chunks)} chunks...")
embeddings = self.generate_embeddings(chunks)
# Prepare documents
documents = []
for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
doc = {
"id": f"{doc_id}_{idx}",
"content": chunk,
"filename": filename,
"doc_id": doc_id,
"user_id": str(user_id),
"chunk_index": idx,
"created_at": datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
"content_vector": embedding
}
documents.append(doc)
# Upload to search index
result = self.search_client.upload_documents(documents=documents)
logger.info(f"Indexed {len(documents)} chunks for {filename}")
return len(documents)
except Exception as e:
logger.error(f"Error indexing document: {e}")
raise
def search_document(
self,
query: str,
doc_id: str,
user_id: int,
top_k: int = 3
) -> List[Dict[str, Any]]:
"""Search within a specific document using vector search."""
try:
# Generate query embedding
query_embedding = self.generate_embeddings([query])[0]
# Vector search with filters
from azure.search.documents.models import VectorizedQuery
vector_query = VectorizedQuery(
vector=query_embedding,
k_nearest_neighbors=top_k,
fields="content_vector"
)
results = self.search_client.search(
search_text=None,
vector_queries=[vector_query],
filter=f"doc_id eq '{doc_id}' and user_id eq '{user_id}'",
top=top_k,
select=["content", "filename", "chunk_index"]
)
# Format results
search_results = []
for result in results:
search_results.append({
"content": result["content"],
"chunk_index": result.get("chunk_index", 0)
})
return search_results
except Exception as e:
logger.error(f"Error searching document: {e}")
raise
def delete_document(self, doc_id: str):
"""Delete all chunks of a document from the search index."""
try:
# Search for all chunks
results = self.search_client.search(
search_text="*",
filter=f"doc_id eq '{doc_id}'",
select=["id"],
top=1000
)
# Delete all chunks
doc_ids = [{"id": r["id"]} for r in results]
if doc_ids:
self.search_client.delete_documents(documents=doc_ids)
logger.info(f"Deleted {len(doc_ids)} chunks for document {doc_id}")
except Exception as e:
logger.error(f"Error deleting document: {e}")
raise
def document_exists(self, doc_id: str, user_id: int) -> bool:
"""Check if a document is already indexed."""
try:
results = self.search_client.search(
search_text="*",
filter=f"doc_id eq '{doc_id}' and user_id eq '{user_id}'",
top=1,
select=["id"]
)
return len(list(results)) > 0
except:
return False
rag_service = RAGService()