File size: 6,629 Bytes
2847c03
 
 
 
 
57bb533
2847c03
 
 
 
 
 
 
 
 
 
 
 
 
 
57bb533
 
 
2847c03
57bb533
 
 
 
 
2847c03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57bb533
 
2847c03
 
 
 
 
 
57bb533
 
2847c03
 
57bb533
2847c03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ff4a0e
 
 
 
2847c03
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import re
import os


class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the model and tokenizer for the inference endpoint.
        
        Args:
            path: The path to the model directory (provided by HF Inference Endpoints)
        """
        # Model configuration
        self.base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Get HF token from environment if available (Inference Endpoints will set this)
        hf_token = os.environ.get("HF_TOKEN", None)
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.base_model_name,
            token=hf_token,
            trust_remote_code=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load base model with quantization for memory efficiency
        if torch.cuda.is_available():
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
            base_model = AutoModelForCausalLM.from_pretrained(
                self.base_model_name,
                quantization_config=bnb_config,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True,
                token=hf_token
            )
        else:
            base_model = AutoModelForCausalLM.from_pretrained(
                self.base_model_name,
                torch_dtype=torch.float16,
                low_cpu_mem_usage=True,
                trust_remote_code=True,
                token=hf_token
            )
        
        # Load PEFT adapter from the current path
        self.model = PeftModel.from_pretrained(base_model, path)
        self.model.eval()
        
        # Generation config
        self.generation_config = {
            "do_sample": True,
            "temperature": 0.7,
            "top_p": 0.9,
            "max_new_tokens": 1000,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id
        }
    
    def format_math_prompt(self, question: str) -> str:
        """Format a math question with proper instructions."""
        instructions = """Please solve this math problem step by step, following these rules:
1) Start by noting all the facts from the problem.
2) Show your work by performing inner calculations inside double angle brackets, like <<calculation=result>>.
3) You MUST write the final answer on a new line with a #### prefix.
Note - each answer must be of length <= 400."""
        
        # Format according to Llama 3.1 chat template
        prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{instructions}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"
        return prompt
    
    def extract_answer(self, response: str) -> Any:
        """Extract the final answer from the model response."""
        # Look for answer after ####
        answer_match = re.search(r'####\s*([-\d,\.]+)', response)
        if answer_match:
            answer_str = answer_match.group(1).replace(',', '')
            try:
                # Try to convert to float first
                if '.' in answer_str:
                    return float(answer_str)
                else:
                    return int(answer_str)
            except ValueError:
                return answer_str
        
        # Fallback: look for any number at the end
        numbers = re.findall(r'[-\d,\.]+', response)
        if numbers:
            last_num = numbers[-1].replace(',', '')
            try:
                if '.' in last_num:
                    return float(last_num)
                else:
                    return int(last_num)
            except ValueError:
                pass
        
        return None
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the inference request.
        
        Args:
            data: A dictionary containing the input data
                - inputs: str or List[str] - The math questions to solve
                - parameters (optional): Dict with generation parameters
        
        Returns:
            List of dictionaries containing the results
        """
        # Extract inputs
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})
        
        # Handle both single string and list of strings
        if isinstance(inputs, str):
            questions = [inputs]
        else:
            questions = inputs
        
        # Update generation config with any provided parameters
        gen_config = self.generation_config.copy()
        gen_config.update(parameters)
        
        # Process each question
        results = []
        for question in questions:
            # Format the prompt
            prompt = self.format_math_prompt(question)
            
            # Tokenize
            model_inputs = self.tokenizer(
                prompt, 
                return_tensors="pt", 
                truncation=True, 
                max_length=512
            ).to(self.device)
            
            # Generate response
            with torch.no_grad():
                outputs = self.model.generate(
                    **model_inputs,
                    **gen_config
                )
            
            # Decode response - only decode the generated tokens, not the input
            input_length = model_inputs['input_ids'].shape[1]
            generated_tokens = outputs[0][input_length:]
            assistant_response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
            
            # Extract the final answer
            extracted_answer = self.extract_answer(assistant_response)
            
            results.append({
                "question": question,
                "full_response": assistant_response,
                "answer": extracted_answer,
                "formatted_answer": f"#### {extracted_answer}" if extracted_answer is not None else "No answer found"
            })
        
        return results