| | |
| | """ |
| | Helion-2.5-Rnd Utility Functions |
| | Common utilities for model inference and processing |
| | """ |
| |
|
| | import json |
| | import logging |
| | import os |
| | import time |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import yaml |
| | from transformers import AutoTokenizer |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class ModelConfig: |
| | """Model configuration manager""" |
| | |
| | def __init__(self, config_path: str = "model_config.yaml"): |
| | """Load configuration from YAML file""" |
| | self.config_path = Path(config_path) |
| | self.config = self._load_config() |
| | |
| | def _load_config(self) -> Dict[str, Any]: |
| | """Load YAML configuration""" |
| | if not self.config_path.exists(): |
| | logger.warning(f"Config file not found: {self.config_path}") |
| | return self._default_config() |
| | |
| | with open(self.config_path, 'r') as f: |
| | config = yaml.safe_load(f) |
| | |
| | logger.info(f"Loaded configuration from {self.config_path}") |
| | return config |
| | |
| | def _default_config(self) -> Dict[str, Any]: |
| | """Return default configuration""" |
| | return { |
| | "model": { |
| | "name": "DeepXR/Helion-2.5-Rnd", |
| | "max_position_embeddings": 131072, |
| | }, |
| | "inference": { |
| | "default_parameters": { |
| | "temperature": 0.7, |
| | "top_p": 0.9, |
| | "max_new_tokens": 4096, |
| | } |
| | } |
| | } |
| | |
| | def get(self, key: str, default: Any = None) -> Any: |
| | """Get configuration value by dot-separated key""" |
| | keys = key.split('.') |
| | value = self.config |
| | |
| | for k in keys: |
| | if isinstance(value, dict): |
| | value = value.get(k) |
| | if value is None: |
| | return default |
| | else: |
| | return default |
| | |
| | return value |
| |
|
| |
|
| | class TokenCounter: |
| | """Token counting utilities""" |
| | |
| | def __init__(self, model_name: str = "meta-llama/Meta-Llama-3.1-70B"): |
| | """Initialize tokenizer for counting""" |
| | try: |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | except Exception as e: |
| | logger.warning(f"Failed to load tokenizer: {e}") |
| | self.tokenizer = None |
| | |
| | def count_tokens(self, text: str) -> int: |
| | """Count tokens in text""" |
| | if self.tokenizer is None: |
| | |
| | return len(text) // 4 |
| | |
| | return len(self.tokenizer.encode(text)) |
| | |
| | def count_messages_tokens(self, messages: List[Dict[str, str]]) -> int: |
| | """Count tokens in message list""" |
| | total = 0 |
| | for msg in messages: |
| | |
| | total += self.count_tokens(msg.get('role', '')) |
| | total += self.count_tokens(msg.get('content', '')) |
| | |
| | total += 4 |
| | |
| | return total |
| | |
| | def truncate_to_tokens( |
| | self, |
| | text: str, |
| | max_tokens: int, |
| | from_end: bool = False |
| | ) -> str: |
| | """Truncate text to maximum token count""" |
| | if self.tokenizer is None: |
| | |
| | max_chars = max_tokens * 4 |
| | if from_end: |
| | return text[-max_chars:] |
| | return text[:max_chars] |
| | |
| | tokens = self.tokenizer.encode(text) |
| | |
| | if len(tokens) <= max_tokens: |
| | return text |
| | |
| | if from_end: |
| | truncated_tokens = tokens[-max_tokens:] |
| | else: |
| | truncated_tokens = tokens[:max_tokens] |
| | |
| | return self.tokenizer.decode(truncated_tokens) |
| |
|
| |
|
| | class PromptTemplate: |
| | """Prompt templating utilities""" |
| | |
| | TEMPLATES = { |
| | "chat": ( |
| | "{% for message in messages %}" |
| | "<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n" |
| | "{% endfor %}" |
| | "<|im_start|>assistant\n" |
| | ), |
| | "instruction": ( |
| | "### Instruction:\n{instruction}\n\n" |
| | "### Response:\n" |
| | ), |
| | "qa": ( |
| | "Question: {question}\n\n" |
| | "Answer: " |
| | ), |
| | "code": ( |
| | "# Task: {task}\n\n" |
| | "```{language}\n" |
| | ), |
| | "analysis": ( |
| | "Analyze the following:\n\n{content}\n\n" |
| | "Analysis:" |
| | ) |
| | } |
| | |
| | @classmethod |
| | def format(cls, template_name: str, **kwargs) -> str: |
| | """Format a template with given arguments""" |
| | template = cls.TEMPLATES.get(template_name) |
| | if template is None: |
| | raise ValueError(f"Unknown template: {template_name}") |
| | |
| | |
| | try: |
| | return template.format(**kwargs) |
| | except KeyError as e: |
| | raise ValueError(f"Missing required argument: {e}") |
| | |
| | @classmethod |
| | def format_chat(cls, messages: List[Dict[str, str]]) -> str: |
| | """Format chat messages into prompt""" |
| | formatted = "" |
| | for msg in messages: |
| | role = msg.get('role', 'user') |
| | content = msg.get('content', '') |
| | formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n" |
| | formatted += "<|im_start|>assistant\n" |
| | return formatted |
| |
|
| |
|
| | class ResponseParser: |
| | """Parse and validate model responses""" |
| | |
| | @staticmethod |
| | def extract_code(response: str, language: Optional[str] = None) -> str: |
| | """Extract code from markdown code blocks""" |
| | import re |
| | |
| | if language: |
| | pattern = f"```{language}\n(.*?)```" |
| | else: |
| | pattern = r"```(?:\w+)?\n(.*?)```" |
| | |
| | matches = re.findall(pattern, response, re.DOTALL) |
| | |
| | if matches: |
| | return matches[0].strip() |
| | |
| | |
| | return response.strip() |
| | |
| | @staticmethod |
| | def extract_json(response: str) -> Optional[Dict]: |
| | """Extract and parse JSON from response""" |
| | import re |
| | |
| | |
| | json_pattern = r"```json\n(.*?)```" |
| | matches = re.findall(json_pattern, response, re.DOTALL) |
| | |
| | if matches: |
| | try: |
| | return json.loads(matches[0]) |
| | except json.JSONDecodeError: |
| | pass |
| | |
| | |
| | try: |
| | return json.loads(response) |
| | except json.JSONDecodeError: |
| | return None |
| | |
| | @staticmethod |
| | def split_sections(response: str) -> Dict[str, str]: |
| | """Split response into sections based on headers""" |
| | import re |
| | |
| | sections = {} |
| | current_section = "main" |
| | current_content = [] |
| | |
| | for line in response.split('\n'): |
| | |
| | header_match = re.match(r'^#{1,3}\s+(.+)$', line) |
| | if header_match: |
| | |
| | if current_content: |
| | sections[current_section] = '\n'.join(current_content).strip() |
| | |
| | |
| | current_section = header_match.group(1).lower().replace(' ', '_') |
| | current_content = [] |
| | else: |
| | current_content.append(line) |
| | |
| | |
| | if current_content: |
| | sections[current_section] = '\n'.join(current_content).strip() |
| | |
| | return sections |
| |
|
| |
|
| | class PerformanceMonitor: |
| | """Monitor inference performance""" |
| | |
| | def __init__(self): |
| | self.requests = [] |
| | self.start_time = time.time() |
| | |
| | def record_request( |
| | self, |
| | duration: float, |
| | input_tokens: int, |
| | output_tokens: int, |
| | success: bool = True |
| | ): |
| | """Record a request""" |
| | self.requests.append({ |
| | 'timestamp': time.time(), |
| | 'duration': duration, |
| | 'input_tokens': input_tokens, |
| | 'output_tokens': output_tokens, |
| | 'success': success, |
| | 'tokens_per_second': output_tokens / duration if duration > 0 else 0 |
| | }) |
| | |
| | def get_stats(self) -> Dict[str, Any]: |
| | """Get performance statistics""" |
| | if not self.requests: |
| | return { |
| | 'total_requests': 0, |
| | 'uptime_seconds': time.time() - self.start_time |
| | } |
| | |
| | successful = [r for r in self.requests if r['success']] |
| | |
| | return { |
| | 'total_requests': len(self.requests), |
| | 'successful_requests': len(successful), |
| | 'failed_requests': len(self.requests) - len(successful), |
| | 'uptime_seconds': time.time() - self.start_time, |
| | 'avg_duration': sum(r['duration'] for r in successful) / len(successful), |
| | 'avg_tokens_per_second': sum(r['tokens_per_second'] for r in successful) / len(successful), |
| | 'total_input_tokens': sum(r['input_tokens'] for r in self.requests), |
| | 'total_output_tokens': sum(r['output_tokens'] for r in self.requests), |
| | } |
| | |
| | def reset(self): |
| | """Reset statistics""" |
| | self.requests = [] |
| | self.start_time = time.time() |
| |
|
| |
|
| | class SafetyFilter: |
| | """Basic safety filtering for outputs""" |
| | |
| | UNSAFE_PATTERNS = [ |
| | r'\b(kill|murder|suicide)\s+(?:yourself|myself)', |
| | r'\b(bomb|weapon)\s+(?:making|instructions)', |
| | r'\bhate\s+speech\b', |
| | ] |
| | |
| | @classmethod |
| | def is_safe(cls, text: str) -> Tuple[bool, Optional[str]]: |
| | """ |
| | Check if text is safe |
| | |
| | Returns: |
| | (is_safe, reason) |
| | """ |
| | import re |
| | |
| | text_lower = text.lower() |
| | |
| | for pattern in cls.UNSAFE_PATTERNS: |
| | if re.search(pattern, text_lower): |
| | return False, f"Matched unsafe pattern: {pattern}" |
| | |
| | return True, None |
| | |
| | @classmethod |
| | def filter_response(cls, text: str, replacement: str = "[FILTERED]") -> str: |
| | """Filter unsafe content from response""" |
| | is_safe, reason = cls.is_safe(text) |
| | |
| | if not is_safe: |
| | logger.warning(f"Filtered unsafe content: {reason}") |
| | return replacement |
| | |
| | return text |
| |
|
| |
|
| | def get_gpu_info() -> Dict[str, Any]: |
| | """Get GPU information""" |
| | if not torch.cuda.is_available(): |
| | return {"available": False} |
| | |
| | info = { |
| | "available": True, |
| | "count": torch.cuda.device_count(), |
| | "devices": [] |
| | } |
| | |
| | for i in range(torch.cuda.device_count()): |
| | device_info = { |
| | "id": i, |
| | "name": torch.cuda.get_device_name(i), |
| | "memory_total": torch.cuda.get_device_properties(i).total_memory, |
| | "memory_allocated": torch.cuda.memory_allocated(i), |
| | "memory_reserved": torch.cuda.memory_reserved(i), |
| | } |
| | info["devices"].append(device_info) |
| | |
| | return info |
| |
|
| |
|
| | def format_bytes(bytes_value: int) -> str: |
| | """Format bytes to human-readable string""" |
| | for unit in ['B', 'KB', 'MB', 'GB', 'TB']: |
| | if bytes_value < 1024.0: |
| | return f"{bytes_value:.2f} {unit}" |
| | bytes_value /= 1024.0 |
| | return f"{bytes_value:.2f} PB" |