from typing import Any, Dict, List from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import os MAX_INPUT_LENGTH = 256 MAX_OUTPUT_LENGTH = 128 class EndpointHandler: def __init__(self, model_dir: str = "", num_threads: int | None = None, generation_config: Dict[str, Any] | None = None, **kwargs: Any) -> None: # Set environment hints for CPU efficiency os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") # Configure torch threading for CPU if num_threads: try: torch.set_num_threads(num_threads) torch.set_num_interop_threads(max(1, num_threads // 2)) except Exception: pass os.environ.setdefault("OMP_NUM_THREADS", str(num_threads)) os.environ.setdefault("MKL_NUM_THREADS", str(num_threads)) self.device = "cpu" # Force CPU usage # Load tokenizer & model with CPU-friendly settings self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, low_cpu_mem_usage=True) self.model.eval() self.model.to(self.device) # Optional bfloat16 cast on CPU (beneficial on Sapphire Rapids/oneDNN) self._use_bf16 = False if os.getenv("ENABLE_BF16", "1") == "1": try: self.model = self.model.to(dtype=torch.bfloat16) self._use_bf16 = True except Exception: self._use_bf16 = False # Determine a safe pad token id pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id # Default fast generation config (greedy) overridable by caller default_gen = { "max_length": MAX_OUTPUT_LENGTH, "num_beams": 4, # Greedy for CPU speed "do_sample": False, "no_repeat_ngram_size": 3, "early_stopping": True, "use_cache": True, "pad_token_id": pad_id, } if generation_config: default_gen.update(generation_config) self.generation_args = default_gen def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: inputs = data.get("inputs") if not inputs: raise ValueError("No 'inputs' found in the request data.") if isinstance(inputs, str): inputs = [inputs] # Allow per-request overrides under 'parameters' per_request_params = data.get("parameters") or {} # Unpack nested generate_parameters dict if provided if isinstance(per_request_params.get("generate_parameters"), dict): nested = per_request_params.pop("generate_parameters") per_request_params.update(nested) # Extract decode-only params decode_params = {} if "clean_up_tokenization_spaces" in per_request_params: decode_params["clean_up_tokenization_spaces"] = per_request_params.pop("clean_up_tokenization_spaces") # Sanitize sampling-related params to prevent invalid configs do_sample_req = bool(per_request_params.get("do_sample", self.generation_args.get("do_sample", False))) if "temperature" in per_request_params: # If not sampling, drop temperature entirely if not do_sample_req: per_request_params.pop("temperature", None) else: # Ensure strictly positive float try: temp_val = float(per_request_params["temperature"]) except (TypeError, ValueError): temp_val = None if not temp_val or temp_val <= 0: per_request_params["temperature"] = 1.0 # Filter only supported generation args to avoid warnings allowed = set(self.model.generation_config.to_dict().keys()) | { "max_length","min_length","max_new_tokens","num_beams","num_return_sequences","temperature","top_k","top_p", "repetition_penalty","length_penalty","early_stopping","do_sample","no_repeat_ngram_size","use_cache", "pad_token_id","eos_token_id","bos_token_id","decoder_start_token_id","num_beam_groups","diversity_penalty", "penalty_alpha","typical_p","return_dict_in_generate","output_scores","output_attentions","output_hidden_states" } # Important: don't pass attention_mask via kwargs since we pass it explicitly per_request_params.pop("attention_mask", None) filtered_params = {k: v for k, v in per_request_params.items() if k in allowed} gen_args = {**self.generation_args, **filtered_params} tokenized_inputs = self.tokenizer( inputs, max_length=MAX_INPUT_LENGTH, padding=True, truncation=True, return_tensors="pt" ).to(self.device) try: with torch.inference_mode(): outputs = self.model.generate( tokenized_inputs["input_ids"], attention_mask=tokenized_inputs["attention_mask"], **gen_args ) decoded_outputs = self.tokenizer.batch_decode( outputs, skip_special_tokens=True, **decode_params ) results = [{"generated_text": text} for text in decoded_outputs] return results except Exception as e: return [{"generated_text": f"Error: {str(e)}"}]