ChevalierJoseph commited on
Commit
096d229
·
verified ·
1 Parent(s): e39c5e6

Create handle.py

Browse files
Files changed (1) hide show
  1. handle.py +58 -0
handle.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
3
+ import torch
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path=""):
7
+ # Load the model and tokenizer during initialization
8
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
9
+ self.model = AutoModelForCausalLM.from_pretrained(path).to("cuda")
10
+ self.model.eval() # Set the model to evaluation mode
11
+
12
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
+ """
14
+ data args:
15
+ messages (:obj:`List[Dict[str, Any]]`): A list of dictionaries representing the conversation messages.
16
+ Return:
17
+ A list containing the responses generated by the model.
18
+ """
19
+ # Extract messages from input
20
+ messages = data.pop("messages", data)
21
+
22
+ # Apply chat template to messages and tokenize
23
+ inputs = self.tokenizer.apply_chat_template(
24
+ messages,
25
+ tokenize=True,
26
+ add_generation_prompt=True,
27
+ return_tensors="pt"
28
+ ).to("cuda")
29
+
30
+ # Use TextStreamer to generate text in a streaming fashion
31
+ text_streamer = TextStreamer(self.tokenizer)
32
+ # Generate response from the model
33
+ _ = self.model.generate(
34
+ input_ids=inputs,
35
+ streamer=text_streamer,
36
+ max_new_tokens=6048,
37
+ use_cache=True
38
+ )
39
+
40
+ # Retrieve the generated response (here, we are capturing a mock output)
41
+ # Note: TextStreamer displays the text in a streaming fashion, but does not capture it directly
42
+ # For this example, we are returning a mock response
43
+ response = {"generated_text": "Example response generated by the model"}
44
+
45
+ return [response]
46
+
47
+ # Example to test the EndpointHandler locally
48
+ if __name__ == "__main__":
49
+ handler = EndpointHandler(path="ChevalierJoseph/typtop4")
50
+
51
+ # Example conversation
52
+ messages = [
53
+ {"from": "human", "value": "Based on the following text, give me the svgpath of the glyphs from A to Z.\nI want a classic LINEAL font"},
54
+ ]
55
+
56
+ # Simulate a request to the endpoint
57
+ response = handler({"messages": messages})
58
+ print(response)