yol146 commited on
Commit
ec5bb4e
·
1 Parent(s): 3be7031

add handler

Browse files
Files changed (1) hide show
  1. handler.py +156 -0
handler.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import Dict, List, Any
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ """
10
+ Initialize the model and tokenizer for Phi-4 inference.
11
+
12
+ Args:
13
+ path (str): Path to the model directory
14
+ """
15
+ # Set default parameters for inference
16
+ self.max_new_tokens = 4096
17
+ self.temperature = 0.7
18
+ self.top_p = 0.9
19
+ self.do_sample = True
20
+
21
+ # Determine if CUDA is available
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
24
+
25
+ # Load tokenizer
26
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
27
+
28
+ # Load model with appropriate settings
29
+ self.model = AutoModelForCausalLM.from_pretrained(
30
+ path,
31
+ torch_dtype=self.dtype,
32
+ device_map="auto" if self.device == "cuda" else None,
33
+ trust_remote_code=True
34
+ )
35
+
36
+ # Move model to device if CPU
37
+ if self.device == "cpu":
38
+ self.model = self.model.to(self.device)
39
+
40
+ # Set model to evaluation mode
41
+ self.model.eval()
42
+
43
+ print(f"Model loaded on {self.device} using {self.dtype}")
44
+
45
+ def format_prompt(self, prompt: str) -> str:
46
+ """
47
+ Format the user prompt for Phi-4 model.
48
+
49
+ Args:
50
+ prompt (str): User input prompt
51
+
52
+ Returns:
53
+ str: Formatted prompt
54
+ """
55
+ # For Phi-4-mini-instruct, the prompt format is simple
56
+ # You may need to adjust this based on your specific fine-tuning
57
+ return prompt
58
+
59
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
60
+ """
61
+ Process the input data and generate a response using the Phi-4 model.
62
+
63
+ Args:
64
+ data (Dict[str, Any]): Input data containing the prompt and generation parameters
65
+
66
+ Returns:
67
+ Dict[str, Any]: Model response
68
+ """
69
+ # Extract input parameters with defaults
70
+ prompt = data.pop("inputs", "")
71
+ parameters = data.pop("parameters", {})
72
+
73
+ # Get generation parameters with fallbacks to defaults
74
+ max_new_tokens = parameters.get("max_new_tokens", self.max_new_tokens)
75
+ temperature = parameters.get("temperature", self.temperature)
76
+ top_p = parameters.get("top_p", self.top_p)
77
+ do_sample = parameters.get("do_sample", self.do_sample)
78
+ stream = parameters.get("stream", False)
79
+
80
+ # Format the prompt according to model requirements
81
+ formatted_prompt = self.format_prompt(prompt)
82
+
83
+ # Tokenize the input
84
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
85
+
86
+ # Handle streaming if requested
87
+ if stream:
88
+ return self._generate_stream(inputs, max_new_tokens, temperature, top_p, do_sample)
89
+ else:
90
+ return self._generate(inputs, max_new_tokens, temperature, top_p, do_sample)
91
+
92
+ def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample):
93
+ """Generate text non-streaming mode"""
94
+ with torch.no_grad():
95
+ outputs = self.model.generate(
96
+ **inputs,
97
+ max_new_tokens=max_new_tokens,
98
+ temperature=temperature,
99
+ top_p=top_p,
100
+ do_sample=do_sample,
101
+ pad_token_id=self.tokenizer.eos_token_id
102
+ )
103
+
104
+ # Decode the generated text
105
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
106
+
107
+ # Return only the newly generated text (without the prompt)
108
+ prompt_length = len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
109
+ response_text = generated_text[prompt_length:]
110
+
111
+ return {"generated_text": response_text}
112
+
113
+ def _generate_stream(self, inputs, max_new_tokens, temperature, top_p, do_sample):
114
+ """Generate text in streaming mode"""
115
+ # Create a streamer object
116
+ streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
117
+
118
+ # Set up generation in a separate thread
119
+ generation_kwargs = dict(
120
+ **inputs,
121
+ streamer=streamer,
122
+ max_new_tokens=max_new_tokens,
123
+ temperature=temperature,
124
+ top_p=top_p,
125
+ do_sample=do_sample,
126
+ pad_token_id=self.tokenizer.eos_token_id
127
+ )
128
+
129
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
130
+ thread.start()
131
+
132
+ # Determine input text length to strip it from outputs
133
+ prompt_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
134
+ prompt_length = len(prompt_text)
135
+
136
+ # Stream the output
137
+ def generate_stream():
138
+ # Skip the prompt part in the first chunk
139
+ first_chunk = True
140
+ for text in streamer:
141
+ if first_chunk:
142
+ # Only yield new tokens, not the original prompt
143
+ if len(text) > prompt_length:
144
+ yield {"generated_text": text[prompt_length:]}
145
+ first_chunk = False
146
+ else:
147
+ yield {"generated_text": text}
148
+
149
+ return generate_stream()
150
+
151
+ # For local testing
152
+ if __name__ == "__main__":
153
+ # Example usage
154
+ handler = EndpointHandler()
155
+ result = handler({"inputs": "What are the major features of Phi-4?"})
156
+ print(result)