File size: 3,315 Bytes
933c2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import chromadb
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import StorageContext, Settings
from llama_index.core.schema import TextNode
from openai import OpenAI
from modal_client import ModalClient


from structlog import get_logger

logger = get_logger(__name__)

from typing import Any, List


from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.embeddings import BaseEmbedding


class CustomEmbeddings(BaseEmbedding):
   
    def __init__(
        self,
        base_url:str,
        api_key:str,
        model_name: str ,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self._client = OpenAI(
            base_url=base_url,
            api_key=api_key
        )
        self.model_name = model_name


    @classmethod
    def class_name(cls) -> str:
        return "custom"

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)

    def _get_query_embedding(self, query: str) -> List[float]:
        embeddings = self._client.embeddings.create(
            model=self.model_name,
            input=[query] 
        ).data[0].embedding
        return embeddings

    def _get_text_embedding(self, text: str) -> List[float]:
        embeddings = self._client.embeddings.create(
            model=self.model_name,
            input=[text] 
        ).data[0].embedding
        return embeddings

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        embeddings_data = self._client.embeddings.create(
            model=self.model_name,
            input=texts
        )
        return [embedding.embedding for embedding in embeddings_data.data]

class EmbeddingService:
    def __init__(self, collection_name):
        config = ModalClient.embedding_config()
        Settings.embed_model = CustomEmbeddings(
            api_key=config.get("api_key"),
            base_url=config.get("base_url"),
            model_name=config.get("model"),
            embed_batch_size=32
        )
        Settings.chunk_size = 1024
        chroma_client = chromadb.EphemeralClient()
        chroma_collection = chroma_client.create_collection(collection_name)
        vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
        self.storage_context = StorageContext.from_defaults(vector_store=vector_store)
    
    def prepare_index(self,nodes):
        
        self.index = VectorStoreIndex.from_documents(nodes, storage_context=self.storage_context)

        
    def infer(self, query, top_k=10):
        retriever = self.index.as_retriever(similarity_top_k=top_k)
        results = retriever.retrieve(query)
        text = ""
        for result in results:
            text += "\n -------------------------- \n"
            text += f"name = {result.metadata['name']}\n"
            text += f"filename = {result.metadata['filename']}\n"
            text += f"type = {result.metadata['type']}\n"
            text += f"namespace = {result.metadata['namespace']}\n"
            text += f"content = {result.text}\n"
        return text