0chanly commited on
Commit
8408dc0
·
verified ·
1 Parent(s): fc65fcc

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +80 -0
handler.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom handler for Constitutional AI models
3
+ """
4
+
5
+ from typing import Dict, List, Any
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ """
13
+ Initialize the handler with model and tokenizer
14
+
15
+ Args:
16
+ path: Path to the model directory
17
+ """
18
+ # Load tokenizer
19
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
20
+ if self.tokenizer.pad_token is None:
21
+ self.tokenizer.pad_token = self.tokenizer.eos_token
22
+
23
+ # Load model
24
+ self.model = AutoModelForCausalLM.from_pretrained(
25
+ path,
26
+ torch_dtype=torch.float16,
27
+ device_map="auto",
28
+ low_cpu_mem_usage=True
29
+ )
30
+ self.model.eval()
31
+
32
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
33
+ """
34
+ Process the inference request
35
+
36
+ Args:
37
+ data: A dictionary containing:
38
+ - inputs (str): The input text
39
+ - parameters (dict): Generation parameters
40
+
41
+ Returns:
42
+ List containing the generated text
43
+ """
44
+ # Get inputs
45
+ inputs = data.pop("inputs", data)
46
+ parameters = data.pop("parameters", {})
47
+
48
+ # Set default parameters
49
+ max_new_tokens = parameters.get("max_new_tokens", 200)
50
+ temperature = parameters.get("temperature", 0.7)
51
+ do_sample = parameters.get("do_sample", True)
52
+ top_p = parameters.get("top_p", 0.95)
53
+
54
+ # Tokenize
55
+ input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
56
+
57
+ # Move to same device as model
58
+ if torch.cuda.is_available():
59
+ input_ids = input_ids.cuda()
60
+
61
+ # Generate
62
+ with torch.no_grad():
63
+ outputs = self.model.generate(
64
+ input_ids,
65
+ max_new_tokens=max_new_tokens,
66
+ temperature=temperature,
67
+ do_sample=do_sample,
68
+ top_p=top_p,
69
+ pad_token_id=self.tokenizer.pad_token_id,
70
+ eos_token_id=self.tokenizer.eos_token_id
71
+ )
72
+
73
+ # Decode
74
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+
76
+ # Remove the input prompt from the output
77
+ if generated_text.startswith(inputs):
78
+ generated_text = generated_text[len(inputs):].strip()
79
+
80
+ return [{"generated_text": generated_text}]