import torch from transformers import AutoTokenizer, AutoModelForQuestionAnswering import logging # Set up logging logger = logging.getLogger(__name__) class EndpointHandler: """ Custom handler for Hugging Face Inference Endpoints. This handler loads the 'izzelbas/roberta-large-ugm-cs-curriculum' model and uses a custom search logic to find the best answer span in a given context. """ def __init__(self, path: str = ""): """ Initializes the model and tokenizer. Args: path (str): The path to the downloaded model assets. Hugging Face Inference Endpoints automatically provides this. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Load tokenizer and model from the specified path # If path is empty, it will download from the Hub (useful for local testing) model_name_or_path = path if path else "izzelbas/roberta-large-ugm-cs-curriculum" try: self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path) # Move model to the appropriate device and set to evaluation mode self.model.to(self.device) self.model.eval() logger.info("Model and tokenizer loaded successfully.") except Exception as e: logger.error(f"Failed to load model or tokenizer: {e}") raise e def __call__(self, data: dict) -> dict: """ Handles an inference request. Args: data (dict): A dictionary containing the input data. Expected format: { "inputs": { "question": "Your question here", "context": "The context paragraph here" }, "parameters": { "max_answer_len": 30 } } Returns: dict: A dictionary containing the answer and its score. """ try: inputs_data = data.pop("inputs", data) question = inputs_data.get("question") context = inputs_data.get("context") if not question or not context: return {"error": "Missing 'question' or 'context' in 'inputs'"} # Get parameters or use defaults params = data.pop("parameters", {}) max_answer_len = params.get("max_answer_len", 30) # --- Start of user-provided search logic --- # Tokenize input and send to GPU/CPU inputs = self.tokenizer( question, context, return_tensors="pt", truncation=True, max_length=512 ).to(self.device) # Inference without gradient calculation for efficiency with torch.no_grad(): outputs = self.model(**inputs) start_logits = outputs.start_logits[0] end_logits = outputs.end_logits[0] best_score = float('-inf') best_span = "" # Loop through all possible start and end token positions for start in range(len(start_logits)): # Constrain the end position based on max_answer_len end_limit = min(start + max_answer_len, len(end_logits)) for end in range(start, end_limit): score = start_logits[start] + end_logits[end] if score > best_score: # Decode the span of tokens to get the answer text span_ids = inputs["input_ids"][0][start : end + 1] span_text = self.tokenizer.decode(span_ids, skip_special_tokens=True) best_score = score.item() best_span = span_text # --- End of user-provided search logic --- return {"answer": best_span, "score": best_score} except Exception as e: logger.error(f"Error during inference: {e}") return {"error": str(e)}