|
|
import chromadb |
|
|
from llama_index.core import VectorStoreIndex |
|
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
|
from llama_index.core import StorageContext, Settings |
|
|
from llama_index.core.schema import TextNode |
|
|
from openai import OpenAI |
|
|
from modal_client import ModalClient |
|
|
|
|
|
|
|
|
from structlog import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
from typing import Any, List |
|
|
|
|
|
|
|
|
from llama_index.core.bridge.pydantic import PrivateAttr |
|
|
from llama_index.core.embeddings import BaseEmbedding |
|
|
|
|
|
|
|
|
class CustomEmbeddings(BaseEmbedding): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_url:str, |
|
|
api_key:str, |
|
|
model_name: str , |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self._client = OpenAI( |
|
|
base_url=base_url, |
|
|
api_key=api_key |
|
|
) |
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def class_name(cls) -> str: |
|
|
return "custom" |
|
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]: |
|
|
return self._get_query_embedding(query) |
|
|
|
|
|
async def _aget_text_embedding(self, text: str) -> List[float]: |
|
|
return self._get_text_embedding(text) |
|
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]: |
|
|
embeddings = self._client.embeddings.create( |
|
|
model=self.model_name, |
|
|
input=[query] |
|
|
).data[0].embedding |
|
|
return embeddings |
|
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]: |
|
|
embeddings = self._client.embeddings.create( |
|
|
model=self.model_name, |
|
|
input=[text] |
|
|
).data[0].embedding |
|
|
return embeddings |
|
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: |
|
|
embeddings_data = self._client.embeddings.create( |
|
|
model=self.model_name, |
|
|
input=texts |
|
|
) |
|
|
return [embedding.embedding for embedding in embeddings_data.data] |
|
|
|
|
|
class EmbeddingService: |
|
|
def __init__(self, collection_name): |
|
|
config = ModalClient.embedding_config() |
|
|
Settings.embed_model = CustomEmbeddings( |
|
|
api_key=config.get("api_key"), |
|
|
base_url=config.get("base_url"), |
|
|
model_name=config.get("model"), |
|
|
embed_batch_size=32 |
|
|
) |
|
|
Settings.chunk_size = 1024 |
|
|
chroma_client = chromadb.EphemeralClient() |
|
|
chroma_collection = chroma_client.create_collection(collection_name) |
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
self.storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
|
|
|
|
def prepare_index(self,nodes): |
|
|
|
|
|
self.index = VectorStoreIndex.from_documents(nodes, storage_context=self.storage_context) |
|
|
|
|
|
|
|
|
def infer(self, query, top_k=10): |
|
|
retriever = self.index.as_retriever(similarity_top_k=top_k) |
|
|
results = retriever.retrieve(query) |
|
|
text = "" |
|
|
for result in results: |
|
|
text += "\n -------------------------- \n" |
|
|
text += f"name = {result.metadata['name']}\n" |
|
|
text += f"filename = {result.metadata['filename']}\n" |
|
|
text += f"type = {result.metadata['type']}\n" |
|
|
text += f"namespace = {result.metadata['namespace']}\n" |
|
|
text += f"content = {result.text}\n" |
|
|
return text |
|
|
|
|
|
|
|
|
|