Echo-DSRN-114M / handler.py
mrs83's picture
Upload 5 files
f80e336 verified
import os
from typing import Any, Dict, List
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
from .modeling_echo import EchoConfig, EchoForCausalLM
# Register local architecture to override remote code
AutoConfig.register("echo", EchoConfig)
AutoModelForCausalLM.register(EchoConfig, EchoForCausalLM)
class StringStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_strings):
self.tokenizer = tokenizer
self.stop_strings = stop_strings
def __call__(self, input_ids, scores, **kwargs):
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
for stop_string in self.stop_strings:
if stop_string in generated_text[-(len(stop_string) + 20) :]:
if generated_text.strip().endswith(stop_string):
return True
return False
class EndpointHandler:
"""
Custom Handler for Hugging Face Inference Endpoints.
Ensures correct initialization of the Echo-DSRN model and fixes the pad_token error.
"""
def __init__(self, path=""):
print(f"Loading Echo-DSRN from {path}...")
"cuda" if torch.cuda.is_available() else "cpu"
# Determine if path is an adapter or a full model
from peft import PeftConfig, PeftModel
adapter_config_path = os.path.join(path, "adapter_config.json")
tokenizer_path = path
if os.path.exists(adapter_config_path):
print(f"Detected LoRA adapter in {path}")
peft_config = PeftConfig.from_pretrained(path)
base_model_path = peft_config.base_model_name_or_path
tokenizer_path = base_model_path # Use base model for tokenizer
print(f"Loading base model: {base_model_path}")
# USE LOCAL EchoForCausalLM to ensure our fixes are active!
model = EchoForCausalLM.from_pretrained(
base_model_path,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=False,
)
print("Applying adapter and merging...")
model = PeftModel.from_pretrained(model, path)
self.model = model.merge_and_unload()
else:
print(f"Loading full model: {path}")
self.model = EchoForCausalLM.from_pretrained(
path,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=False,
)
print(f"Loading tokenizer from {tokenizer_path}...")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
self.tokenizer.pad_token_id = 32000 # <|endoftext|>
self.eos_token_ids = [32000, 32007, 32011]
# Pre-compile stopping criteria strings matching talk.py
self.stop_strings = ["<|im_end|>", "<|end|>", "<|user|>"]
self.model.eval()
print("Model and Tokenizer loaded successfully (Local Code Forced).")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj: `Dict`):
- "inputs": The prompt for generation.
- "parameters" (optional): Dictionary of generation parameters.
Returns:
A :obj:`list`: A list containing the generated text/logprobs.
"""
inputs = data.pop("inputs", data)
parameters = data.pop(
"parameters",
{
"max_new_tokens": 128,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
"repetition_penalty": 1.2,
"use_cache": False,
},
)
# Ensure use_cache is False even if passed
parameters["use_cache"] = False
# Extract special flags
logprobs_count = parameters.pop("logprobs", None)
echo = parameters.pop("echo", False)
# Handle Chat vs Completion inputs
if isinstance(inputs, list):
for msg in inputs:
if isinstance(msg.get("content"), list):
text_content = ""
for part in msg["content"]:
if part.get("type") == "text":
text_content += part.get("text", "")
msg["content"] = text_content
inputs = self.tokenizer.apply_chat_template(
inputs, tokenize=False, add_generation_prompt=True
)
# Tokenize inputs
input_tokens = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
input_ids = input_tokens.input_ids
# Extract special params
eos_token_id = parameters.pop("eos_token_id", self.eos_token_ids)
pad_token_id = parameters.pop("pad_token_id", self.tokenizer.pad_token_id)
repetition_penalty = parameters.pop("repetition_penalty", 1.2)
tokenizer_stop = StringStoppingCriteria(self.tokenizer, self.stop_strings)
all_tokens = []
all_logprobs = []
text_offsets = []
current_offset = 0
# Handle Prompt Logprobs (Echo)
if logprobs_count is not None and echo:
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits # (B, T, V)
# Shift logits to match input_ids for logprob calculation
# input_ids[0, 1] logprob is logits[0, 0, input_ids[0, 1]]
for i in range(input_ids.shape[1]):
token_id = input_ids[0, i].item()
token_text = self.tokenizer.decode([token_id])
all_tokens.append(token_text)
text_offsets.append(current_offset)
current_offset += len(token_text)
if i == 0:
all_logprobs.append(None)
else:
lp = torch.nn.functional.log_softmax(logits[0, i - 1, :], dim=-1)
all_logprobs.append(lp[token_id].item())
# Generate output
with torch.no_grad():
gen_out = self.model.generate(
**input_tokens,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
repetition_penalty=repetition_penalty,
stopping_criteria=StoppingCriteriaList([tokenizer_stop]),
output_scores=True if logprobs_count is not None else False,
return_dict_in_generate=True if logprobs_count is not None else False,
**parameters,
)
if logprobs_count is not None:
output_ids = gen_out.sequences
scores = gen_out.scores # list of (B, V) tensors
# Process generated tokens
input_len = input_ids.shape[1]
generated_ids = output_ids[0, input_len:]
for i, token_id in enumerate(generated_ids):
token_id = token_id.item()
token_text = self.tokenizer.decode([token_id])
all_tokens.append(token_text)
lp = torch.nn.functional.log_softmax(scores[i][0, :], dim=-1)
all_logprobs.append(lp[token_id].item())
text_offsets.append(current_offset)
current_offset += len(token_text)
decoded_output = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
logprobs_dict = {
"tokens": all_tokens,
"token_logprobs": all_logprobs,
"top_logprobs": [],
"text_offset": text_offsets,
}
return [{"generated_text": decoded_output, "logprobs": logprobs_dict}]
else:
output_ids = gen_out
input_len = input_ids.shape[1]
decoded_output = self.tokenizer.decode(
output_ids[0][input_len:], skip_special_tokens=True
)
return [{"generated_text": decoded_output}]