| | from typing import Dict, List, Any |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| | from peft import PeftModel |
| | import re |
| | import os |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | """ |
| | Initialize the model and tokenizer for the inference endpoint. |
| | |
| | Args: |
| | path: The path to the model directory (provided by HF Inference Endpoints) |
| | """ |
| | |
| | self.base_model_name = "meta-llama/Llama-3.1-8B-Instruct" |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | |
| | hf_token = os.environ.get("HF_TOKEN", None) |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | self.base_model_name, |
| | token=hf_token, |
| | trust_remote_code=True |
| | ) |
| | if self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | |
| | |
| | if torch.cuda.is_available(): |
| | bnb_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype=torch.float16, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4" |
| | ) |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | self.base_model_name, |
| | quantization_config=bnb_config, |
| | torch_dtype=torch.float16, |
| | device_map="auto", |
| | trust_remote_code=True, |
| | token=hf_token |
| | ) |
| | else: |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | self.base_model_name, |
| | torch_dtype=torch.float16, |
| | low_cpu_mem_usage=True, |
| | trust_remote_code=True, |
| | token=hf_token |
| | ) |
| | |
| | |
| | self.model = PeftModel.from_pretrained(base_model, path) |
| | self.model.eval() |
| | |
| | |
| | self.generation_config = { |
| | "do_sample": True, |
| | "temperature": 0.7, |
| | "top_p": 0.9, |
| | "max_new_tokens": 1000, |
| | "pad_token_id": self.tokenizer.pad_token_id, |
| | "eos_token_id": self.tokenizer.eos_token_id |
| | } |
| | |
| | def format_math_prompt(self, question: str) -> str: |
| | """Format a math question with proper instructions.""" |
| | instructions = """Please solve this math problem step by step, following these rules: |
| | 1) Start by noting all the facts from the problem. |
| | 2) Show your work by performing inner calculations inside double angle brackets, like <<calculation=result>>. |
| | 3) You MUST write the final answer on a new line with a #### prefix. |
| | Note - each answer must be of length <= 400.""" |
| | |
| | |
| | prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{instructions}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n" |
| | return prompt |
| | |
| | def extract_answer(self, response: str) -> Any: |
| | """Extract the final answer from the model response.""" |
| | |
| | answer_match = re.search(r'####\s*([-\d,\.]+)', response) |
| | if answer_match: |
| | answer_str = answer_match.group(1).replace(',', '') |
| | try: |
| | |
| | if '.' in answer_str: |
| | return float(answer_str) |
| | else: |
| | return int(answer_str) |
| | except ValueError: |
| | return answer_str |
| | |
| | |
| | numbers = re.findall(r'[-\d,\.]+', response) |
| | if numbers: |
| | last_num = numbers[-1].replace(',', '') |
| | try: |
| | if '.' in last_num: |
| | return float(last_num) |
| | else: |
| | return int(last_num) |
| | except ValueError: |
| | pass |
| | |
| | return None |
| | |
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Process the inference request. |
| | |
| | Args: |
| | data: A dictionary containing the input data |
| | - inputs: str or List[str] - The math questions to solve |
| | - parameters (optional): Dict with generation parameters |
| | |
| | Returns: |
| | List of dictionaries containing the results |
| | """ |
| | |
| | inputs = data.get("inputs", "") |
| | parameters = data.get("parameters", {}) |
| | |
| | |
| | if isinstance(inputs, str): |
| | questions = [inputs] |
| | else: |
| | questions = inputs |
| | |
| | |
| | gen_config = self.generation_config.copy() |
| | gen_config.update(parameters) |
| | |
| | |
| | results = [] |
| | for question in questions: |
| | |
| | prompt = self.format_math_prompt(question) |
| | |
| | |
| | model_inputs = self.tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=512 |
| | ).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **model_inputs, |
| | **gen_config |
| | ) |
| | |
| | |
| | input_length = model_inputs['input_ids'].shape[1] |
| | generated_tokens = outputs[0][input_length:] |
| | assistant_response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() |
| | |
| | |
| | extracted_answer = self.extract_answer(assistant_response) |
| | |
| | results.append({ |
| | "question": question, |
| | "full_response": assistant_response, |
| | "answer": extracted_answer, |
| | "formatted_answer": f"#### {extracted_answer}" if extracted_answer is not None else "No answer found" |
| | }) |
| | |
| | return results |