File size: 5,768 Bytes
ec5bb4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import torch
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the model and tokenizer for Phi-4 inference.
        
        Args:
            path (str): Path to the model directory
        """
        # Set default parameters for inference
        self.max_new_tokens = 4096
        self.temperature = 0.7
        self.top_p = 0.9
        self.do_sample = True
        
        # Determine if CUDA is available
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        
        # Load model with appropriate settings
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=self.dtype,
            device_map="auto" if self.device == "cuda" else None,
            trust_remote_code=True
        )
        
        # Move model to device if CPU
        if self.device == "cpu":
            self.model = self.model.to(self.device)
        
        # Set model to evaluation mode
        self.model.eval()
        
        print(f"Model loaded on {self.device} using {self.dtype}")

    def format_prompt(self, prompt: str) -> str:
        """
        Format the user prompt for Phi-4 model.
        
        Args:
            prompt (str): User input prompt
            
        Returns:
            str: Formatted prompt
        """
        # For Phi-4-mini-instruct, the prompt format is simple
        # You may need to adjust this based on your specific fine-tuning
        return prompt

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process the input data and generate a response using the Phi-4 model.
        
        Args:
            data (Dict[str, Any]): Input data containing the prompt and generation parameters
            
        Returns:
            Dict[str, Any]: Model response
        """
        # Extract input parameters with defaults
        prompt = data.pop("inputs", "")
        parameters = data.pop("parameters", {})
        
        # Get generation parameters with fallbacks to defaults
        max_new_tokens = parameters.get("max_new_tokens", self.max_new_tokens)
        temperature = parameters.get("temperature", self.temperature)
        top_p = parameters.get("top_p", self.top_p)
        do_sample = parameters.get("do_sample", self.do_sample)
        stream = parameters.get("stream", False)
        
        # Format the prompt according to model requirements
        formatted_prompt = self.format_prompt(prompt)
        
        # Tokenize the input
        inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
        
        # Handle streaming if requested
        if stream:
            return self._generate_stream(inputs, max_new_tokens, temperature, top_p, do_sample)
        else:
            return self._generate(inputs, max_new_tokens, temperature, top_p, do_sample)
    
    def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample):
        """Generate text non-streaming mode"""
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=do_sample,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        # Decode the generated text
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Return only the newly generated text (without the prompt)
        prompt_length = len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
        response_text = generated_text[prompt_length:]
        
        return {"generated_text": response_text}
    
    def _generate_stream(self, inputs, max_new_tokens, temperature, top_p, do_sample):
        """Generate text in streaming mode"""
        # Create a streamer object
        streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
        
        # Set up generation in a separate thread
        generation_kwargs = dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=do_sample,
            pad_token_id=self.tokenizer.eos_token_id
        )
        
        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()
        
        # Determine input text length to strip it from outputs
        prompt_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
        prompt_length = len(prompt_text)
        
        # Stream the output
        def generate_stream():
            # Skip the prompt part in the first chunk
            first_chunk = True
            for text in streamer:
                if first_chunk:
                    # Only yield new tokens, not the original prompt
                    if len(text) > prompt_length:
                        yield {"generated_text": text[prompt_length:]}
                    first_chunk = False
                else:
                    yield {"generated_text": text}
        
        return generate_stream()

# For local testing
if __name__ == "__main__":
    # Example usage
    handler = EndpointHandler()
    result = handler({"inputs": "What are the major features of Phi-4?"})
    print(result)