Trouter-Library commited on
Commit
44b08c5
·
verified ·
1 Parent(s): ef0c6e1

Create inference/utils.py

Browse files
Files changed (1) hide show
  1. inference/utils.py +376 -0
inference/utils.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Helion-2.5-Rnd Utility Functions
4
+ Common utilities for model inference and processing
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+
14
+ import torch
15
+ import yaml
16
+ from transformers import AutoTokenizer
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ModelConfig:
22
+ """Model configuration manager"""
23
+
24
+ def __init__(self, config_path: str = "model_config.yaml"):
25
+ """Load configuration from YAML file"""
26
+ self.config_path = Path(config_path)
27
+ self.config = self._load_config()
28
+
29
+ def _load_config(self) -> Dict[str, Any]:
30
+ """Load YAML configuration"""
31
+ if not self.config_path.exists():
32
+ logger.warning(f"Config file not found: {self.config_path}")
33
+ return self._default_config()
34
+
35
+ with open(self.config_path, 'r') as f:
36
+ config = yaml.safe_load(f)
37
+
38
+ logger.info(f"Loaded configuration from {self.config_path}")
39
+ return config
40
+
41
+ def _default_config(self) -> Dict[str, Any]:
42
+ """Return default configuration"""
43
+ return {
44
+ "model": {
45
+ "name": "DeepXR/Helion-2.5-Rnd",
46
+ "max_position_embeddings": 131072,
47
+ },
48
+ "inference": {
49
+ "default_parameters": {
50
+ "temperature": 0.7,
51
+ "top_p": 0.9,
52
+ "max_new_tokens": 4096,
53
+ }
54
+ }
55
+ }
56
+
57
+ def get(self, key: str, default: Any = None) -> Any:
58
+ """Get configuration value by dot-separated key"""
59
+ keys = key.split('.')
60
+ value = self.config
61
+
62
+ for k in keys:
63
+ if isinstance(value, dict):
64
+ value = value.get(k)
65
+ if value is None:
66
+ return default
67
+ else:
68
+ return default
69
+
70
+ return value
71
+
72
+
73
+ class TokenCounter:
74
+ """Token counting utilities"""
75
+
76
+ def __init__(self, model_name: str = "meta-llama/Meta-Llama-3.1-70B"):
77
+ """Initialize tokenizer for counting"""
78
+ try:
79
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
80
+ except Exception as e:
81
+ logger.warning(f"Failed to load tokenizer: {e}")
82
+ self.tokenizer = None
83
+
84
+ def count_tokens(self, text: str) -> int:
85
+ """Count tokens in text"""
86
+ if self.tokenizer is None:
87
+ # Rough estimate: ~4 characters per token
88
+ return len(text) // 4
89
+
90
+ return len(self.tokenizer.encode(text))
91
+
92
+ def count_messages_tokens(self, messages: List[Dict[str, str]]) -> int:
93
+ """Count tokens in message list"""
94
+ total = 0
95
+ for msg in messages:
96
+ # Add tokens for role and content
97
+ total += self.count_tokens(msg.get('role', ''))
98
+ total += self.count_tokens(msg.get('content', ''))
99
+ # Add overhead for formatting
100
+ total += 4
101
+
102
+ return total
103
+
104
+ def truncate_to_tokens(
105
+ self,
106
+ text: str,
107
+ max_tokens: int,
108
+ from_end: bool = False
109
+ ) -> str:
110
+ """Truncate text to maximum token count"""
111
+ if self.tokenizer is None:
112
+ # Character-based truncation
113
+ max_chars = max_tokens * 4
114
+ if from_end:
115
+ return text[-max_chars:]
116
+ return text[:max_chars]
117
+
118
+ tokens = self.tokenizer.encode(text)
119
+
120
+ if len(tokens) <= max_tokens:
121
+ return text
122
+
123
+ if from_end:
124
+ truncated_tokens = tokens[-max_tokens:]
125
+ else:
126
+ truncated_tokens = tokens[:max_tokens]
127
+
128
+ return self.tokenizer.decode(truncated_tokens)
129
+
130
+
131
+ class PromptTemplate:
132
+ """Prompt templating utilities"""
133
+
134
+ TEMPLATES = {
135
+ "chat": (
136
+ "{% for message in messages %}"
137
+ "<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n"
138
+ "{% endfor %}"
139
+ "<|im_start|>assistant\n"
140
+ ),
141
+ "instruction": (
142
+ "### Instruction:\n{instruction}\n\n"
143
+ "### Response:\n"
144
+ ),
145
+ "qa": (
146
+ "Question: {question}\n\n"
147
+ "Answer: "
148
+ ),
149
+ "code": (
150
+ "# Task: {task}\n\n"
151
+ "```{language}\n"
152
+ ),
153
+ "analysis": (
154
+ "Analyze the following:\n\n{content}\n\n"
155
+ "Analysis:"
156
+ )
157
+ }
158
+
159
+ @classmethod
160
+ def format(cls, template_name: str, **kwargs) -> str:
161
+ """Format a template with given arguments"""
162
+ template = cls.TEMPLATES.get(template_name)
163
+ if template is None:
164
+ raise ValueError(f"Unknown template: {template_name}")
165
+
166
+ # Simple string formatting
167
+ try:
168
+ return template.format(**kwargs)
169
+ except KeyError as e:
170
+ raise ValueError(f"Missing required argument: {e}")
171
+
172
+ @classmethod
173
+ def format_chat(cls, messages: List[Dict[str, str]]) -> str:
174
+ """Format chat messages into prompt"""
175
+ formatted = ""
176
+ for msg in messages:
177
+ role = msg.get('role', 'user')
178
+ content = msg.get('content', '')
179
+ formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n"
180
+ formatted += "<|im_start|>assistant\n"
181
+ return formatted
182
+
183
+
184
+ class ResponseParser:
185
+ """Parse and validate model responses"""
186
+
187
+ @staticmethod
188
+ def extract_code(response: str, language: Optional[str] = None) -> str:
189
+ """Extract code from markdown code blocks"""
190
+ import re
191
+
192
+ if language:
193
+ pattern = f"```{language}\n(.*?)```"
194
+ else:
195
+ pattern = r"```(?:\w+)?\n(.*?)```"
196
+
197
+ matches = re.findall(pattern, response, re.DOTALL)
198
+
199
+ if matches:
200
+ return matches[0].strip()
201
+
202
+ # No code blocks found, return as is
203
+ return response.strip()
204
+
205
+ @staticmethod
206
+ def extract_json(response: str) -> Optional[Dict]:
207
+ """Extract and parse JSON from response"""
208
+ import re
209
+
210
+ # Try to find JSON in code blocks
211
+ json_pattern = r"```json\n(.*?)```"
212
+ matches = re.findall(json_pattern, response, re.DOTALL)
213
+
214
+ if matches:
215
+ try:
216
+ return json.loads(matches[0])
217
+ except json.JSONDecodeError:
218
+ pass
219
+
220
+ # Try to parse entire response as JSON
221
+ try:
222
+ return json.loads(response)
223
+ except json.JSONDecodeError:
224
+ return None
225
+
226
+ @staticmethod
227
+ def split_sections(response: str) -> Dict[str, str]:
228
+ """Split response into sections based on headers"""
229
+ import re
230
+
231
+ sections = {}
232
+ current_section = "main"
233
+ current_content = []
234
+
235
+ for line in response.split('\n'):
236
+ # Check for markdown headers
237
+ header_match = re.match(r'^#{1,3}\s+(.+)$', line)
238
+ if header_match:
239
+ # Save previous section
240
+ if current_content:
241
+ sections[current_section] = '\n'.join(current_content).strip()
242
+
243
+ # Start new section
244
+ current_section = header_match.group(1).lower().replace(' ', '_')
245
+ current_content = []
246
+ else:
247
+ current_content.append(line)
248
+
249
+ # Save last section
250
+ if current_content:
251
+ sections[current_section] = '\n'.join(current_content).strip()
252
+
253
+ return sections
254
+
255
+
256
+ class PerformanceMonitor:
257
+ """Monitor inference performance"""
258
+
259
+ def __init__(self):
260
+ self.requests = []
261
+ self.start_time = time.time()
262
+
263
+ def record_request(
264
+ self,
265
+ duration: float,
266
+ input_tokens: int,
267
+ output_tokens: int,
268
+ success: bool = True
269
+ ):
270
+ """Record a request"""
271
+ self.requests.append({
272
+ 'timestamp': time.time(),
273
+ 'duration': duration,
274
+ 'input_tokens': input_tokens,
275
+ 'output_tokens': output_tokens,
276
+ 'success': success,
277
+ 'tokens_per_second': output_tokens / duration if duration > 0 else 0
278
+ })
279
+
280
+ def get_stats(self) -> Dict[str, Any]:
281
+ """Get performance statistics"""
282
+ if not self.requests:
283
+ return {
284
+ 'total_requests': 0,
285
+ 'uptime_seconds': time.time() - self.start_time
286
+ }
287
+
288
+ successful = [r for r in self.requests if r['success']]
289
+
290
+ return {
291
+ 'total_requests': len(self.requests),
292
+ 'successful_requests': len(successful),
293
+ 'failed_requests': len(self.requests) - len(successful),
294
+ 'uptime_seconds': time.time() - self.start_time,
295
+ 'avg_duration': sum(r['duration'] for r in successful) / len(successful),
296
+ 'avg_tokens_per_second': sum(r['tokens_per_second'] for r in successful) / len(successful),
297
+ 'total_input_tokens': sum(r['input_tokens'] for r in self.requests),
298
+ 'total_output_tokens': sum(r['output_tokens'] for r in self.requests),
299
+ }
300
+
301
+ def reset(self):
302
+ """Reset statistics"""
303
+ self.requests = []
304
+ self.start_time = time.time()
305
+
306
+
307
+ class SafetyFilter:
308
+ """Basic safety filtering for outputs"""
309
+
310
+ UNSAFE_PATTERNS = [
311
+ r'\b(kill|murder|suicide)\s+(?:yourself|myself)',
312
+ r'\b(bomb|weapon)\s+(?:making|instructions)',
313
+ r'\bhate\s+speech\b',
314
+ ]
315
+
316
+ @classmethod
317
+ def is_safe(cls, text: str) -> Tuple[bool, Optional[str]]:
318
+ """
319
+ Check if text is safe
320
+
321
+ Returns:
322
+ (is_safe, reason)
323
+ """
324
+ import re
325
+
326
+ text_lower = text.lower()
327
+
328
+ for pattern in cls.UNSAFE_PATTERNS:
329
+ if re.search(pattern, text_lower):
330
+ return False, f"Matched unsafe pattern: {pattern}"
331
+
332
+ return True, None
333
+
334
+ @classmethod
335
+ def filter_response(cls, text: str, replacement: str = "[FILTERED]") -> str:
336
+ """Filter unsafe content from response"""
337
+ is_safe, reason = cls.is_safe(text)
338
+
339
+ if not is_safe:
340
+ logger.warning(f"Filtered unsafe content: {reason}")
341
+ return replacement
342
+
343
+ return text
344
+
345
+
346
+ def get_gpu_info() -> Dict[str, Any]:
347
+ """Get GPU information"""
348
+ if not torch.cuda.is_available():
349
+ return {"available": False}
350
+
351
+ info = {
352
+ "available": True,
353
+ "count": torch.cuda.device_count(),
354
+ "devices": []
355
+ }
356
+
357
+ for i in range(torch.cuda.device_count()):
358
+ device_info = {
359
+ "id": i,
360
+ "name": torch.cuda.get_device_name(i),
361
+ "memory_total": torch.cuda.get_device_properties(i).total_memory,
362
+ "memory_allocated": torch.cuda.memory_allocated(i),
363
+ "memory_reserved": torch.cuda.memory_reserved(i),
364
+ }
365
+ info["devices"].append(device_info)
366
+
367
+ return info
368
+
369
+
370
+ def format_bytes(bytes_value: int) -> str:
371
+ """Format bytes to human-readable string"""
372
+ for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
373
+ if bytes_value < 1024.0:
374
+ return f"{bytes_value:.2f} {unit}"
375
+ bytes_value /= 1024.0
376
+ return f"{bytes_value:.2f} PB"