RanjithaRuttala commited on
Commit
80f8f92
·
verified ·
1 Parent(s): 418c370

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +74 -0
handler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path: str = "/repository"):
7
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ print(f"Loading tokenizer from {path}...")
10
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
11
+
12
+ # StarCoder2 FIXES
13
+ if self.tokenizer.pad_token is None:
14
+ self.tokenizer.pad_token = self.tokenizer.eos_token
15
+ self.tokenizer.padding_side = "left" # Critical for code completion
16
+
17
+ print(f"Loading model from {path} on device: {self.device}...")
18
+ self.model = AutoModelForCausalLM.from_pretrained(
19
+ path,
20
+ torch_dtype=torch.float16,
21
+ trust_remote_code=True,
22
+ device_map="auto",
23
+ )
24
+ self.model.eval()
25
+ print("✅ Model loaded successfully!")
26
+
27
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
28
+ inputs = data.get("inputs", "")
29
+ parameters = data.get("parameters", {}) or {}
30
+
31
+ if not isinstance(inputs, str) or not inputs.strip():
32
+ return {"generated_text": ""}
33
+
34
+ gen_kwargs = {
35
+ "max_new_tokens": min(parameters.get("max_new_tokens", 256), 512), # Cap for stability
36
+ "temperature": parameters.get("temperature", 0.2),
37
+ "top_p": parameters.get("top_p", 0.95),
38
+ "top_k": parameters.get("top_k", 50),
39
+ "do_sample": parameters.get("do_sample", True),
40
+ "repetition_penalty": parameters.get("repetition_penalty", 1.1), # Slightly higher
41
+ "eos_token_id": self.tokenizer.eos_token_id,
42
+ "pad_token_id": self.tokenizer.pad_token_id,
43
+ }
44
+
45
+ print(f"Generating with parameters: {gen_kwargs}")
46
+
47
+ # StarCoder2 tokenization
48
+ inputs = inputs.strip()
49
+ tokenized = self.tokenizer(
50
+ inputs,
51
+ return_tensors="pt",
52
+ truncation=True,
53
+ max_length=2048,
54
+ padding=True
55
+ ).to(self.device)
56
+
57
+ with torch.no_grad():
58
+ # Generate ONLY new tokens (not full sequence)
59
+ outputs = self.model.generate(
60
+ input_ids=tokenized.input_ids,
61
+ attention_mask=tokenized.attention_mask,
62
+ **gen_kwargs,
63
+ use_cache=True
64
+ )
65
+
66
+ # Extract ONLY newly generated tokens
67
+ new_tokens = outputs[0][len(tokenized.input_ids[0]):]
68
+ generated_text = self.tokenizer.decode(
69
+ new_tokens,
70
+ skip_special_tokens=True,
71
+ clean_up_tokenization_spaces=True
72
+ )
73
+
74
+ return {"generated_text": generated_text.strip()}