MudassirFayaz commited on
Commit
982e13e
·
verified ·
1 Parent(s): 6ef538a
Files changed (1) hide show
  1. handler.py +9 -3
handler.py CHANGED
@@ -1,16 +1,22 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
2
  import torch
3
 
4
  class EndpointHandler:
5
  def __init__(self, path=""):
6
- self.tokenizer = AutoTokenizer.from_pretrained(path)
7
- self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
 
 
 
 
 
8
  self.model.eval()
9
 
10
  def __call__(self, inputs):
11
  prompt = inputs.get("inputs", "")
12
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
13
  with torch.no_grad():
14
- outputs = self.model.generate(**inputs, max_new_tokens=200)
15
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
16
  return {"generated_text": response}
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ from peft import PeftModel, PeftConfig
3
  import torch
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ # Load PEFT config to get base model path
8
+ config = PeftConfig.from_pretrained(path)
9
+ self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
10
+ base_model = AutoModelForCausalLM.from_pretrained(
11
+ config.base_model_name_or_path, torch_dtype=torch.float16
12
+ )
13
+ self.model = PeftModel.from_pretrained(base_model, path)
14
  self.model.eval()
15
 
16
  def __call__(self, inputs):
17
  prompt = inputs.get("inputs", "")
18
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
19
  with torch.no_grad():
20
+ outputs = self.model.generate(**inputs, max_new_tokens=800)
21
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
22
  return {"generated_text": response}