tans37 commited on
Commit
d4710a4
·
verified ·
1 Parent(s): 2db3bee

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +61 -0
  2. requirements.txt +5 -3
handler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+ import json
5
+
6
+ # Configuration
7
+ BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
8
+ ADAPTER = "tans37/mistral-query-router"
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ # Load the base model and tokenizer
13
+ self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
14
+
15
+ # Load base model in half precision for efficiency
16
+ base_model = AutoModelForCausalLM.from_pretrained(
17
+ BASE_MODEL,
18
+ torch_dtype=torch.float16,
19
+ device_map="auto"
20
+ )
21
+
22
+ # Load the Peft adapter
23
+ # 'path' is the directory where the handler is located (the adapter repo)
24
+ self.model = PeftModel.from_pretrained(base_model, path)
25
+ self.model.eval()
26
+
27
+ print(f"[Handler] Loaded LoRA adapter from {path} onto {BASE_MODEL}")
28
+
29
+ def __call__(self, data):
30
+ """
31
+ Args:
32
+ data (:obj: `dict`):
33
+ subset of the request body with the following keys:
34
+ - `inputs`: the prompt to be processed
35
+ - `parameters`: optional generation parameters
36
+ """
37
+ inputs = data.pop("inputs", data)
38
+ parameters = data.pop("parameters", {
39
+ "max_new_tokens": 128,
40
+ "temperature": 0.1,
41
+ "top_p": 0.9,
42
+ "do_sample": False
43
+ })
44
+
45
+ # Tokenize
46
+ inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
47
+
48
+ # Generate
49
+ with torch.no_grad():
50
+ output_tokens = self.model.generate(
51
+ **inputs,
52
+ **parameters
53
+ )
54
+
55
+ # Decode
56
+ # We only want the new tokens, so we slice the output
57
+ input_len = inputs["input_ids"].shape[1]
58
+ new_tokens = output_tokens[0][input_len:]
59
+ prediction = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
60
+
61
+ return [{"generated_text": prediction}]
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
- transformers>=4.48.0
2
- peft>=0.14.0
3
- tokenizers>=0.21.0
 
 
 
1
+ transformers>=4.40.0
2
+ peft>=0.10.0
3
+ torch>=2.2.0
4
+ accelerate>=0.29.0
5
+ bitsandbytes>=0.43.0