PEFT_FP16_starcoder2-3b / handler_old.py
RanjithaRuttala's picture
Rename handler.py to handler_old.py
418c370 verified
from typing import Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the model and tokenizer.
Args:
path: Path to the model directory (will be "/repository" in endpoint container)
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading tokenizer from {path}...")
self.tokenizer = AutoTokenizer.from_pretrained(path)
# ✅ ADD THIS: Set pad token to prevent corruption
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"Loading model from {path} on device: {self.device}...")
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16, # Your merged model is fp16
trust_remote_code=True, # StarCoder2 may use custom code
device_map="auto", # Efficient placement on GPU/CPU
)
self.model.eval() # Set to evaluation mode
print("✅ Model loaded successfully!")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process inference requests.
Args:
data: Dictionary containing:
- inputs: str (code prompt to complete)
- parameters: dict (optional, generation parameters)
Returns:
Dictionary with generated_text key
"""
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {}) or {}
if not isinstance(inputs, str):
raise ValueError("`inputs` must be a string")
if not inputs.strip():
raise ValueError("`inputs` cannot be empty")
# Generation parameters with sensible defaults
gen_kwargs = {
"max_new_tokens": parameters.get("max_new_tokens", 128),
"temperature": parameters.get("temperature", 0.2), # Lower = more deterministic
"top_p": parameters.get("top_p", 0.95), # Nucleus sampling
"top_k": parameters.get("top_k", 50), # Top-k sampling
"do_sample": parameters.get("do_sample", True), # Use sampling
"repetition_penalty": parameters.get("repetition_penalty", 1.0),
}
print(f"Generating with parameters: {gen_kwargs}")
# Tokenize input
enc = self.tokenizer(inputs, return_tensors="pt",padding=True, # Enable padding
truncation=True, # Truncate if needed
max_length=2048).to(self.device)
# Generate with no_grad for efficiency
with torch.no_grad():
out = self.model.generate(**enc,
**gen_kwargs,
pad_token_id=self.tokenizer.pad_token_id) # Tell model which token is padding
# Decode output
generated_text = self.tokenizer.decode(out[0], skip_special_tokens=True)
return {
"generated_text": generated_text
}