from typing import Dict, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler: def __init__(self, path: str = "/repository"): self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading tokenizer from {path}...") self.tokenizer = AutoTokenizer.from_pretrained(path) # StarCoder2 FIXES # if self.tokenizer.pad_token is None: # self.tokenizer.pad_token = self.tokenizer.eos_token # self.tokenizer.padding_side = "left" # Critical for code completion # Basic tokenizer fixes only if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(f"Loading model from {path} on device: {self.device}...") self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float16, # ✅ back to float16 from bfloat16 trust_remote_code=True, device_map="auto", low_cpu_mem_usage=True # attn_implementation="flash_attention_2" # ✅ Faster + stable ) self.model.eval() print("✅ Model loaded successfully!") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = data.get("inputs", "") parameters = data.get("parameters", {}) or {} if not isinstance(inputs, str) or not inputs.strip(): return {"generated_text": ""} # # ✅ StarCoder2: Add code context prefix # prompt = f"{inputs}" gen_kwargs = { "max_new_tokens": min(parameters.get("max_new_tokens", 256), 512), # Cap for stability "temperature": parameters.get("temperature", 0.3), "top_p": parameters.get("top_p", 0.95), "top_k": parameters.get("top_k", 50), "do_sample": parameters.get("do_sample", True), "repetition_penalty": parameters.get("repetition_penalty", 1.1), # Slightly higher "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, } print(f"Generating with parameters: {gen_kwargs}") # print(f"Prompt length: {len(prompt)} | Gen params: {gen_kwargs}") # StarCoder2 tokenization inputs = inputs.strip() tokenized = self.tokenizer( # prompt, inputs, return_tensors="pt", truncation=True, max_length=2048, padding=True ).to(self.device) with torch.no_grad(): # Generate ONLY new tokens (not full sequence) outputs = self.model.generate( input_ids=tokenized.input_ids, attention_mask=tokenized.attention_mask, **gen_kwargs, use_cache=True ) # Extract ONLY newly generated tokens new_tokens = outputs[0][len(tokenized.input_ids[0]):] generated = self.tokenizer.decode( new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True ) # generated = generated.replace("", "").replace("", "").strip() return {"generated_text": generated.strip()}