File size: 5,532 Bytes
770285c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# handler.py
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
from pathlib import Path
import json

# Make sure the custom model code is importable
from .models.inference_memory_wrapper import InferenceMemoryWrapper

class EndpointHandler:
    def __init__(self, model_dir="."):
        """
        Load model and tokenizer.
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model_dir = Path(model_dir)

        print("Loading Tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print("Loading Base Llama Model...")
        # Load the base Llama model first
        base_model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            torch_dtype=torch.float16, # Use float16 for efficiency
            device_map="auto" # Let HF handle device placement if multiple GPUs
        )
        base_model.eval() # Ensure base model is in eval mode

        print("Initializing InferenceMemoryWrapper...")
        # Load wrapper config to get memory_size etc. (assuming it's saved)
        # You might need to adjust how config is loaded/passed
        wrapper_config_path = model_dir / "config.json" # Assuming wrapper config is here
        if wrapper_config_path.exists():
              config = LlamaConfig.from_pretrained(model_dir)
              memory_size = getattr(config, "memory_size", 512) # Get from config or default
              update_alpha = getattr(config, "update_alpha", 0.1)
              # Add other params as needed
        else:
              # Default values if no specific wrapper config saved
              memory_size = 512
              update_alpha = 0.1
              print("Warning: Wrapper config not found, using defaults.")


        # Initialize the wrapper, passing the loaded base model
        self.wrapper = InferenceMemoryWrapper(
            llama_model=base_model,
            memory_size=memory_size,
            update_alpha=update_alpha
            # Add other params loaded from config or defaults
        ).to(self.device).half() # Move wrapper to device and use float16

        # Load the wrapper's specific state (memory buffer)
        memory_buffer_path = model_dir / "memory_buffer.pt"
        surprise_state_path = model_dir / "surprise_state.pt"

        if memory_buffer_path.exists():
            print("Loading memory buffer state...")
            # Load state dict for the nn.Parameter
            mem_state_dict = torch.load(memory_buffer_path, map_location=self.device)
            self.wrapper.memory_buffer.load_state_dict(mem_state_dict)
        else:
            print("Warning: memory_buffer.pt not found. Initializing with zeros.")

        if surprise_state_path.exists():
              print("Loading surprise state...")
              # Load buffer tensor directly
              surprise_state = torch.load(surprise_state_path, map_location=self.device)
              # Manually assign to the registered buffer
              self.wrapper.surprise_state = surprise_state
        else:
              print("Warning: surprise_state.pt not found. Initializing with zeros.")

        self.wrapper.eval() # Ensure wrapper is also in eval mode
        print("Model loaded successfully.")

    def __call__(self, data: dict):
        """
        Handle inference requests.
        `data` is the deserialized request payload.
        """
        prompt = data.pop("inputs", data)
        parameters = data.pop("parameters", {})

        # Default parameters (match wrapper.generate defaults)
        max_new_tokens = parameters.get("max_new_tokens", 20)
        use_memory = parameters.get("use_memory", True)
        # Default to 'ema' or 'none' for endpoints
        update_rule = parameters.get("update_rule", "ema")
        if update_rule == 'surprise':
              print("Warning: 'surprise' update rule requested, may be slow/costly.")
              # Decide whether to allow it or force 'ema'/'none'
              # update_rule = 'ema'

        temperature = parameters.get("temperature", 0.7)
        top_p = parameters.get("top_p", 0.95)
        do_sample = parameters.get("do_sample", True)
        repetition_penalty = parameters.get("repetition_penalty", 1.0)

        print(f"Generating with params: {parameters}, update_rule: {update_rule}")

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        # --- Inference ---
        # Note: Memory state persists within this handler instance (stateful per replica)
        with torch.inference_mode(): # Ensure no gradients are computed unless explicitly needed
              output_ids = self.wrapper.generate(
                  input_ids=inputs["input_ids"],
                  max_new_tokens=max_new_tokens,
                  use_memory=use_memory,
                  update_rule=update_rule, # Pass the rule
                  temperature=temperature,
                  top_p=top_p,
                  do_sample=do_sample,
                  repetition_penalty=repetition_penalty,
                  eos_token_id=self.tokenizer.eos_token_id,
                  pad_token_id=self.tokenizer.pad_token_id,
                  # Add any other relevant generate parameters
              )

        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

        return [{"generated_text": generated_text}]