Chittrarasu commited on
Commit
7098798
·
1 Parent(s): 2c4a024
Files changed (2) hide show
  1. Requirements.txt +2 -1
  2. service/prediction_service.py +27 -5
Requirements.txt CHANGED
@@ -6,4 +6,5 @@ scikit-learn
6
  numpy
7
  pandas
8
  openpyxl
9
- numpy
 
 
6
  numpy
7
  pandas
8
  openpyxl
9
+ numpy
10
+ huggingface-hub
service/prediction_service.py CHANGED
@@ -1,12 +1,34 @@
1
  import pickle
2
  from sentence_transformers import SentenceTransformer
3
- import numpy as np
 
4
 
5
- # Load Model and Transformer
6
- with open('models/logistic_regression_model.pkl', 'rb') as f:
7
- logistic_model = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- model = SentenceTransformer('models/sentence_transformer')
10
 
11
  def predict_label(message: str):
12
  embedding = model.encode([message])
 
1
  import pickle
2
  from sentence_transformers import SentenceTransformer
3
+ import os
4
+ from huggingface_hub import hf_hub_download
5
 
6
+ # Get the Hugging Face token from environment variable
7
+ hf_token = os.getenv('HF_TOKEN')
8
+
9
+ # Hugging Face Model ID and local model directory
10
+ hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5'
11
+ model_dir = 'models/sentence_transformer'
12
+
13
+ # Create model directory if not exists
14
+ os.makedirs(model_dir, exist_ok=True)
15
+
16
+ # Download model if not already downloaded
17
+ if not os.path.exists(os.path.join(model_dir, 'config.json')):
18
+ print(f"Downloading model '{hf_model_id}' from Hugging Face...")
19
+ model_path = hf_hub_download(
20
+ repo_id=hf_model_id,
21
+ filename='config.json',
22
+ cache_dir=model_dir,
23
+ token=hf_token
24
+ )
25
+ # Load model from Hugging Face with token
26
+ model = SentenceTransformer(hf_model_id, use_auth_token=hf_token, trust_remote_code=True)
27
+ model.save(model_dir)
28
+ else:
29
+ print(f"Loading model from local directory: {model_dir}")
30
+ model = SentenceTransformer(model_dir)
31
 
 
32
 
33
  def predict_label(message: str):
34
  embedding = model.encode([message])