Chittrarasu commited on
Commit
4e3ad22
·
1 Parent(s): 7098798
Files changed (1) hide show
  1. service/prediction_service.py +1 -17
service/prediction_service.py CHANGED
@@ -1,4 +1,3 @@
1
- import pickle
2
  from sentence_transformers import SentenceTransformer
3
  import os
4
  from huggingface_hub import hf_hub_download
@@ -8,7 +7,7 @@ hf_token = os.getenv('HF_TOKEN')
8
 
9
  # Hugging Face Model ID and local model directory
10
  hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5'
11
- model_dir = 'models/sentence_transformer'
12
 
13
  # Create model directory if not exists
14
  os.makedirs(model_dir, exist_ok=True)
@@ -16,23 +15,8 @@ os.makedirs(model_dir, exist_ok=True)
16
  # Download model if not already downloaded
17
  if not os.path.exists(os.path.join(model_dir, 'config.json')):
18
  print(f"Downloading model '{hf_model_id}' from Hugging Face...")
19
- model_path = hf_hub_download(
20
- repo_id=hf_model_id,
21
- filename='config.json',
22
- cache_dir=model_dir,
23
- token=hf_token
24
- )
25
- # Load model from Hugging Face with token
26
  model = SentenceTransformer(hf_model_id, use_auth_token=hf_token, trust_remote_code=True)
27
  model.save(model_dir)
28
  else:
29
  print(f"Loading model from local directory: {model_dir}")
30
  model = SentenceTransformer(model_dir)
31
-
32
-
33
- def predict_label(message: str):
34
- embedding = model.encode([message])
35
- prediction = logistic_model.predict(embedding)[0]
36
- probability = logistic_model.predict_proba(embedding)[0].max()
37
-
38
- return prediction, float(probability)
 
 
1
  from sentence_transformers import SentenceTransformer
2
  import os
3
  from huggingface_hub import hf_hub_download
 
7
 
8
  # Hugging Face Model ID and local model directory
9
  hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5'
10
+ model_dir = '/tmp/sentence_transformer' # Use /tmp for write permissions
11
 
12
  # Create model directory if not exists
13
  os.makedirs(model_dir, exist_ok=True)
 
15
  # Download model if not already downloaded
16
  if not os.path.exists(os.path.join(model_dir, 'config.json')):
17
  print(f"Downloading model '{hf_model_id}' from Hugging Face...")
 
 
 
 
 
 
 
18
  model = SentenceTransformer(hf_model_id, use_auth_token=hf_token, trust_remote_code=True)
19
  model.save(model_dir)
20
  else:
21
  print(f"Loading model from local directory: {model_dir}")
22
  model = SentenceTransformer(model_dir)