# handler.py import torch import os from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig from pathlib import Path import json # Make sure the custom model code is importable from .models.inference_memory_wrapper import InferenceMemoryWrapper class EndpointHandler: def __init__(self, model_dir="."): """ Load model and tokenizer. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_dir = Path(model_dir) print("Loading Tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained(model_dir) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print("Loading Base Llama Model...") # Load the base Llama model first base_model = AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.float16, # Use float16 for efficiency device_map="auto" # Let HF handle device placement if multiple GPUs ) base_model.eval() # Ensure base model is in eval mode print("Initializing InferenceMemoryWrapper...") # Load wrapper config to get memory_size etc. (assuming it's saved) # You might need to adjust how config is loaded/passed wrapper_config_path = model_dir / "config.json" # Assuming wrapper config is here if wrapper_config_path.exists(): config = LlamaConfig.from_pretrained(model_dir) memory_size = getattr(config, "memory_size", 512) # Get from config or default update_alpha = getattr(config, "update_alpha", 0.1) # Add other params as needed else: # Default values if no specific wrapper config saved memory_size = 512 update_alpha = 0.1 print("Warning: Wrapper config not found, using defaults.") # Initialize the wrapper, passing the loaded base model self.wrapper = InferenceMemoryWrapper( llama_model=base_model, memory_size=memory_size, update_alpha=update_alpha # Add other params loaded from config or defaults ).to(self.device).half() # Move wrapper to device and use float16 # Load the wrapper's specific state (memory buffer) memory_buffer_path = model_dir / "memory_buffer.pt" surprise_state_path = model_dir / "surprise_state.pt" if memory_buffer_path.exists(): print("Loading memory buffer state...") # Load state dict for the nn.Parameter mem_state_dict = torch.load(memory_buffer_path, map_location=self.device) self.wrapper.memory_buffer.load_state_dict(mem_state_dict) else: print("Warning: memory_buffer.pt not found. Initializing with zeros.") if surprise_state_path.exists(): print("Loading surprise state...") # Load buffer tensor directly surprise_state = torch.load(surprise_state_path, map_location=self.device) # Manually assign to the registered buffer self.wrapper.surprise_state = surprise_state else: print("Warning: surprise_state.pt not found. Initializing with zeros.") self.wrapper.eval() # Ensure wrapper is also in eval mode print("Model loaded successfully.") def __call__(self, data: dict): """ Handle inference requests. `data` is the deserialized request payload. """ prompt = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Default parameters (match wrapper.generate defaults) max_new_tokens = parameters.get("max_new_tokens", 20) use_memory = parameters.get("use_memory", True) # Default to 'ema' or 'none' for endpoints update_rule = parameters.get("update_rule", "ema") if update_rule == 'surprise': print("Warning: 'surprise' update rule requested, may be slow/costly.") # Decide whether to allow it or force 'ema'/'none' # update_rule = 'ema' temperature = parameters.get("temperature", 0.7) top_p = parameters.get("top_p", 0.95) do_sample = parameters.get("do_sample", True) repetition_penalty = parameters.get("repetition_penalty", 1.0) print(f"Generating with params: {parameters}, update_rule: {update_rule}") inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) # --- Inference --- # Note: Memory state persists within this handler instance (stateful per replica) with torch.inference_mode(): # Ensure no gradients are computed unless explicitly needed output_ids = self.wrapper.generate( input_ids=inputs["input_ids"], max_new_tokens=max_new_tokens, use_memory=use_memory, update_rule=update_rule, # Pass the rule temperature=temperature, top_p=top_p, do_sample=do_sample, repetition_penalty=repetition_penalty, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, # Add any other relevant generate parameters ) generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) return [{"generated_text": generated_text}]