SCANSKY's picture
Update handler.py
54676ac verified
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import os
# Initialize model and tokenizer
model_name = "SCANSKY/distilbertTourism-multilingual-rclassifier"
model = None
def load_model_components():
"""Load the model and tokenizer once at startup"""
global tokenizer, model
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name)
model.eval()
print("Model and tokenizer loaded successfully.")
# Load model components when the handler is initialized
load_model_components()
def predict_relevance(text):
"""Predict whether a text is relevant or not"""
if not text.strip():
return {"error": "Empty text provided."}
inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=64, # Should match training max_length
return_tensors="pt"
)
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {k: v.to(device) for k, v in inputs.items()}
model.to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
predicted_class = torch.argmax(probs).item()
confidence = probs[0][predicted_class].item()
return {
"prediction": predicted_class, # 1 for relevant, 0 for not relevant
"confidence": float(confidence) * 100, # Convert to percentage
"text": text
}
class EndpointHandler:
def __init__(self, model_dir=None):
# Model and tokenizer are loaded globally, so no need to reinitialize here
# The `model_dir` argument is required by Hugging Face's inference toolkit
pass
def preprocess(self, data):
# Extract the input text from the request
text = data.get("inputs", "")
# Split by newlines and remove empty lines
lines = [line.strip() for line in text.split('\n') if line.strip()]
return lines
def inference(self, lines):
results = []
for line in lines:
result = predict_relevance(line)
results.append(result)
return results
def postprocess(self, outputs):
processed_results = []
for output in outputs:
if "error" in output:
processed_results.append({
"text": output.get("text", ""),
"error": output["error"],
"confidence": 0
})
else:
processed_results.append({
"text": output["text"],
"confidence": output["confidence"],
"relevance": "RELEVANT" if output["prediction"] == 1 else "IRRELEVANT"
})
return processed_results
def __call__(self, data):
# Main method to handle the request
lines = self.preprocess(data)
outputs = self.inference(lines)
return self.postprocess(outputs)