File size: 1,800 Bytes
c8723a7
1987c4f
 
c8723a7
 
c8f63e6
c8723a7
 
 
 
 
 
 
 
1987c4f
 
 
 
 
 
 
 
 
 
 
 
c8723a7
1987c4f
c8723a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
# from unsloth import FastLanguageModel
from transformers import AutoTokenizer,AutoModelForCausalLM

# === Config ===
MODEL_DIR = "RanjithaRuttala/PEFT_starcode2-3b_merged"  
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.2
TOP_P = 0.95

# === Load merged model and tokenizer ===
print("[Handler] Loading model and tokenizer...")
# FastLanguageModel can load merged PEFT models directly
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name=MODEL_DIR,
#     dtype=torch.float16,
#     load_in_4bit=False
# )

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_DIR,
    torch_dtype=torch.float16,
    device_map="auto"
)

model = model.to(DEVICE)
model.eval()

def handle(inputs):
    """
    inputs: dict, e.g., {"prompt": "def add_numbers(a, b):", "max_new_tokens": 128}
    returns: dict, {"completion": "generated code"}
    """

    prompt = inputs.get("prompt", "")
    if not prompt:
        return {"completion": ""}

    max_new_tokens = inputs.get("max_new_tokens", MAX_NEW_TOKENS)

    # Tokenize prompt (like training)
    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(DEVICE)

    # Generate continuation (similar to training inference)
    output_ids = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

    # Decode
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    completion = generated_text[len(prompt):]  # remove prompt from output

    return {"completion": completion}