File size: 3,094 Bytes
e3b4f0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513c09e
e3b4f0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513c09e
e3b4f0e
 
 
 
 
 
 
 
 
 
 
 
513c09e
 
 
e3b4f0e
513c09e
 
 
 
 
 
e3b4f0e
513c09e
 
 
e3b4f0e
513c09e
 
 
 
 
 
 
 
 
54676ac
513c09e
 
e3b4f0e
 
 
513c09e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import os

# Initialize model and tokenizer
model_name = "SCANSKY/distilbertTourism-multilingual-rclassifier" 
model = None

def load_model_components():
    """Load the model and tokenizer once at startup"""
    global tokenizer, model
    tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    model = DistilBertForSequenceClassification.from_pretrained(model_name)
    model.eval()
    print("Model and tokenizer loaded successfully.")

# Load model components when the handler is initialized
load_model_components()

def predict_relevance(text):
    """Predict whether a text is relevant or not"""
    if not text.strip():
        return {"error": "Empty text provided."}
    
    inputs = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=64,  # Should match training max_length
        return_tensors="pt"
    )
    
    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    model.to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    probs = torch.softmax(outputs.logits, dim=1)
    predicted_class = torch.argmax(probs).item()
    confidence = probs[0][predicted_class].item()
    
    return {
        "prediction": predicted_class,  # 1 for relevant, 0 for not relevant
        "confidence": float(confidence) * 100,  # Convert to percentage
        "text": text
    }

class EndpointHandler:
    def __init__(self, model_dir=None):
        # Model and tokenizer are loaded globally, so no need to reinitialize here
        # The `model_dir` argument is required by Hugging Face's inference toolkit
        pass

    def preprocess(self, data):
        # Extract the input text from the request
        text = data.get("inputs", "")
        # Split by newlines and remove empty lines
        lines = [line.strip() for line in text.split('\n') if line.strip()]
        return lines

    def inference(self, lines):
        results = []
        for line in lines:
            result = predict_relevance(line)
            results.append(result)
        return results

    def postprocess(self, outputs):
        processed_results = []
        for output in outputs:
            if "error" in output:
                processed_results.append({
                    "text": output.get("text", ""),
                    "error": output["error"],
                    "confidence": 0
                })
            else:
                processed_results.append({
                    "text": output["text"],
                    "confidence": output["confidence"],
                    "relevance": "RELEVANT" if output["prediction"] == 1 else "IRRELEVANT"
                })
        return processed_results

    def __call__(self, data):
        # Main method to handle the request
        lines = self.preprocess(data)
        outputs = self.inference(lines)
        return self.postprocess(outputs)