MudassirFayaz commited on
Commit
8f1b759
·
verified ·
1 Parent(s): 982e13e
Files changed (1) hide show
  1. handler.py +3 -4
handler.py CHANGED
@@ -4,11 +4,11 @@ import torch
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- # Load PEFT config to get base model path
8
  config = PeftConfig.from_pretrained(path)
9
  self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
10
  base_model = AutoModelForCausalLM.from_pretrained(
11
- config.base_model_name_or_path, torch_dtype=torch.float16
 
12
  )
13
  self.model = PeftModel.from_pretrained(base_model, path)
14
  self.model.eval()
@@ -18,5 +18,4 @@ class EndpointHandler:
18
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
19
  with torch.no_grad():
20
  outputs = self.model.generate(**inputs, max_new_tokens=800)
21
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
22
- return {"generated_text": response}
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
 
7
  config = PeftConfig.from_pretrained(path)
8
  self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
9
  base_model = AutoModelForCausalLM.from_pretrained(
10
+ config.base_model_name_or_path,
11
+ torch_dtype=torch.float16
12
  )
13
  self.model = PeftModel.from_pretrained(base_model, path)
14
  self.model.eval()
 
18
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
19
  with torch.no_grad():
20
  outputs = self.model.generate(**inputs, max_new_tokens=800)
21
+ return {"generated_text": self.tokenizer.decode(outputs[0], skip_special_tokens=True)}