File size: 2,411 Bytes
9288345
7098798
 
3369b7c
9288345
7098798
 
 
 
 
4e3ad22
5e0bbc5
7098798
 
 
 
 
 
 
 
 
 
 
4b5415a
35b327b
3369b7c
 
 
 
 
 
 
 
4b5415a
 
81c6189
3369b7c
 
 
 
81c6189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from sentence_transformers import SentenceTransformer
import os
from huggingface_hub import hf_hub_download
from joblib import load  # <-- Import this to load the model

# Get the Hugging Face token from environment variable
hf_token = os.getenv('HF_TOKEN')

# Hugging Face Model ID and local model directory
hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5'
model_dir = '/tmp/sentence_transformer'  # Use /tmp for write permissions
clf_model_path = 'models/logistic_regression_model.pkl'

# Create model directory if not exists
os.makedirs(model_dir, exist_ok=True)

# Download model if not already downloaded
if not os.path.exists(os.path.join(model_dir, 'config.json')):
    print(f"Downloading model '{hf_model_id}' from Hugging Face...")
    model = SentenceTransformer(hf_model_id, use_auth_token=hf_token, trust_remote_code=True)
    model.save(model_dir)
else:
    print(f"Loading model from local directory: {model_dir}")
    model = SentenceTransformer(model_dir, trust_remote_code=True)  # Added trust_remote_code=True

# ✅ Load the logistic regression model and define clf globally
clf = None  # Initialize as None
if os.path.exists(clf_model_path):
    clf = load(clf_model_path)  # Load the logistic regression model
    print("Logistic Regression model loaded successfully.")
else:
    print("Logistic Regression model not found. Ensure it is saved in /tmp.")

# Define predict_label function
def predict_label(text):
    try:
        # Check if clf is loaded
        if clf is None:
            raise ValueError("Logistic Regression model is not loaded.")
        
        # Ensure input is a list for the model
        if not isinstance(text, list):
            text = [text]
        
        # Generate embeddings
        embeddings = model.encode(text)
        
        # Ensure embeddings are in the correct shape
        if len(embeddings) == 0:
            raise ValueError("No embeddings generated.")
        
        # Predict using the logistic regression model
        prediction = clf.predict(embeddings)
        probability = clf.predict_proba(embeddings).max()
        
        # Convert label to string ("0" or "1")
        label = str(prediction[0])
        
        # Return label and probability
        return label, float(probability)

    except Exception as e:
        # Log the exception for debugging
        print(f"Error in predict_label: {e}")
        return "Error", 0.0