RanjithaRuttala's picture
Update handler.py
c8f63e6 verified
import torch
# from unsloth import FastLanguageModel
from transformers import AutoTokenizer,AutoModelForCausalLM
# === Config ===
MODEL_DIR = "RanjithaRuttala/PEFT_starcode2-3b_merged"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.2
TOP_P = 0.95
# === Load merged model and tokenizer ===
print("[Handler] Loading model and tokenizer...")
# FastLanguageModel can load merged PEFT models directly
# model, tokenizer = FastLanguageModel.from_pretrained(
# model_name=MODEL_DIR,
# dtype=torch.float16,
# load_in_4bit=False
# )
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForCausalLM.from_pretrained(
MODEL_DIR,
torch_dtype=torch.float16,
device_map="auto"
)
model = model.to(DEVICE)
model.eval()
def handle(inputs):
"""
inputs: dict, e.g., {"prompt": "def add_numbers(a, b):", "max_new_tokens": 128}
returns: dict, {"completion": "generated code"}
"""
prompt = inputs.get("prompt", "")
if not prompt:
return {"completion": ""}
max_new_tokens = inputs.get("max_new_tokens", MAX_NEW_TOKENS)
# Tokenize prompt (like training)
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(DEVICE)
# Generate continuation (similar to training inference)
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
completion = generated_text[len(prompt):] # remove prompt from output
return {"completion": completion}