|
|
|
|
|
import torch |
|
|
import os |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
|
|
|
from .models.inference_memory_wrapper import InferenceMemoryWrapper |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir="."): |
|
|
""" |
|
|
Load model and tokenizer. |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model_dir = Path(model_dir) |
|
|
|
|
|
print("Loading Tokenizer...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
print("Loading Base Llama Model...") |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_dir, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
base_model.eval() |
|
|
|
|
|
print("Initializing InferenceMemoryWrapper...") |
|
|
|
|
|
|
|
|
wrapper_config_path = model_dir / "config.json" |
|
|
if wrapper_config_path.exists(): |
|
|
config = LlamaConfig.from_pretrained(model_dir) |
|
|
memory_size = getattr(config, "memory_size", 512) |
|
|
update_alpha = getattr(config, "update_alpha", 0.1) |
|
|
|
|
|
else: |
|
|
|
|
|
memory_size = 512 |
|
|
update_alpha = 0.1 |
|
|
print("Warning: Wrapper config not found, using defaults.") |
|
|
|
|
|
|
|
|
|
|
|
self.wrapper = InferenceMemoryWrapper( |
|
|
llama_model=base_model, |
|
|
memory_size=memory_size, |
|
|
update_alpha=update_alpha |
|
|
|
|
|
).to(self.device).half() |
|
|
|
|
|
|
|
|
memory_buffer_path = model_dir / "memory_buffer.pt" |
|
|
surprise_state_path = model_dir / "surprise_state.pt" |
|
|
|
|
|
if memory_buffer_path.exists(): |
|
|
print("Loading memory buffer state...") |
|
|
|
|
|
mem_state_dict = torch.load(memory_buffer_path, map_location=self.device) |
|
|
self.wrapper.memory_buffer.load_state_dict(mem_state_dict) |
|
|
else: |
|
|
print("Warning: memory_buffer.pt not found. Initializing with zeros.") |
|
|
|
|
|
if surprise_state_path.exists(): |
|
|
print("Loading surprise state...") |
|
|
|
|
|
surprise_state = torch.load(surprise_state_path, map_location=self.device) |
|
|
|
|
|
self.wrapper.surprise_state = surprise_state |
|
|
else: |
|
|
print("Warning: surprise_state.pt not found. Initializing with zeros.") |
|
|
|
|
|
self.wrapper.eval() |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
def __call__(self, data: dict): |
|
|
""" |
|
|
Handle inference requests. |
|
|
`data` is the deserialized request payload. |
|
|
""" |
|
|
prompt = data.pop("inputs", data) |
|
|
parameters = data.pop("parameters", {}) |
|
|
|
|
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 20) |
|
|
use_memory = parameters.get("use_memory", True) |
|
|
|
|
|
update_rule = parameters.get("update_rule", "ema") |
|
|
if update_rule == 'surprise': |
|
|
print("Warning: 'surprise' update rule requested, may be slow/costly.") |
|
|
|
|
|
|
|
|
|
|
|
temperature = parameters.get("temperature", 0.7) |
|
|
top_p = parameters.get("top_p", 0.95) |
|
|
do_sample = parameters.get("do_sample", True) |
|
|
repetition_penalty = parameters.get("repetition_penalty", 1.0) |
|
|
|
|
|
print(f"Generating with params: {parameters}, update_rule: {update_rule}") |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
output_ids = self.wrapper.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
max_new_tokens=max_new_tokens, |
|
|
use_memory=use_memory, |
|
|
update_rule=update_rule, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=do_sample, |
|
|
repetition_penalty=repetition_penalty, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
|
|
|
) |
|
|
|
|
|
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
return [{"generated_text": generated_text}] |
|
|
|