RanjithaRuttala commited on
Commit
14032ea
·
verified ·
1 Parent(s): 0a11e2e

Update handler.py

Browse files

- replace bfloat16 to float16
- added basic tokeniser fixes
- Removed prompt

Files changed (1) hide show
  1. handler.py +14 -9
handler.py CHANGED
@@ -10,16 +10,21 @@ class EndpointHandler:
10
  self.tokenizer = AutoTokenizer.from_pretrained(path)
11
 
12
  # StarCoder2 FIXES
 
 
 
 
 
13
  if self.tokenizer.pad_token is None:
14
  self.tokenizer.pad_token = self.tokenizer.eos_token
15
- self.tokenizer.padding_side = "left" # Critical for code completion
16
 
17
  print(f"Loading model from {path} on device: {self.device}...")
18
  self.model = AutoModelForCausalLM.from_pretrained(
19
  path,
20
- torch_dtype=torch.bfloat16, # ✅ Changed to bfloat16
21
  trust_remote_code=True,
22
  device_map="auto",
 
23
  # attn_implementation="flash_attention_2" # ✅ Faster + stable
24
  )
25
  self.model.eval()
@@ -32,12 +37,12 @@ class EndpointHandler:
32
  if not isinstance(inputs, str) or not inputs.strip():
33
  return {"generated_text": ""}
34
 
35
- # ✅ StarCoder2: Add code context prefix
36
- prompt = f"<fim_prefix>{inputs}<fim_suffix><fim_middle>"
37
 
38
  gen_kwargs = {
39
  "max_new_tokens": min(parameters.get("max_new_tokens", 256), 512), # Cap for stability
40
- "temperature": parameters.get("temperature", 0.2),
41
  "top_p": parameters.get("top_p", 0.95),
42
  "top_k": parameters.get("top_k", 50),
43
  "do_sample": parameters.get("do_sample", True),
@@ -46,13 +51,13 @@ class EndpointHandler:
46
  "pad_token_id": self.tokenizer.pad_token_id,
47
  }
48
 
49
- # print(f"Generating with parameters: {gen_kwargs}")
50
- print(f"Prompt length: {len(prompt)} | Gen params: {gen_kwargs}")
51
 
52
  # StarCoder2 tokenization
53
  inputs = inputs.strip()
54
  tokenized = self.tokenizer(
55
- prompt,
56
  inputs,
57
  return_tensors="pt",
58
  truncation=True,
@@ -77,5 +82,5 @@ class EndpointHandler:
77
  clean_up_tokenization_spaces=True
78
  )
79
 
80
- generated = generated.replace("<fim_middle>", "").replace("<fim_suffix>", "").strip()
81
  return {"generated_text": generated.strip()}
 
10
  self.tokenizer = AutoTokenizer.from_pretrained(path)
11
 
12
  # StarCoder2 FIXES
13
+ # if self.tokenizer.pad_token is None:
14
+ # self.tokenizer.pad_token = self.tokenizer.eos_token
15
+ # self.tokenizer.padding_side = "left" # Critical for code completion
16
+
17
+ # Basic tokenizer fixes only
18
  if self.tokenizer.pad_token is None:
19
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
20
 
21
  print(f"Loading model from {path} on device: {self.device}...")
22
  self.model = AutoModelForCausalLM.from_pretrained(
23
  path,
24
+ torch_dtype=torch.float16, # ✅ back to float16 from bfloat16
25
  trust_remote_code=True,
26
  device_map="auto",
27
+ low_cpu_mem_usage=True
28
  # attn_implementation="flash_attention_2" # ✅ Faster + stable
29
  )
30
  self.model.eval()
 
37
  if not isinstance(inputs, str) or not inputs.strip():
38
  return {"generated_text": ""}
39
 
40
+ # # ✅ StarCoder2: Add code context prefix
41
+ # prompt = f"<fim_prefix>{inputs}<fim_suffix><fim_middle>"
42
 
43
  gen_kwargs = {
44
  "max_new_tokens": min(parameters.get("max_new_tokens", 256), 512), # Cap for stability
45
+ "temperature": parameters.get("temperature", 0.3),
46
  "top_p": parameters.get("top_p", 0.95),
47
  "top_k": parameters.get("top_k", 50),
48
  "do_sample": parameters.get("do_sample", True),
 
51
  "pad_token_id": self.tokenizer.pad_token_id,
52
  }
53
 
54
+ print(f"Generating with parameters: {gen_kwargs}")
55
+ # print(f"Prompt length: {len(prompt)} | Gen params: {gen_kwargs}")
56
 
57
  # StarCoder2 tokenization
58
  inputs = inputs.strip()
59
  tokenized = self.tokenizer(
60
+ # prompt,
61
  inputs,
62
  return_tensors="pt",
63
  truncation=True,
 
82
  clean_up_tokenization_spaces=True
83
  )
84
 
85
+ # generated = generated.replace("<fim_middle>", "").replace("<fim_suffix>", "").strip()
86
  return {"generated_text": generated.strip()}