RanjithaRuttala commited on
Commit
ef83004
·
verified ·
1 Parent(s): 80f8f92

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -3
handler.py CHANGED
@@ -17,9 +17,10 @@ class EndpointHandler:
17
  print(f"Loading model from {path} on device: {self.device}...")
18
  self.model = AutoModelForCausalLM.from_pretrained(
19
  path,
20
- torch_dtype=torch.float16,
21
  trust_remote_code=True,
22
  device_map="auto",
 
23
  )
24
  self.model.eval()
25
  print("✅ Model loaded successfully!")
@@ -31,6 +32,9 @@ class EndpointHandler:
31
  if not isinstance(inputs, str) or not inputs.strip():
32
  return {"generated_text": ""}
33
 
 
 
 
34
  gen_kwargs = {
35
  "max_new_tokens": min(parameters.get("max_new_tokens", 256), 512), # Cap for stability
36
  "temperature": parameters.get("temperature", 0.2),
@@ -42,11 +46,13 @@ class EndpointHandler:
42
  "pad_token_id": self.tokenizer.pad_token_id,
43
  }
44
 
45
- print(f"Generating with parameters: {gen_kwargs}")
 
46
 
47
  # StarCoder2 tokenization
48
  inputs = inputs.strip()
49
  tokenized = self.tokenizer(
 
50
  inputs,
51
  return_tensors="pt",
52
  truncation=True,
@@ -65,10 +71,11 @@ class EndpointHandler:
65
 
66
  # Extract ONLY newly generated tokens
67
  new_tokens = outputs[0][len(tokenized.input_ids[0]):]
68
- generated_text = self.tokenizer.decode(
69
  new_tokens,
70
  skip_special_tokens=True,
71
  clean_up_tokenization_spaces=True
72
  )
73
 
 
74
  return {"generated_text": generated_text.strip()}
 
17
  print(f"Loading model from {path} on device: {self.device}...")
18
  self.model = AutoModelForCausalLM.from_pretrained(
19
  path,
20
+ torch_dtype=torch.bfloat16, # ✅ Changed to bfloat16
21
  trust_remote_code=True,
22
  device_map="auto",
23
+ attn_implementation="flash_attention_2" # ✅ Faster + stable
24
  )
25
  self.model.eval()
26
  print("✅ Model loaded successfully!")
 
32
  if not isinstance(inputs, str) or not inputs.strip():
33
  return {"generated_text": ""}
34
 
35
+ # ✅ StarCoder2: Add code context prefix
36
+ prompt = f"<fim_prefix>{inputs}<fim_suffix><fim_middle>"
37
+
38
  gen_kwargs = {
39
  "max_new_tokens": min(parameters.get("max_new_tokens", 256), 512), # Cap for stability
40
  "temperature": parameters.get("temperature", 0.2),
 
46
  "pad_token_id": self.tokenizer.pad_token_id,
47
  }
48
 
49
+ # print(f"Generating with parameters: {gen_kwargs}")
50
+ print(f"Prompt length: {len(prompt)} | Gen params: {gen_kwargs}")
51
 
52
  # StarCoder2 tokenization
53
  inputs = inputs.strip()
54
  tokenized = self.tokenizer(
55
+ prompt,
56
  inputs,
57
  return_tensors="pt",
58
  truncation=True,
 
71
 
72
  # Extract ONLY newly generated tokens
73
  new_tokens = outputs[0][len(tokenized.input_ids[0]):]
74
+ generated = self.tokenizer.decode(
75
  new_tokens,
76
  skip_special_tokens=True,
77
  clean_up_tokenization_spaces=True
78
  )
79
 
80
+ generated = generated.replace("<fim_middle>", "").replace("<fim_suffix>", "").strip()
81
  return {"generated_text": generated_text.strip()}