from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch import os # Initialize model and tokenizer model_name = "SCANSKY/distilbertTourism-multilingual-rclassifier" model = None def load_model_components(): """Load the model and tokenizer once at startup""" global tokenizer, model tokenizer = DistilBertTokenizer.from_pretrained(model_name) model = DistilBertForSequenceClassification.from_pretrained(model_name) model.eval() print("Model and tokenizer loaded successfully.") # Load model components when the handler is initialized load_model_components() def predict_relevance(text): """Predict whether a text is relevant or not""" if not text.strip(): return {"error": "Empty text provided."} inputs = tokenizer( text, padding="max_length", truncation=True, max_length=64, # Should match training max_length return_tensors="pt" ) # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inputs = {k: v.to(device) for k, v in inputs.items()} model.to(device) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1) predicted_class = torch.argmax(probs).item() confidence = probs[0][predicted_class].item() return { "prediction": predicted_class, # 1 for relevant, 0 for not relevant "confidence": float(confidence) * 100, # Convert to percentage "text": text } class EndpointHandler: def __init__(self, model_dir=None): # Model and tokenizer are loaded globally, so no need to reinitialize here # The `model_dir` argument is required by Hugging Face's inference toolkit pass def preprocess(self, data): # Extract the input text from the request text = data.get("inputs", "") # Split by newlines and remove empty lines lines = [line.strip() for line in text.split('\n') if line.strip()] return lines def inference(self, lines): results = [] for line in lines: result = predict_relevance(line) results.append(result) return results def postprocess(self, outputs): processed_results = [] for output in outputs: if "error" in output: processed_results.append({ "text": output.get("text", ""), "error": output["error"], "confidence": 0 }) else: processed_results.append({ "text": output["text"], "confidence": output["confidence"], "relevance": "RELEVANT" if output["prediction"] == 1 else "IRRELEVANT" }) return processed_results def __call__(self, data): # Main method to handle the request lines = self.preprocess(data) outputs = self.inference(lines) return self.postprocess(outputs)