File size: 6,629 Bytes
2847c03 57bb533 2847c03 57bb533 2847c03 57bb533 2847c03 57bb533 2847c03 57bb533 2847c03 57bb533 2847c03 0ff4a0e 2847c03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import re
import os
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the model and tokenizer for the inference endpoint.
Args:
path: The path to the model directory (provided by HF Inference Endpoints)
"""
# Model configuration
self.base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Get HF token from environment if available (Inference Endpoints will set this)
hf_token = os.environ.get("HF_TOKEN", None)
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name,
token=hf_token,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load base model with quantization for memory efficiency
if torch.cuda.is_available():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
token=hf_token
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
token=hf_token
)
# Load PEFT adapter from the current path
self.model = PeftModel.from_pretrained(base_model, path)
self.model.eval()
# Generation config
self.generation_config = {
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"max_new_tokens": 1000,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id
}
def format_math_prompt(self, question: str) -> str:
"""Format a math question with proper instructions."""
instructions = """Please solve this math problem step by step, following these rules:
1) Start by noting all the facts from the problem.
2) Show your work by performing inner calculations inside double angle brackets, like <<calculation=result>>.
3) You MUST write the final answer on a new line with a #### prefix.
Note - each answer must be of length <= 400."""
# Format according to Llama 3.1 chat template
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{instructions}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"
return prompt
def extract_answer(self, response: str) -> Any:
"""Extract the final answer from the model response."""
# Look for answer after ####
answer_match = re.search(r'####\s*([-\d,\.]+)', response)
if answer_match:
answer_str = answer_match.group(1).replace(',', '')
try:
# Try to convert to float first
if '.' in answer_str:
return float(answer_str)
else:
return int(answer_str)
except ValueError:
return answer_str
# Fallback: look for any number at the end
numbers = re.findall(r'[-\d,\.]+', response)
if numbers:
last_num = numbers[-1].replace(',', '')
try:
if '.' in last_num:
return float(last_num)
else:
return int(last_num)
except ValueError:
pass
return None
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the inference request.
Args:
data: A dictionary containing the input data
- inputs: str or List[str] - The math questions to solve
- parameters (optional): Dict with generation parameters
Returns:
List of dictionaries containing the results
"""
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle both single string and list of strings
if isinstance(inputs, str):
questions = [inputs]
else:
questions = inputs
# Update generation config with any provided parameters
gen_config = self.generation_config.copy()
gen_config.update(parameters)
# Process each question
results = []
for question in questions:
# Format the prompt
prompt = self.format_math_prompt(question)
# Tokenize
model_inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.device)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**model_inputs,
**gen_config
)
# Decode response - only decode the generated tokens, not the input
input_length = model_inputs['input_ids'].shape[1]
generated_tokens = outputs[0][input_length:]
assistant_response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# Extract the final answer
extracted_answer = self.extract_answer(assistant_response)
results.append({
"question": question,
"full_response": assistant_response,
"answer": extracted_answer,
"formatted_answer": f"#### {extracted_answer}" if extracted_answer is not None else "No answer found"
})
return results |