from typing import Dict, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the model and tokenizer. Args: path: Path to the model directory (will be "/repository" in endpoint container) """ self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading tokenizer from {path}...") self.tokenizer = AutoTokenizer.from_pretrained(path) # ✅ ADD THIS: Set pad token to prevent corruption if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(f"Loading model from {path} on device: {self.device}...") self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float16, # Your merged model is fp16 trust_remote_code=True, # StarCoder2 may use custom code device_map="auto", # Efficient placement on GPU/CPU ) self.model.eval() # Set to evaluation mode print("✅ Model loaded successfully!") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process inference requests. Args: data: Dictionary containing: - inputs: str (code prompt to complete) - parameters: dict (optional, generation parameters) Returns: Dictionary with generated_text key """ # Extract inputs inputs = data.get("inputs", "") parameters = data.get("parameters", {}) or {} if not isinstance(inputs, str): raise ValueError("`inputs` must be a string") if not inputs.strip(): raise ValueError("`inputs` cannot be empty") # Generation parameters with sensible defaults gen_kwargs = { "max_new_tokens": parameters.get("max_new_tokens", 128), "temperature": parameters.get("temperature", 0.2), # Lower = more deterministic "top_p": parameters.get("top_p", 0.95), # Nucleus sampling "top_k": parameters.get("top_k", 50), # Top-k sampling "do_sample": parameters.get("do_sample", True), # Use sampling "repetition_penalty": parameters.get("repetition_penalty", 1.0), } print(f"Generating with parameters: {gen_kwargs}") # Tokenize input enc = self.tokenizer(inputs, return_tensors="pt",padding=True, # Enable padding truncation=True, # Truncate if needed max_length=2048).to(self.device) # Generate with no_grad for efficiency with torch.no_grad(): out = self.model.generate(**enc, **gen_kwargs, pad_token_id=self.tokenizer.pad_token_id) # Tell model which token is padding # Decode output generated_text = self.tokenizer.decode(out[0], skip_special_tokens=True) return { "generated_text": generated_text }