import os from typing import Any, Dict, List import torch from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, ) from .modeling_echo import EchoConfig, EchoForCausalLM # Register local architecture to override remote code AutoConfig.register("echo", EchoConfig) AutoModelForCausalLM.register(EchoConfig, EchoForCausalLM) class StringStoppingCriteria(StoppingCriteria): def __init__(self, tokenizer, stop_strings): self.tokenizer = tokenizer self.stop_strings = stop_strings def __call__(self, input_ids, scores, **kwargs): generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=False) for stop_string in self.stop_strings: if stop_string in generated_text[-(len(stop_string) + 20) :]: if generated_text.strip().endswith(stop_string): return True return False class EndpointHandler: """ Custom Handler for Hugging Face Inference Endpoints. Ensures correct initialization of the Echo-DSRN model and fixes the pad_token error. """ def __init__(self, path=""): print(f"Loading Echo-DSRN from {path}...") "cuda" if torch.cuda.is_available() else "cpu" # Determine if path is an adapter or a full model from peft import PeftConfig, PeftModel adapter_config_path = os.path.join(path, "adapter_config.json") tokenizer_path = path if os.path.exists(adapter_config_path): print(f"Detected LoRA adapter in {path}") peft_config = PeftConfig.from_pretrained(path) base_model_path = peft_config.base_model_name_or_path tokenizer_path = base_model_path # Use base model for tokenizer print(f"Loading base model: {base_model_path}") # USE LOCAL EchoForCausalLM to ensure our fixes are active! model = EchoForCausalLM.from_pretrained( base_model_path, device_map="auto", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, trust_remote_code=False, ) print("Applying adapter and merging...") model = PeftModel.from_pretrained(model, path) self.model = model.merge_and_unload() else: print(f"Loading full model: {path}") self.model = EchoForCausalLM.from_pretrained( path, device_map="auto", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, trust_remote_code=False, ) print(f"Loading tokenizer from {tokenizer_path}...") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) self.tokenizer.pad_token_id = 32000 # <|endoftext|> self.eos_token_ids = [32000, 32007, 32011] # Pre-compile stopping criteria strings matching talk.py self.stop_strings = ["<|im_end|>", "<|end|>", "<|user|>"] self.model.eval() print("Model and Tokenizer loaded successfully (Local Code Forced).") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data (:obj: `Dict`): - "inputs": The prompt for generation. - "parameters" (optional): Dictionary of generation parameters. Returns: A :obj:`list`: A list containing the generated text/logprobs. """ inputs = data.pop("inputs", data) parameters = data.pop( "parameters", { "max_new_tokens": 128, "temperature": 0.7, "top_p": 0.9, "do_sample": True, "repetition_penalty": 1.2, "use_cache": False, }, ) # Ensure use_cache is False even if passed parameters["use_cache"] = False # Extract special flags logprobs_count = parameters.pop("logprobs", None) echo = parameters.pop("echo", False) # Handle Chat vs Completion inputs if isinstance(inputs, list): for msg in inputs: if isinstance(msg.get("content"), list): text_content = "" for part in msg["content"]: if part.get("type") == "text": text_content += part.get("text", "") msg["content"] = text_content inputs = self.tokenizer.apply_chat_template( inputs, tokenize=False, add_generation_prompt=True ) # Tokenize inputs input_tokens = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) input_ids = input_tokens.input_ids # Extract special params eos_token_id = parameters.pop("eos_token_id", self.eos_token_ids) pad_token_id = parameters.pop("pad_token_id", self.tokenizer.pad_token_id) repetition_penalty = parameters.pop("repetition_penalty", 1.2) tokenizer_stop = StringStoppingCriteria(self.tokenizer, self.stop_strings) all_tokens = [] all_logprobs = [] text_offsets = [] current_offset = 0 # Handle Prompt Logprobs (Echo) if logprobs_count is not None and echo: with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits # (B, T, V) # Shift logits to match input_ids for logprob calculation # input_ids[0, 1] logprob is logits[0, 0, input_ids[0, 1]] for i in range(input_ids.shape[1]): token_id = input_ids[0, i].item() token_text = self.tokenizer.decode([token_id]) all_tokens.append(token_text) text_offsets.append(current_offset) current_offset += len(token_text) if i == 0: all_logprobs.append(None) else: lp = torch.nn.functional.log_softmax(logits[0, i - 1, :], dim=-1) all_logprobs.append(lp[token_id].item()) # Generate output with torch.no_grad(): gen_out = self.model.generate( **input_tokens, eos_token_id=eos_token_id, pad_token_id=pad_token_id, repetition_penalty=repetition_penalty, stopping_criteria=StoppingCriteriaList([tokenizer_stop]), output_scores=True if logprobs_count is not None else False, return_dict_in_generate=True if logprobs_count is not None else False, **parameters, ) if logprobs_count is not None: output_ids = gen_out.sequences scores = gen_out.scores # list of (B, V) tensors # Process generated tokens input_len = input_ids.shape[1] generated_ids = output_ids[0, input_len:] for i, token_id in enumerate(generated_ids): token_id = token_id.item() token_text = self.tokenizer.decode([token_id]) all_tokens.append(token_text) lp = torch.nn.functional.log_softmax(scores[i][0, :], dim=-1) all_logprobs.append(lp[token_id].item()) text_offsets.append(current_offset) current_offset += len(token_text) decoded_output = self.tokenizer.decode(generated_ids, skip_special_tokens=True) logprobs_dict = { "tokens": all_tokens, "token_logprobs": all_logprobs, "top_logprobs": [], "text_offset": text_offsets, } return [{"generated_text": decoded_output, "logprobs": logprobs_dict}] else: output_ids = gen_out input_len = input_ids.shape[1] decoded_output = self.tokenizer.decode( output_ids[0][input_len:], skip_special_tokens=True ) return [{"generated_text": decoded_output}]