File size: 4,463 Bytes
608c050
a0c0878
1ec8314
a0c0878
1ec8314
 
2439254
1ec8314
 
 
 
 
 
 
 
 
 
 
 
 
 
663e929
1ec8314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8b54ee
1ec8314
 
 
 
 
 
 
a0c0878
1ec8314
 
a0c0878
1ec8314
 
 
a0c0878
1ec8314
a0c0878
1ec8314
 
 
 
 
 
 
 
 
 
 
 
36bbfbe
1ec8314
 
36bbfbe
1ec8314
 
5777bbb
1ec8314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36bbfbe
1ec8314
a0c0878
1ec8314
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import logging

# Set up logging
logger = logging.getLogger(__name__)

class EndpointHandler:
    """
    Custom handler for Hugging Face Inference Endpoints.
    This handler loads the 'izzelbas/roberta-large-ugm-cs-curriculum' model
    and uses a custom search logic to find the best answer span in a given context.
    """
    def __init__(self, path: str = ""):
        """
        Initializes the model and tokenizer.
        
        Args:
            path (str): The path to the downloaded model assets. 
                        Hugging Face Inference Endpoints automatically provides this.
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")

        # Load tokenizer and model from the specified path
        # If path is empty, it will download from the Hub (useful for local testing)
        model_name_or_path = path if path else "izzelbas/roberta-large-ugm-cs-curriculum"
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            self.model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
            
            # Move model to the appropriate device and set to evaluation mode
            self.model.to(self.device)
            self.model.eval()
            logger.info("Model and tokenizer loaded successfully.")
        except Exception as e:
            logger.error(f"Failed to load model or tokenizer: {e}")
            raise e

    def __call__(self, data: dict) -> dict:
        """
        Handles an inference request.

        Args:
            data (dict): A dictionary containing the input data. Expected format:
                         {
                             "inputs": {
                                 "question": "Your question here",
                                 "context": "The context paragraph here"
                             },
                             "parameters": {
                                 "max_answer_len": 30
                             }
                         }

        Returns:
            dict: A dictionary containing the answer and its score.
        """
        try:
            inputs_data = data.pop("inputs", data)
            question = inputs_data.get("question")
            context = inputs_data.get("context")

            if not question or not context:
                return {"error": "Missing 'question' or 'context' in 'inputs'"}

            # Get parameters or use defaults
            params = data.pop("parameters", {})
            max_answer_len = params.get("max_answer_len", 30)

            # --- Start of user-provided search logic ---

            # Tokenize input and send to GPU/CPU
            inputs = self.tokenizer(
                question, 
                context, 
                return_tensors="pt", 
                truncation=True, 
                max_length=512
            ).to(self.device)
            
            # Inference without gradient calculation for efficiency
            with torch.no_grad():
                outputs = self.model(**inputs)

            start_logits = outputs.start_logits[0]
            end_logits = outputs.end_logits[0]

            best_score = float('-inf')
            best_span = ""

            # Loop through all possible start and end token positions
            for start in range(len(start_logits)):
                # Constrain the end position based on max_answer_len
                end_limit = min(start + max_answer_len, len(end_logits))
                for end in range(start, end_limit):
                    score = start_logits[start] + end_logits[end]
                    
                    if score > best_score:
                        # Decode the span of tokens to get the answer text
                        span_ids = inputs["input_ids"][0][start : end + 1]
                        span_text = self.tokenizer.decode(span_ids, skip_special_tokens=True)
                        
                        best_score = score.item()
                        best_span = span_text
            
            # --- End of user-provided search logic ---

            return {"answer": best_span, "score": best_score}

        except Exception as e:
            logger.error(f"Error during inference: {e}")
            return {"error": str(e)}