SciGuru-delta / handler.py
golyuval's picture
Update handler.py
0ff4a0e verified
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import re
import os
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the model and tokenizer for the inference endpoint.
Args:
path: The path to the model directory (provided by HF Inference Endpoints)
"""
# Model configuration
self.base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Get HF token from environment if available (Inference Endpoints will set this)
hf_token = os.environ.get("HF_TOKEN", None)
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name,
token=hf_token,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load base model with quantization for memory efficiency
if torch.cuda.is_available():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
token=hf_token
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
token=hf_token
)
# Load PEFT adapter from the current path
self.model = PeftModel.from_pretrained(base_model, path)
self.model.eval()
# Generation config
self.generation_config = {
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"max_new_tokens": 1000,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id
}
def format_math_prompt(self, question: str) -> str:
"""Format a math question with proper instructions."""
instructions = """Please solve this math problem step by step, following these rules:
1) Start by noting all the facts from the problem.
2) Show your work by performing inner calculations inside double angle brackets, like <<calculation=result>>.
3) You MUST write the final answer on a new line with a #### prefix.
Note - each answer must be of length <= 400."""
# Format according to Llama 3.1 chat template
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{instructions}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"
return prompt
def extract_answer(self, response: str) -> Any:
"""Extract the final answer from the model response."""
# Look for answer after ####
answer_match = re.search(r'####\s*([-\d,\.]+)', response)
if answer_match:
answer_str = answer_match.group(1).replace(',', '')
try:
# Try to convert to float first
if '.' in answer_str:
return float(answer_str)
else:
return int(answer_str)
except ValueError:
return answer_str
# Fallback: look for any number at the end
numbers = re.findall(r'[-\d,\.]+', response)
if numbers:
last_num = numbers[-1].replace(',', '')
try:
if '.' in last_num:
return float(last_num)
else:
return int(last_num)
except ValueError:
pass
return None
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the inference request.
Args:
data: A dictionary containing the input data
- inputs: str or List[str] - The math questions to solve
- parameters (optional): Dict with generation parameters
Returns:
List of dictionaries containing the results
"""
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle both single string and list of strings
if isinstance(inputs, str):
questions = [inputs]
else:
questions = inputs
# Update generation config with any provided parameters
gen_config = self.generation_config.copy()
gen_config.update(parameters)
# Process each question
results = []
for question in questions:
# Format the prompt
prompt = self.format_math_prompt(question)
# Tokenize
model_inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.device)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**model_inputs,
**gen_config
)
# Decode response - only decode the generated tokens, not the input
input_length = model_inputs['input_ids'].shape[1]
generated_tokens = outputs[0][input_length:]
assistant_response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# Extract the final answer
extracted_answer = self.extract_answer(assistant_response)
results.append({
"question": question,
"full_response": assistant_response,
"answer": extracted_answer,
"formatted_answer": f"#### {extracted_answer}" if extracted_answer is not None else "No answer found"
})
return results