|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForQuestionAnswering |
|
|
import logging |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Custom handler for Hugging Face Inference Endpoints. |
|
|
This handler loads the 'izzelbas/roberta-large-ugm-cs-curriculum' model |
|
|
and uses a custom search logic to find the best answer span in a given context. |
|
|
""" |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initializes the model and tokenizer. |
|
|
|
|
|
Args: |
|
|
path (str): The path to the downloaded model assets. |
|
|
Hugging Face Inference Endpoints automatically provides this. |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
|
|
|
model_name_or_path = path if path else "izzelbas/roberta-large-ugm-cs-curriculum" |
|
|
|
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
|
self.model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path) |
|
|
|
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
logger.info("Model and tokenizer loaded successfully.") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model or tokenizer: {e}") |
|
|
raise e |
|
|
|
|
|
def __call__(self, data: dict) -> dict: |
|
|
""" |
|
|
Handles an inference request. |
|
|
|
|
|
Args: |
|
|
data (dict): A dictionary containing the input data. Expected format: |
|
|
{ |
|
|
"inputs": { |
|
|
"question": "Your question here", |
|
|
"context": "The context paragraph here" |
|
|
}, |
|
|
"parameters": { |
|
|
"max_answer_len": 30 |
|
|
} |
|
|
} |
|
|
|
|
|
Returns: |
|
|
dict: A dictionary containing the answer and its score. |
|
|
""" |
|
|
try: |
|
|
inputs_data = data.pop("inputs", data) |
|
|
question = inputs_data.get("question") |
|
|
context = inputs_data.get("context") |
|
|
|
|
|
if not question or not context: |
|
|
return {"error": "Missing 'question' or 'context' in 'inputs'"} |
|
|
|
|
|
|
|
|
params = data.pop("parameters", {}) |
|
|
max_answer_len = params.get("max_answer_len", 30) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
question, |
|
|
context, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
|
|
|
start_logits = outputs.start_logits[0] |
|
|
end_logits = outputs.end_logits[0] |
|
|
|
|
|
best_score = float('-inf') |
|
|
best_span = "" |
|
|
|
|
|
|
|
|
for start in range(len(start_logits)): |
|
|
|
|
|
end_limit = min(start + max_answer_len, len(end_logits)) |
|
|
for end in range(start, end_limit): |
|
|
score = start_logits[start] + end_logits[end] |
|
|
|
|
|
if score > best_score: |
|
|
|
|
|
span_ids = inputs["input_ids"][0][start : end + 1] |
|
|
span_text = self.tokenizer.decode(span_ids, skip_special_tokens=True) |
|
|
|
|
|
best_score = score.item() |
|
|
best_span = span_text |
|
|
|
|
|
|
|
|
|
|
|
return {"answer": best_span, "score": best_score} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during inference: {e}") |
|
|
return {"error": str(e)} |