Vishwanath77 commited on
Commit
0119a71
·
verified ·
1 Parent(s): ebce1fe

Update src/apps/utils/embeddings.py

Browse files
Files changed (1) hide show
  1. src/apps/utils/embeddings.py +27 -22
src/apps/utils/embeddings.py CHANGED
@@ -1,22 +1,27 @@
1
- import os
2
- from transformers import AutoModel
3
- from numpy.linalg import norm
4
-
5
- cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
6
- _model = None
7
-
8
- def get_model():
9
- global _model
10
- if _model is None:
11
- BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
12
- model_path = os.path.join(BASE_DIR, "volumes", "models", "jina-embeddings-v2-base-en")
13
- print(f"Loading the model weights from: {model_path}")
14
- _model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
15
- return _model
16
-
17
- def get_embeddings(text:list):
18
- model = get_model()
19
- embeddings = model.encode(text)
20
- normalized_embeddings = embeddings/norm(embeddings[0])
21
- return normalized_embeddings
22
-
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoModel
3
+ from numpy.linalg import norm
4
+
5
+ cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
6
+ _model = None
7
+
8
+ def get_model():
9
+ global _model
10
+ if _model is None:
11
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
12
+ model_path = os.path.join(BASE_DIR, "volumes", "models", "jina-embeddings-v2-base-en")
13
+
14
+ if os.path.exists(model_path):
15
+ print(f"Loading the model weights from local path: {model_path}")
16
+ _model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
17
+ else:
18
+ print("Local model weights not found. Downloading from Hugging Face Hub (jinaai/jina-embeddings-v2-base-en)...")
19
+ _model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
20
+ return _model
21
+
22
+ def get_embeddings(text:list):
23
+ model = get_model()
24
+ embeddings = model.encode(text)
25
+ normalized_embeddings = embeddings/norm(embeddings[0])
26
+ return normalized_embeddings
27
+