Davide Panza commited on
Commit
36db4f7
·
verified ·
1 Parent(s): 770aec3

Update app/backend/chromadb_utils.py

Browse files
Files changed (1) hide show
  1. app/backend/chromadb_utils.py +42 -0
app/backend/chromadb_utils.py CHANGED
@@ -1,8 +1,49 @@
1
  import chromadb
2
  from chromadb.utils import embedding_functions
3
  from .text_processing import text_chunking
 
 
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def initialize_chromadb(EMBEDDING_MODEL, local_model_path=None):
7
  """
8
  Initialize ChromaDB client and embedding function, using a local model path if provided.
@@ -19,6 +60,7 @@ def initialize_chromadb(EMBEDDING_MODEL, local_model_path=None):
19
  )
20
 
21
  return client, embedding_func
 
22
 
23
 
24
  def initialize_collection(client, embedding_func, collection_name):
 
1
  import chromadb
2
  from chromadb.utils import embedding_functions
3
  from .text_processing import text_chunking
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModel
6
 
7
 
8
+ class SimpleEmbeddingFunction:
9
+ def __init__(self, model_path):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
11
+ self.model = AutoModel.from_pretrained(model_path, local_files_only=True)
12
+ self.model.eval()
13
+
14
+ def __call__(self, texts):
15
+ if isinstance(texts, str):
16
+ texts = [texts]
17
+
18
+ inputs = self.tokenizer(
19
+ texts,
20
+ padding=True,
21
+ truncation=True,
22
+ return_tensors='pt',
23
+ max_length=384
24
+ )
25
+
26
+ with torch.no_grad():
27
+ outputs = self.model(**inputs)
28
+ embeddings = outputs.last_hidden_state.mean(dim=1)
29
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
30
+
31
+ return embeddings.numpy().tolist()
32
+
33
+
34
+ def initialize_chromadb(model_name, model_path):
35
+ import chromadb
36
+
37
+ if not os.path.exists(model_path):
38
+ raise Exception(f"Local model not found at {model_path}")
39
+
40
+ embedding_func = SimpleEmbeddingFunction(model_path)
41
+ client = chromadb.Client()
42
+
43
+ return client, embedding_func
44
+
45
+
46
+ """
47
  def initialize_chromadb(EMBEDDING_MODEL, local_model_path=None):
48
  """
49
  Initialize ChromaDB client and embedding function, using a local model path if provided.
 
60
  )
61
 
62
  return client, embedding_func
63
+ """
64
 
65
 
66
  def initialize_collection(client, embedding_func, collection_name):