from typing import Dict, List, Any import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel import re import os class EndpointHandler: def __init__(self, path=""): """ Initialize the model and tokenizer for the inference endpoint. Args: path: The path to the model directory (provided by HF Inference Endpoints) """ # Model configuration self.base_model_name = "meta-llama/Llama-3.1-8B-Instruct" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Get HF token from environment if available (Inference Endpoints will set this) hf_token = os.environ.get("HF_TOKEN", None) # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained( self.base_model_name, token=hf_token, trust_remote_code=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load base model with quantization for memory efficiency if torch.cuda.is_available(): bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) base_model = AutoModelForCausalLM.from_pretrained( self.base_model_name, quantization_config=bnb_config, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, token=hf_token ) else: base_model = AutoModelForCausalLM.from_pretrained( self.base_model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, token=hf_token ) # Load PEFT adapter from the current path self.model = PeftModel.from_pretrained(base_model, path) self.model.eval() # Generation config self.generation_config = { "do_sample": True, "temperature": 0.7, "top_p": 0.9, "max_new_tokens": 1000, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id } def format_math_prompt(self, question: str) -> str: """Format a math question with proper instructions.""" instructions = """Please solve this math problem step by step, following these rules: 1) Start by noting all the facts from the problem. 2) Show your work by performing inner calculations inside double angle brackets, like <>. 3) You MUST write the final answer on a new line with a #### prefix. Note - each answer must be of length <= 400.""" # Format according to Llama 3.1 chat template prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{instructions}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n" return prompt def extract_answer(self, response: str) -> Any: """Extract the final answer from the model response.""" # Look for answer after #### answer_match = re.search(r'####\s*([-\d,\.]+)', response) if answer_match: answer_str = answer_match.group(1).replace(',', '') try: # Try to convert to float first if '.' in answer_str: return float(answer_str) else: return int(answer_str) except ValueError: return answer_str # Fallback: look for any number at the end numbers = re.findall(r'[-\d,\.]+', response) if numbers: last_num = numbers[-1].replace(',', '') try: if '.' in last_num: return float(last_num) else: return int(last_num) except ValueError: pass return None def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the inference request. Args: data: A dictionary containing the input data - inputs: str or List[str] - The math questions to solve - parameters (optional): Dict with generation parameters Returns: List of dictionaries containing the results """ # Extract inputs inputs = data.get("inputs", "") parameters = data.get("parameters", {}) # Handle both single string and list of strings if isinstance(inputs, str): questions = [inputs] else: questions = inputs # Update generation config with any provided parameters gen_config = self.generation_config.copy() gen_config.update(parameters) # Process each question results = [] for question in questions: # Format the prompt prompt = self.format_math_prompt(question) # Tokenize model_inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ).to(self.device) # Generate response with torch.no_grad(): outputs = self.model.generate( **model_inputs, **gen_config ) # Decode response - only decode the generated tokens, not the input input_length = model_inputs['input_ids'].shape[1] generated_tokens = outputs[0][input_length:] assistant_response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() # Extract the final answer extracted_answer = self.extract_answer(assistant_response) results.append({ "question": question, "full_response": assistant_response, "answer": extracted_answer, "formatted_answer": f"#### {extracted_answer}" if extracted_answer is not None else "No answer found" }) return results