Trouter-Library commited on
Commit
e2035b7
·
verified ·
1 Parent(s): ee2ec96

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +279 -0
inference.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-V2 Inference Script
3
+ Provides optimized inference with various sampling strategies.
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
8
+ import argparse
9
+ from typing import Optional, List, Dict
10
+ import time
11
+
12
+
13
+ class HelionInference:
14
+ """Inference wrapper for Helion-V2 model."""
15
+
16
+ def __init__(
17
+ self,
18
+ model_name: str = "DeepXR/Helion-V2",
19
+ device: str = "auto",
20
+ load_in_4bit: bool = False,
21
+ load_in_8bit: bool = False,
22
+ use_flash_attention: bool = True,
23
+ ):
24
+ """
25
+ Initialize the Helion-V2 model for inference.
26
+
27
+ Args:
28
+ model_name: HuggingFace model identifier
29
+ device: Device placement ('auto', 'cuda', 'cpu')
30
+ load_in_4bit: Use 4-bit quantization
31
+ load_in_8bit: Use 8-bit quantization
32
+ use_flash_attention: Enable Flash Attention 2
33
+ """
34
+ self.model_name = model_name
35
+ self.device = device
36
+
37
+ print(f"Loading tokenizer from {model_name}...")
38
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+
40
+ # Configure quantization
41
+ quantization_config = None
42
+ if load_in_4bit:
43
+ quantization_config = BitsAndBytesConfig(
44
+ load_in_4bit=True,
45
+ bnb_4bit_compute_dtype=torch.float16,
46
+ bnb_4bit_use_double_quant=True,
47
+ bnb_4bit_quant_type="nf4"
48
+ )
49
+ elif load_in_8bit:
50
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
51
+
52
+ print(f"Loading model from {model_name}...")
53
+ model_kwargs = {
54
+ "device_map": device,
55
+ "torch_dtype": torch.float16,
56
+ "quantization_config": quantization_config,
57
+ }
58
+
59
+ if use_flash_attention and not (load_in_4bit or load_in_8bit):
60
+ model_kwargs["attn_implementation"] = "flash_attention_2"
61
+
62
+ self.model = AutoModelForCausalLM.from_pretrained(
63
+ model_name,
64
+ **model_kwargs
65
+ )
66
+
67
+ self.model.eval()
68
+ print("Model loaded successfully!")
69
+
70
+ def generate(
71
+ self,
72
+ prompt: str,
73
+ max_new_tokens: int = 512,
74
+ temperature: float = 0.7,
75
+ top_p: float = 0.9,
76
+ top_k: int = 50,
77
+ repetition_penalty: float = 1.1,
78
+ do_sample: bool = True,
79
+ num_return_sequences: int = 1,
80
+ ) -> List[str]:
81
+ """
82
+ Generate text from a prompt.
83
+
84
+ Args:
85
+ prompt: Input text prompt
86
+ max_new_tokens: Maximum tokens to generate
87
+ temperature: Sampling temperature (higher = more random)
88
+ top_p: Nucleus sampling threshold
89
+ top_k: Top-k sampling parameter
90
+ repetition_penalty: Penalty for repeating tokens
91
+ do_sample: Use sampling vs greedy decoding
92
+ num_return_sequences: Number of sequences to generate
93
+
94
+ Returns:
95
+ List of generated text strings
96
+ """
97
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
98
+
99
+ start_time = time.time()
100
+
101
+ with torch.no_grad():
102
+ outputs = self.model.generate(
103
+ **inputs,
104
+ max_new_tokens=max_new_tokens,
105
+ temperature=temperature,
106
+ top_p=top_p,
107
+ top_k=top_k,
108
+ repetition_penalty=repetition_penalty,
109
+ do_sample=do_sample,
110
+ num_return_sequences=num_return_sequences,
111
+ pad_token_id=self.tokenizer.eos_token_id,
112
+ )
113
+
114
+ generation_time = time.time() - start_time
115
+ tokens_generated = outputs.shape[1] - inputs["input_ids"].shape[1]
116
+ tokens_per_second = tokens_generated / generation_time
117
+
118
+ results = []
119
+ for output in outputs:
120
+ text = self.tokenizer.decode(output, skip_special_tokens=True)
121
+ results.append(text)
122
+
123
+ print(f"\nGeneration stats:")
124
+ print(f" Tokens generated: {tokens_generated}")
125
+ print(f" Time: {generation_time:.2f}s")
126
+ print(f" Speed: {tokens_per_second:.2f} tokens/s")
127
+
128
+ return results
129
+
130
+ def chat(
131
+ self,
132
+ messages: List[Dict[str, str]],
133
+ max_new_tokens: int = 512,
134
+ temperature: float = 0.7,
135
+ top_p: float = 0.9,
136
+ **kwargs
137
+ ) -> str:
138
+ """
139
+ Generate response in chat format.
140
+
141
+ Args:
142
+ messages: List of message dicts with 'role' and 'content'
143
+ max_new_tokens: Maximum tokens to generate
144
+ temperature: Sampling temperature
145
+ top_p: Nucleus sampling threshold
146
+ **kwargs: Additional generation parameters
147
+
148
+ Returns:
149
+ Generated response text
150
+ """
151
+ input_text = self.tokenizer.apply_chat_template(
152
+ messages,
153
+ tokenize=False,
154
+ add_generation_prompt=True
155
+ )
156
+
157
+ results = self.generate(
158
+ input_text,
159
+ max_new_tokens=max_new_tokens,
160
+ temperature=temperature,
161
+ top_p=top_p,
162
+ **kwargs
163
+ )
164
+
165
+ # Extract only the assistant's response
166
+ full_text = results[0]
167
+ if "<|assistant|>" in full_text:
168
+ response = full_text.split("<|assistant|>")[-1].split("<|end|>")[0].strip()
169
+ else:
170
+ response = full_text[len(input_text):].strip()
171
+
172
+ return response
173
+
174
+
175
+ def main():
176
+ parser = argparse.ArgumentParser(description="Helion-V2 Inference")
177
+ parser.add_argument(
178
+ "--model",
179
+ type=str,
180
+ default="DeepXR/Helion-V2",
181
+ help="Model name or path"
182
+ )
183
+ parser.add_argument(
184
+ "--prompt",
185
+ type=str,
186
+ required=True,
187
+ help="Input prompt"
188
+ )
189
+ parser.add_argument(
190
+ "--max-tokens",
191
+ type=int,
192
+ default=512,
193
+ help="Maximum tokens to generate"
194
+ )
195
+ parser.add_argument(
196
+ "--temperature",
197
+ type=float,
198
+ default=0.7,
199
+ help="Sampling temperature"
200
+ )
201
+ parser.add_argument(
202
+ "--top-p",
203
+ type=float,
204
+ default=0.9,
205
+ help="Nucleus sampling threshold"
206
+ )
207
+ parser.add_argument(
208
+ "--top-k",
209
+ type=int,
210
+ default=50,
211
+ help="Top-k sampling"
212
+ )
213
+ parser.add_argument(
214
+ "--repetition-penalty",
215
+ type=float,
216
+ default=1.1,
217
+ help="Repetition penalty"
218
+ )
219
+ parser.add_argument(
220
+ "--load-in-4bit",
221
+ action="store_true",
222
+ help="Load model in 4-bit precision"
223
+ )
224
+ parser.add_argument(
225
+ "--load-in-8bit",
226
+ action="store_true",
227
+ help="Load model in 8-bit precision"
228
+ )
229
+ parser.add_argument(
230
+ "--device",
231
+ type=str,
232
+ default="auto",
233
+ help="Device placement"
234
+ )
235
+ parser.add_argument(
236
+ "--chat-mode",
237
+ action="store_true",
238
+ help="Use chat format"
239
+ )
240
+
241
+ args = parser.parse_args()
242
+
243
+ # Initialize model
244
+ inference = HelionInference(
245
+ model_name=args.model,
246
+ device=args.device,
247
+ load_in_4bit=args.load_in_4bit,
248
+ load_in_8bit=args.load_in_8bit,
249
+ )
250
+
251
+ # Generate response
252
+ if args.chat_mode:
253
+ messages = [
254
+ {"role": "system", "content": "You are a helpful AI assistant."},
255
+ {"role": "user", "content": args.prompt}
256
+ ]
257
+ response = inference.chat(
258
+ messages,
259
+ max_new_tokens=args.max_tokens,
260
+ temperature=args.temperature,
261
+ top_p=args.top_p,
262
+ top_k=args.top_k,
263
+ repetition_penalty=args.repetition_penalty,
264
+ )
265
+ print(f"\nAssistant: {response}")
266
+ else:
267
+ results = inference.generate(
268
+ args.prompt,
269
+ max_new_tokens=args.max_tokens,
270
+ temperature=args.temperature,
271
+ top_p=args.top_p,
272
+ top_k=args.top_k,
273
+ repetition_penalty=args.repetition_penalty,
274
+ )
275
+ print(f"\nGenerated text:\n{results[0]}")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()