Davide Panza commited on
Commit
cf868d7
·
verified ·
1 Parent(s): 582ee3f

Update app/backend/chromadb_utils.py

Browse files
Files changed (1) hide show
  1. app/backend/chromadb_utils.py +1 -43
app/backend/chromadb_utils.py CHANGED
@@ -1,49 +1,8 @@
1
  import chromadb
2
  from chromadb.utils import embedding_functions
3
  from .text_processing import text_chunking
4
- import torch
5
- from transformers import AutoTokenizer, AutoModel
6
 
7
 
8
- class SimpleEmbeddingFunction:
9
- def __init__(self, model_path):
10
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
11
- self.model = AutoModel.from_pretrained(model_path, local_files_only=True)
12
- self.model.eval()
13
-
14
- def __call__(self, texts):
15
- if isinstance(texts, str):
16
- texts = [texts]
17
-
18
- inputs = self.tokenizer(
19
- texts,
20
- padding=True,
21
- truncation=True,
22
- return_tensors='pt',
23
- max_length=384
24
- )
25
-
26
- with torch.no_grad():
27
- outputs = self.model(**inputs)
28
- embeddings = outputs.last_hidden_state.mean(dim=1)
29
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
30
-
31
- return embeddings.numpy().tolist()
32
-
33
-
34
- def initialize_chromadb(model_name, model_path):
35
- import chromadb
36
-
37
- if not os.path.exists(model_path):
38
- raise Exception(f"Local model not found at {model_path}")
39
-
40
- embedding_func = SimpleEmbeddingFunction(model_path)
41
- client = chromadb.Client()
42
-
43
- return client, embedding_func
44
-
45
-
46
- """
47
  def initialize_chromadb(EMBEDDING_MODEL, local_model_path=None):
48
  """
49
  Initialize ChromaDB client and embedding function, using a local model path if provided.
@@ -60,8 +19,7 @@ def initialize_chromadb(EMBEDDING_MODEL, local_model_path=None):
60
  )
61
 
62
  return client, embedding_func
63
- """
64
-
65
 
66
  def initialize_collection(client, embedding_func, collection_name):
67
  """
 
1
  import chromadb
2
  from chromadb.utils import embedding_functions
3
  from .text_processing import text_chunking
 
 
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def initialize_chromadb(EMBEDDING_MODEL, local_model_path=None):
7
  """
8
  Initialize ChromaDB client and embedding function, using a local model path if provided.
 
19
  )
20
 
21
  return client, embedding_func
22
+
 
23
 
24
  def initialize_collection(client, embedding_func, collection_name):
25
  """