botverse's picture
merged .gitattributes
770285c
# handler.py
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
from pathlib import Path
import json
# Make sure the custom model code is importable
from .models.inference_memory_wrapper import InferenceMemoryWrapper
class EndpointHandler:
def __init__(self, model_dir="."):
"""
Load model and tokenizer.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dir = Path(model_dir)
print("Loading Tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print("Loading Base Llama Model...")
# Load the base Llama model first
base_model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.float16, # Use float16 for efficiency
device_map="auto" # Let HF handle device placement if multiple GPUs
)
base_model.eval() # Ensure base model is in eval mode
print("Initializing InferenceMemoryWrapper...")
# Load wrapper config to get memory_size etc. (assuming it's saved)
# You might need to adjust how config is loaded/passed
wrapper_config_path = model_dir / "config.json" # Assuming wrapper config is here
if wrapper_config_path.exists():
config = LlamaConfig.from_pretrained(model_dir)
memory_size = getattr(config, "memory_size", 512) # Get from config or default
update_alpha = getattr(config, "update_alpha", 0.1)
# Add other params as needed
else:
# Default values if no specific wrapper config saved
memory_size = 512
update_alpha = 0.1
print("Warning: Wrapper config not found, using defaults.")
# Initialize the wrapper, passing the loaded base model
self.wrapper = InferenceMemoryWrapper(
llama_model=base_model,
memory_size=memory_size,
update_alpha=update_alpha
# Add other params loaded from config or defaults
).to(self.device).half() # Move wrapper to device and use float16
# Load the wrapper's specific state (memory buffer)
memory_buffer_path = model_dir / "memory_buffer.pt"
surprise_state_path = model_dir / "surprise_state.pt"
if memory_buffer_path.exists():
print("Loading memory buffer state...")
# Load state dict for the nn.Parameter
mem_state_dict = torch.load(memory_buffer_path, map_location=self.device)
self.wrapper.memory_buffer.load_state_dict(mem_state_dict)
else:
print("Warning: memory_buffer.pt not found. Initializing with zeros.")
if surprise_state_path.exists():
print("Loading surprise state...")
# Load buffer tensor directly
surprise_state = torch.load(surprise_state_path, map_location=self.device)
# Manually assign to the registered buffer
self.wrapper.surprise_state = surprise_state
else:
print("Warning: surprise_state.pt not found. Initializing with zeros.")
self.wrapper.eval() # Ensure wrapper is also in eval mode
print("Model loaded successfully.")
def __call__(self, data: dict):
"""
Handle inference requests.
`data` is the deserialized request payload.
"""
prompt = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Default parameters (match wrapper.generate defaults)
max_new_tokens = parameters.get("max_new_tokens", 20)
use_memory = parameters.get("use_memory", True)
# Default to 'ema' or 'none' for endpoints
update_rule = parameters.get("update_rule", "ema")
if update_rule == 'surprise':
print("Warning: 'surprise' update rule requested, may be slow/costly.")
# Decide whether to allow it or force 'ema'/'none'
# update_rule = 'ema'
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.95)
do_sample = parameters.get("do_sample", True)
repetition_penalty = parameters.get("repetition_penalty", 1.0)
print(f"Generating with params: {parameters}, update_rule: {update_rule}")
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# --- Inference ---
# Note: Memory state persists within this handler instance (stateful per replica)
with torch.inference_mode(): # Ensure no gradients are computed unless explicitly needed
output_ids = self.wrapper.generate(
input_ids=inputs["input_ids"],
max_new_tokens=max_new_tokens,
use_memory=use_memory,
update_rule=update_rule, # Pass the rule
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
# Add any other relevant generate parameters
)
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return [{"generated_text": generated_text}]