File size: 3,094 Bytes
e3b4f0e 513c09e e3b4f0e 513c09e e3b4f0e 513c09e e3b4f0e 513c09e e3b4f0e 513c09e e3b4f0e 513c09e 54676ac 513c09e e3b4f0e 513c09e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import os
# Initialize model and tokenizer
model_name = "SCANSKY/distilbertTourism-multilingual-rclassifier"
model = None
def load_model_components():
"""Load the model and tokenizer once at startup"""
global tokenizer, model
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name)
model.eval()
print("Model and tokenizer loaded successfully.")
# Load model components when the handler is initialized
load_model_components()
def predict_relevance(text):
"""Predict whether a text is relevant or not"""
if not text.strip():
return {"error": "Empty text provided."}
inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=64, # Should match training max_length
return_tensors="pt"
)
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {k: v.to(device) for k, v in inputs.items()}
model.to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
predicted_class = torch.argmax(probs).item()
confidence = probs[0][predicted_class].item()
return {
"prediction": predicted_class, # 1 for relevant, 0 for not relevant
"confidence": float(confidence) * 100, # Convert to percentage
"text": text
}
class EndpointHandler:
def __init__(self, model_dir=None):
# Model and tokenizer are loaded globally, so no need to reinitialize here
# The `model_dir` argument is required by Hugging Face's inference toolkit
pass
def preprocess(self, data):
# Extract the input text from the request
text = data.get("inputs", "")
# Split by newlines and remove empty lines
lines = [line.strip() for line in text.split('\n') if line.strip()]
return lines
def inference(self, lines):
results = []
for line in lines:
result = predict_relevance(line)
results.append(result)
return results
def postprocess(self, outputs):
processed_results = []
for output in outputs:
if "error" in output:
processed_results.append({
"text": output.get("text", ""),
"error": output["error"],
"confidence": 0
})
else:
processed_results.append({
"text": output["text"],
"confidence": output["confidence"],
"relevance": "RELEVANT" if output["prediction"] == 1 else "IRRELEVANT"
})
return processed_results
def __call__(self, data):
# Main method to handle the request
lines = self.preprocess(data)
outputs = self.inference(lines)
return self.postprocess(outputs) |