izzelbas's picture
Update handler.py
1ec8314 verified
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import logging
# Set up logging
logger = logging.getLogger(__name__)
class EndpointHandler:
"""
Custom handler for Hugging Face Inference Endpoints.
This handler loads the 'izzelbas/roberta-large-ugm-cs-curriculum' model
and uses a custom search logic to find the best answer span in a given context.
"""
def __init__(self, path: str = ""):
"""
Initializes the model and tokenizer.
Args:
path (str): The path to the downloaded model assets.
Hugging Face Inference Endpoints automatically provides this.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Load tokenizer and model from the specified path
# If path is empty, it will download from the Hub (useful for local testing)
model_name_or_path = path if path else "izzelbas/roberta-large-ugm-cs-curriculum"
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
# Move model to the appropriate device and set to evaluation mode
self.model.to(self.device)
self.model.eval()
logger.info("Model and tokenizer loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {e}")
raise e
def __call__(self, data: dict) -> dict:
"""
Handles an inference request.
Args:
data (dict): A dictionary containing the input data. Expected format:
{
"inputs": {
"question": "Your question here",
"context": "The context paragraph here"
},
"parameters": {
"max_answer_len": 30
}
}
Returns:
dict: A dictionary containing the answer and its score.
"""
try:
inputs_data = data.pop("inputs", data)
question = inputs_data.get("question")
context = inputs_data.get("context")
if not question or not context:
return {"error": "Missing 'question' or 'context' in 'inputs'"}
# Get parameters or use defaults
params = data.pop("parameters", {})
max_answer_len = params.get("max_answer_len", 30)
# --- Start of user-provided search logic ---
# Tokenize input and send to GPU/CPU
inputs = self.tokenizer(
question,
context,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.device)
# Inference without gradient calculation for efficiency
with torch.no_grad():
outputs = self.model(**inputs)
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
best_score = float('-inf')
best_span = ""
# Loop through all possible start and end token positions
for start in range(len(start_logits)):
# Constrain the end position based on max_answer_len
end_limit = min(start + max_answer_len, len(end_logits))
for end in range(start, end_limit):
score = start_logits[start] + end_logits[end]
if score > best_score:
# Decode the span of tokens to get the answer text
span_ids = inputs["input_ids"][0][start : end + 1]
span_text = self.tokenizer.decode(span_ids, skip_special_tokens=True)
best_score = score.item()
best_span = span_text
# --- End of user-provided search logic ---
return {"answer": best_span, "score": best_score}
except Exception as e:
logger.error(f"Error during inference: {e}")
return {"error": str(e)}