Upload 20 files
Browse files- utils/__init__.py +18 -0
- utils/attention_connector.py +135 -0
- utils/attention_trigger_system.py +198 -0
- utils/collator.py +100 -0
- utils/convert_checkpoints.py +253 -0
- utils/debug_helper.py +124 -0
- utils/dual_encoder_utils.py +151 -0
- utils/emergency_abort.py +65 -0
- utils/event_bus.py +48 -0
- utils/event_system.py +169 -0
- utils/gpu_config_optimizer.py +143 -0
- utils/model_utils.py +84 -0
- utils/nltk_stub.py +119 -0
- utils/output_formatter.py +175 -0
- utils/prepare_hf_training.py +149 -0
- utils/prepare_hf_transformer_training.py +335 -0
- utils/sentence_transformer_utils.py +41 -0
- utils/smartHybridAttention.py +675 -0
- utils/tokenizer_utils.py +147 -0
- utils/transformer_utils.py +160 -0
utils/__init__.py
CHANGED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
# Add project root to path
|
| 4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 5 |
+
|
| 6 |
+
# Utils package initialization
|
| 7 |
+
from .transformer_utils import get_tokenizer, get_sentence_transformer
|
| 8 |
+
try:
|
| 9 |
+
from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
|
| 10 |
+
except ImportError:
|
| 11 |
+
try:
|
| 12 |
+
from smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
|
| 13 |
+
except ImportError:
|
| 14 |
+
print("Warning: Could not import SmartHybridAttention")
|
| 15 |
+
SmartHybridAttention = None
|
| 16 |
+
get_hybrid_attention_config = None
|
| 17 |
+
|
| 18 |
+
__all__ = ['get_tokenizer', 'get_sentence_transformer', 'SmartHybridAttention', 'get_hybrid_attention_config']
|
utils/attention_connector.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Dict, Any, Optional, List, Tuple, Union
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
# Conditional import for attention profile selector
|
| 8 |
+
try:
|
| 9 |
+
from attention_trigger_system import AttentionProfileSelector
|
| 10 |
+
ATTENTION_SELECTOR_AVAILABLE = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
ATTENTION_SELECTOR_AVAILABLE = False
|
| 13 |
+
logging.warning("AttentionProfileSelector not available - content-aware attention disabled")
|
| 14 |
+
|
| 15 |
+
class AttentionConnector:
|
| 16 |
+
"""
|
| 17 |
+
Connects the core architecture with content-aware attention mechanisms.
|
| 18 |
+
This class serves as the integration layer between:
|
| 19 |
+
1. Original text from user input
|
| 20 |
+
2. The attention configuration system
|
| 21 |
+
3. The SmartHybridAttention implementation
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, config_path: Optional[str] = None):
|
| 25 |
+
"""Initialize the connector with configuration"""
|
| 26 |
+
self.logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
# Set up default config path if none provided
|
| 29 |
+
if config_path is None:
|
| 30 |
+
self.config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
|
| 31 |
+
else:
|
| 32 |
+
self.config_path = config_path
|
| 33 |
+
|
| 34 |
+
# Initialize profile selector if available
|
| 35 |
+
self.profile_selector = self._init_profile_selector()
|
| 36 |
+
|
| 37 |
+
# Thread-local storage for current input text
|
| 38 |
+
self.current_input_text = None
|
| 39 |
+
self.current_context = {}
|
| 40 |
+
self.active_profile_id = "standard"
|
| 41 |
+
self.profile_confidence = 1.0
|
| 42 |
+
|
| 43 |
+
def _init_profile_selector(self) -> Optional[Any]:
|
| 44 |
+
"""Initialize the attention profile selector"""
|
| 45 |
+
if not ATTENTION_SELECTOR_AVAILABLE:
|
| 46 |
+
self.logger.warning("AttentionProfileSelector not available - using default attention")
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
selector = AttentionProfileSelector(self.config_path)
|
| 51 |
+
self.logger.info(f"Initialized AttentionProfileSelector with {len(selector.profiles)} profiles")
|
| 52 |
+
return selector
|
| 53 |
+
except Exception as e:
|
| 54 |
+
self.logger.error(f"Error initializing AttentionProfileSelector: {e}")
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
def set_input_text(self, text: str, context: Optional[Dict[str, Any]] = None):
|
| 58 |
+
"""Set the current input text for attention mechanism"""
|
| 59 |
+
self.current_input_text = text
|
| 60 |
+
self.current_context = context or {}
|
| 61 |
+
|
| 62 |
+
# If we have a profile selector, determine the appropriate profile
|
| 63 |
+
if self.profile_selector:
|
| 64 |
+
self.active_profile_id, self.profile_confidence = self.profile_selector.select_profile(
|
| 65 |
+
text, self.current_context
|
| 66 |
+
)
|
| 67 |
+
self.logger.info(f"Selected attention profile: {self.active_profile_id} (confidence: {self.profile_confidence:.2f})")
|
| 68 |
+
|
| 69 |
+
def get_attention_parameters(self) -> Dict[str, Any]:
|
| 70 |
+
"""Get parameters for the current attention profile"""
|
| 71 |
+
if not self.profile_selector:
|
| 72 |
+
return {}
|
| 73 |
+
|
| 74 |
+
return self.profile_selector.get_profile_parameters(self.active_profile_id)
|
| 75 |
+
|
| 76 |
+
def inject_attention_parameters(self, attention_module: Any) -> Any:
|
| 77 |
+
"""Inject content-aware parameters into an attention module"""
|
| 78 |
+
if not hasattr(attention_module, 'set_parameters'):
|
| 79 |
+
self.logger.warning(f"Attention module does not support parameter injection")
|
| 80 |
+
return attention_module
|
| 81 |
+
|
| 82 |
+
params = self.get_attention_parameters()
|
| 83 |
+
attention_module.set_parameters(**params)
|
| 84 |
+
return attention_module
|
| 85 |
+
|
| 86 |
+
def get_input_context(self) -> Dict[str, Any]:
|
| 87 |
+
"""Get the current input context for attention mechanism"""
|
| 88 |
+
return {
|
| 89 |
+
"input_text": self.current_input_text,
|
| 90 |
+
"context": self.current_context,
|
| 91 |
+
"profile_id": self.active_profile_id,
|
| 92 |
+
"confidence": self.profile_confidence
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# Global singleton instance
|
| 96 |
+
_connector_instance = None
|
| 97 |
+
|
| 98 |
+
def get_attention_connector() -> AttentionConnector:
|
| 99 |
+
"""Get or create the global attention connector instance"""
|
| 100 |
+
global _connector_instance
|
| 101 |
+
if _connector_instance is None:
|
| 102 |
+
_connector_instance = AttentionConnector()
|
| 103 |
+
return _connector_instance
|
| 104 |
+
|
| 105 |
+
# Hook functions to integrate with existing architecture
|
| 106 |
+
|
| 107 |
+
def inject_input_text(input_text: str, model_forward_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
| 108 |
+
"""
|
| 109 |
+
Inject input text into model forward kwargs.
|
| 110 |
+
This function should be called in the communicator before forwarding to model.
|
| 111 |
+
"""
|
| 112 |
+
connector = get_attention_connector()
|
| 113 |
+
connector.set_input_text(input_text)
|
| 114 |
+
|
| 115 |
+
# Add input_text to kwargs for models that support it directly
|
| 116 |
+
model_forward_kwargs["original_text"] = input_text
|
| 117 |
+
|
| 118 |
+
return model_forward_kwargs
|
| 119 |
+
|
| 120 |
+
def prepare_attention_context(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Dict[str, Any]:
|
| 121 |
+
"""
|
| 122 |
+
Prepare attention context from connector for attention mechanism.
|
| 123 |
+
This should be called within the attention module's forward method.
|
| 124 |
+
"""
|
| 125 |
+
connector = get_attention_connector()
|
| 126 |
+
attention_context = connector.get_input_context()
|
| 127 |
+
|
| 128 |
+
# Include tensor shapes for attention mechanism's reference
|
| 129 |
+
attention_context.update({
|
| 130 |
+
"query_shape": query.shape,
|
| 131 |
+
"key_shape": key.shape,
|
| 132 |
+
"value_shape": value.shape,
|
| 133 |
+
})
|
| 134 |
+
|
| 135 |
+
return attention_context
|
utils/attention_trigger_system.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, Any, Optional, List, Tuple
|
| 4 |
+
import re
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
class AttentionProfileSelector:
|
| 8 |
+
"""
|
| 9 |
+
Selects appropriate attention profiles based on input characteristics
|
| 10 |
+
and configuration specified in the JSON dataset.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, config_path: Optional[str] = None):
|
| 14 |
+
"""
|
| 15 |
+
Initialize the selector with the provided configuration.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
config_path: Path to the attention configuration JSON
|
| 19 |
+
"""
|
| 20 |
+
if config_path is None:
|
| 21 |
+
# Default to the standard location
|
| 22 |
+
config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
|
| 23 |
+
|
| 24 |
+
self.config = self._load_config(config_path)
|
| 25 |
+
self.profiles = {p["profile_id"]: p for p in self.config.get("attention_profiles", [])}
|
| 26 |
+
self.default_profile_id = self.config.get("default_profile", "standard")
|
| 27 |
+
self.selection_strategy = self.config.get("profile_selection_strategy", {})
|
| 28 |
+
|
| 29 |
+
def _load_config(self, config_path: str) -> Dict[str, Any]:
|
| 30 |
+
"""Load configuration from JSON file."""
|
| 31 |
+
try:
|
| 32 |
+
with open(config_path, 'r') as f:
|
| 33 |
+
return json.load(f)
|
| 34 |
+
except (FileNotFoundError, json.JSONDecodeError) as e:
|
| 35 |
+
print(f"Error loading attention configuration: {e}")
|
| 36 |
+
return {}
|
| 37 |
+
|
| 38 |
+
def select_profile(self,
|
| 39 |
+
input_text: str,
|
| 40 |
+
context: Optional[Dict[str, Any]] = None) -> Tuple[str, float]:
|
| 41 |
+
"""
|
| 42 |
+
Select the most appropriate attention profile based on input characteristics.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
input_text: The user's input text
|
| 46 |
+
context: Additional context about the interaction
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Tuple of (profile_id, confidence)
|
| 50 |
+
"""
|
| 51 |
+
if not self.profiles:
|
| 52 |
+
return self.default_profile_id, 1.0
|
| 53 |
+
|
| 54 |
+
# Initialize scores for each profile
|
| 55 |
+
scores = {profile_id: 0.0 for profile_id in self.profiles.keys()}
|
| 56 |
+
|
| 57 |
+
# Calculate content length score
|
| 58 |
+
input_length = len(input_text)
|
| 59 |
+
for profile_id, profile in self.profiles.items():
|
| 60 |
+
# Check document length thresholds
|
| 61 |
+
length_threshold = profile.get("activation_signals", {}).get("document_length_threshold", 0)
|
| 62 |
+
if input_length > length_threshold and length_threshold > 0:
|
| 63 |
+
scores[profile_id] += self.selection_strategy.get("strategy_parameters", {}).get("input_length_weight", 0.3)
|
| 64 |
+
|
| 65 |
+
# Check content type signals
|
| 66 |
+
for profile_id, profile in self.profiles.items():
|
| 67 |
+
content_signals = profile.get("activation_signals", {}).get("content_type_signals", [])
|
| 68 |
+
matched_signals = sum(1 for signal in content_signals if signal.lower() in input_text.lower())
|
| 69 |
+
if content_signals:
|
| 70 |
+
signal_score = matched_signals / len(content_signals)
|
| 71 |
+
scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
|
| 72 |
+
|
| 73 |
+
# Check structure indicators
|
| 74 |
+
for profile_id, profile in self.profiles.items():
|
| 75 |
+
structure_signals = profile.get("activation_signals", {}).get("structure_indicators", [])
|
| 76 |
+
matched_signals = sum(1 for signal in structure_signals if signal.lower() in input_text.lower())
|
| 77 |
+
if structure_signals:
|
| 78 |
+
signal_score = matched_signals / len(structure_signals)
|
| 79 |
+
scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
|
| 80 |
+
|
| 81 |
+
# Check for explicit request in context
|
| 82 |
+
if context and "requested_attention" in context:
|
| 83 |
+
requested = context["requested_attention"]
|
| 84 |
+
if requested in self.profiles:
|
| 85 |
+
scores[requested] += self.selection_strategy.get("strategy_parameters", {}).get("explicit_request_weight", 1.0)
|
| 86 |
+
|
| 87 |
+
# Get the highest scoring profile
|
| 88 |
+
if not scores:
|
| 89 |
+
return self.default_profile_id, 1.0
|
| 90 |
+
|
| 91 |
+
best_profile_id = max(scores.items(), key=lambda x: x[1])[0]
|
| 92 |
+
confidence = scores[best_profile_id]
|
| 93 |
+
|
| 94 |
+
# Apply minimum confidence threshold
|
| 95 |
+
min_confidence = self.selection_strategy.get("strategy_parameters", {}).get("minimum_confidence", 0.65)
|
| 96 |
+
if confidence < min_confidence:
|
| 97 |
+
return self.default_profile_id, confidence
|
| 98 |
+
|
| 99 |
+
return best_profile_id, confidence
|
| 100 |
+
|
| 101 |
+
def get_profile_parameters(self, profile_id: str) -> Dict[str, Any]:
|
| 102 |
+
"""
|
| 103 |
+
Get the parameters for the specified attention profile.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
profile_id: ID of the attention profile
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Dictionary of attention parameters
|
| 110 |
+
"""
|
| 111 |
+
if profile_id in self.profiles:
|
| 112 |
+
return self.profiles[profile_id].get("parameters", {})
|
| 113 |
+
return {}
|
| 114 |
+
|
| 115 |
+
def get_attention_type(self, profile_id: str) -> str:
|
| 116 |
+
"""
|
| 117 |
+
Get the attention mechanism type for the specified profile.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
profile_id: ID of the attention profile
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
String identifying the attention type
|
| 124 |
+
"""
|
| 125 |
+
if profile_id in self.profiles:
|
| 126 |
+
return self.profiles[profile_id].get("attention_type", "standard")
|
| 127 |
+
return "standard"
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Factory method to create appropriate attention mechanism
|
| 131 |
+
def create_attention_mechanism(profile_id: str, model_dim: int, selector: AttentionProfileSelector):
|
| 132 |
+
"""
|
| 133 |
+
Factory function to create an attention mechanism based on the selected profile.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
profile_id: ID of the selected attention profile
|
| 137 |
+
model_dim: Model hidden dimension
|
| 138 |
+
selector: AttentionProfileSelector instance
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Configured attention mechanism
|
| 142 |
+
"""
|
| 143 |
+
# This function would integrate with your existing attention mechanisms
|
| 144 |
+
# For implementation with smartHybridAttention:
|
| 145 |
+
attention_type = selector.get_attention_type(profile_id)
|
| 146 |
+
parameters = selector.get_profile_parameters(profile_id)
|
| 147 |
+
|
| 148 |
+
# Import here to avoid circular imports
|
| 149 |
+
try:
|
| 150 |
+
from smartHybridAttention import EnhancedSmartHybridAttention, create_smart_hybrid_attention
|
| 151 |
+
|
| 152 |
+
# Map parameters from JSON to the attention mechanism
|
| 153 |
+
attention_params = {
|
| 154 |
+
"dim": model_dim,
|
| 155 |
+
"num_heads": parameters.get("num_heads", 8),
|
| 156 |
+
"window_size": parameters.get("window_size", 256),
|
| 157 |
+
"use_sliding": parameters.get("use_sliding_window", True),
|
| 158 |
+
"use_global": parameters.get("use_global_tokens", True),
|
| 159 |
+
"global_token_ratio": parameters.get("global_token_ratio", 0.05),
|
| 160 |
+
"memory_tokens": parameters.get("memory_token_count", 16)
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
# Create appropriate attention mechanism based on type
|
| 164 |
+
if attention_type == "hierarchical":
|
| 165 |
+
attention_params["use_hierarchical"] = True
|
| 166 |
+
|
| 167 |
+
return create_smart_hybrid_attention(**attention_params)
|
| 168 |
+
|
| 169 |
+
except ImportError:
|
| 170 |
+
print("Warning: smartHybridAttention not found. Using placeholder.")
|
| 171 |
+
# Return a placeholder if the module is not available
|
| 172 |
+
import torch.nn as nn
|
| 173 |
+
return nn.MultiheadAttention(model_dim, 8)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Usage example:
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
selector = AttentionProfileSelector()
|
| 179 |
+
|
| 180 |
+
# Example inputs
|
| 181 |
+
code_input = "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)"
|
| 182 |
+
|
| 183 |
+
document_input = """# Chapter 1: Introduction
|
| 184 |
+
This technical document covers the architecture of our system.
|
| 185 |
+
## Section 1.1: Overview
|
| 186 |
+
The system consists of multiple components working together.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
conversation_input = "How did you like the movie we saw yesterday? I thought the ending was unexpected."
|
| 190 |
+
|
| 191 |
+
# Test profile selection
|
| 192 |
+
code_profile, code_conf = selector.select_profile(code_input)
|
| 193 |
+
doc_profile, doc_conf = selector.select_profile(document_input)
|
| 194 |
+
conv_profile, conv_conf = selector.select_profile(conversation_input)
|
| 195 |
+
|
| 196 |
+
print(f"Code input → {code_profile} (confidence: {code_conf:.2f})")
|
| 197 |
+
print(f"Document input → {doc_profile} (confidence: {doc_conf:.2f})")
|
| 198 |
+
print(f"Conversation input → {conv_profile} (confidence: {conv_conf:.2f})")
|
utils/collator.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom data collators for transformer training.
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
from typing import Dict, List, Any, Union
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class DataCollatorForLanguageModeling:
|
| 11 |
+
"""
|
| 12 |
+
Data collator for language modeling.
|
| 13 |
+
|
| 14 |
+
This collator will tokenize inputs and dynamically mask tokens
|
| 15 |
+
for masked language modeling tasks.
|
| 16 |
+
"""
|
| 17 |
+
tokenizer: Any
|
| 18 |
+
mlm: bool = True # Whether to use masked language modeling
|
| 19 |
+
mlm_probability: float = 0.15 # Probability of masking a token
|
| 20 |
+
|
| 21 |
+
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 22 |
+
"""
|
| 23 |
+
Collate a batch of examples.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
examples: List of examples from dataset
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Batch dictionary for model
|
| 30 |
+
"""
|
| 31 |
+
# Extract input_ids
|
| 32 |
+
input_ids = [example["input_ids"] for example in examples]
|
| 33 |
+
|
| 34 |
+
# Concatenate inputs
|
| 35 |
+
batch = self.tokenizer.pad(
|
| 36 |
+
{"input_ids": input_ids},
|
| 37 |
+
return_tensors="pt"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# If masked language modeling is enabled
|
| 41 |
+
if self.mlm:
|
| 42 |
+
inputs, labels = self.mask_tokens(batch["input_ids"])
|
| 43 |
+
return {"input_ids": inputs, "labels": labels}
|
| 44 |
+
else:
|
| 45 |
+
labels = batch["input_ids"].clone()
|
| 46 |
+
return {
|
| 47 |
+
"input_ids": batch["input_ids"],
|
| 48 |
+
"labels": labels,
|
| 49 |
+
"attention_mask": batch.get("attention_mask", None)
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def mask_tokens(
|
| 53 |
+
self, inputs: torch.Tensor
|
| 54 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 55 |
+
"""
|
| 56 |
+
Prepare masked tokens inputs/labels for masked language modeling.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
inputs: Input tensor
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Tuple of (masked inputs, labels)
|
| 63 |
+
"""
|
| 64 |
+
labels = inputs.clone()
|
| 65 |
+
|
| 66 |
+
# Get probability mask
|
| 67 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
| 68 |
+
|
| 69 |
+
# Create special tokens mask
|
| 70 |
+
if hasattr(self.tokenizer, "get_special_tokens_mask"):
|
| 71 |
+
special_tokens_mask = [
|
| 72 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
|
| 73 |
+
for val in labels.tolist()
|
| 74 |
+
]
|
| 75 |
+
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
| 76 |
+
else:
|
| 77 |
+
special_tokens_mask = torch.tensor(
|
| 78 |
+
[
|
| 79 |
+
[self._is_special_token(x) for x in val]
|
| 80 |
+
for val in labels.tolist()
|
| 81 |
+
],
|
| 82 |
+
dtype=torch.bool,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Don't mask special tokens
|
| 86 |
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
| 87 |
+
|
| 88 |
+
# Get mask indices
|
| 89 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 90 |
+
|
| 91 |
+
# Set labels for non-masked tokens to -100 (ignored in loss)
|
| 92 |
+
labels[~masked_indices] = -100
|
| 93 |
+
|
| 94 |
+
# Set 80% of masked tokens to [MASK]
|
| 95 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 96 |
+
|
| 97 |
+
if hasattr(self.tokenizer, "mask_token_id") and self.tokenizer.mask_token_id is not None:
|
| 98 |
+
inputs[indices_replaced] = self.tokenizer.mask_token_id
|
| 99 |
+
|
| 100 |
+
# Set
|
utils/convert_checkpoints.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility to convert PyTorch (.pt) checkpoints to Hugging Face (.bin) format
|
| 3 |
+
python -m utils.convert_checkpoints --checkpoints checkpoints/stdp_model_epoch_15.pt checkpoints/stdp_model_epoch_20.pt --output hf_stdp_model
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import logging
|
| 8 |
+
import argparse
|
| 9 |
+
import datetime # Added missing import
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Any, Optional
|
| 12 |
+
import json
|
| 13 |
+
import shutil
|
| 14 |
+
|
| 15 |
+
# Configure logging - Fix the typo in format string (levellevel → levelname)
|
| 16 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
def convert_stdp_checkpoint(
|
| 20 |
+
checkpoint_path: str,
|
| 21 |
+
output_dir: str,
|
| 22 |
+
config_path: Optional[str] = None
|
| 23 |
+
) -> str:
|
| 24 |
+
"""
|
| 25 |
+
Convert STDP/SNN PyTorch checkpoint to Hugging Face format.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
checkpoint_path: Path to the .pt checkpoint file
|
| 29 |
+
output_dir: Directory to save the converted model
|
| 30 |
+
config_path: Optional path to config.json file
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Path to the converted model directory
|
| 34 |
+
"""
|
| 35 |
+
logger.info(f"Converting checkpoint: {checkpoint_path}")
|
| 36 |
+
|
| 37 |
+
# Create output directory
|
| 38 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
# Load checkpoint
|
| 42 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 43 |
+
|
| 44 |
+
# Extract epoch from filename if available
|
| 45 |
+
checkpoint_filename = os.path.basename(checkpoint_path)
|
| 46 |
+
epoch = None
|
| 47 |
+
if "epoch_" in checkpoint_filename:
|
| 48 |
+
try:
|
| 49 |
+
epoch = int(checkpoint_filename.split("epoch_")[1].split(".")[0])
|
| 50 |
+
except (ValueError, IndexError):
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
# Create config for the model
|
| 54 |
+
config = {
|
| 55 |
+
"model_type": "stdp_snn",
|
| 56 |
+
"architectures": ["STDPSpikeNeuralNetwork"],
|
| 57 |
+
"epoch": epoch,
|
| 58 |
+
"original_checkpoint": checkpoint_path,
|
| 59 |
+
"conversion_date": str(datetime.datetime.now())
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Update with loaded config if it exists in checkpoint
|
| 63 |
+
if isinstance(checkpoint, dict) and "config" in checkpoint:
|
| 64 |
+
config.update(checkpoint["config"])
|
| 65 |
+
|
| 66 |
+
# Load additional config from file if provided
|
| 67 |
+
if config_path and os.path.exists(config_path):
|
| 68 |
+
with open(config_path, 'r') as f:
|
| 69 |
+
file_config = json.load(f)
|
| 70 |
+
if "STDP_CONFIG" in file_config:
|
| 71 |
+
config.update(file_config["STDP_CONFIG"])
|
| 72 |
+
|
| 73 |
+
# Extract model weights
|
| 74 |
+
model_weights = {}
|
| 75 |
+
if "model_state_dict" in checkpoint:
|
| 76 |
+
model_weights = checkpoint["model_state_dict"]
|
| 77 |
+
elif "state_dict" in checkpoint:
|
| 78 |
+
model_weights = checkpoint["state_dict"]
|
| 79 |
+
elif "weights" in checkpoint:
|
| 80 |
+
model_weights = {"weights": checkpoint["weights"]}
|
| 81 |
+
elif "synaptic_weights" in checkpoint:
|
| 82 |
+
model_weights = {"synaptic_weights": checkpoint["synaptic_weights"]}
|
| 83 |
+
else:
|
| 84 |
+
# If no recognized format, assume the checkpoint itself is the weights
|
| 85 |
+
model_weights = checkpoint
|
| 86 |
+
|
| 87 |
+
# Create model directory structure
|
| 88 |
+
model_dir = os.path.join(output_dir, "pytorch_model.bin")
|
| 89 |
+
|
| 90 |
+
# Save converted weights in HF format
|
| 91 |
+
torch.save(model_weights, model_dir)
|
| 92 |
+
logger.info(f"Saved model weights to {model_dir}")
|
| 93 |
+
|
| 94 |
+
# Save config file
|
| 95 |
+
config_file = os.path.join(output_dir, "config.json")
|
| 96 |
+
with open(config_file, 'w') as f:
|
| 97 |
+
json.dump(config, f, indent=2)
|
| 98 |
+
logger.info(f"Saved model config to {config_file}")
|
| 99 |
+
|
| 100 |
+
# Create a simple README
|
| 101 |
+
readme_file = os.path.join(output_dir, "README.md")
|
| 102 |
+
with open(readme_file, 'w') as f:
|
| 103 |
+
f.write(f"# Converted STDP/SNN Model\n\n")
|
| 104 |
+
f.write(f"This model was converted from PyTorch checkpoint: `{checkpoint_path}`\n\n")
|
| 105 |
+
f.write(f"Converted on: {config['conversion_date']}\n")
|
| 106 |
+
if epoch is not None:
|
| 107 |
+
f.write(f"Training epoch: {epoch}\n")
|
| 108 |
+
|
| 109 |
+
return output_dir
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"Error converting checkpoint: {e}")
|
| 113 |
+
raise
|
| 114 |
+
|
| 115 |
+
def prepare_for_hf_upload(
|
| 116 |
+
checkpoint_paths: list,
|
| 117 |
+
output_dir: str,
|
| 118 |
+
config_path: Optional[str] = None,
|
| 119 |
+
include_code: bool = True
|
| 120 |
+
) -> str:
|
| 121 |
+
"""
|
| 122 |
+
Prepare multiple checkpoints for HF upload with code.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
checkpoint_paths: List of paths to checkpoint files
|
| 126 |
+
output_dir: Directory to save the prepared model
|
| 127 |
+
config_path: Optional path to config.json file
|
| 128 |
+
include_code: Whether to include inference code
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Path to the prepared directory
|
| 132 |
+
"""
|
| 133 |
+
# Create output directory
|
| 134 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 135 |
+
|
| 136 |
+
# Convert each checkpoint
|
| 137 |
+
converted_models = []
|
| 138 |
+
for cp_path in checkpoint_paths:
|
| 139 |
+
model_name = os.path.splitext(os.path.basename(cp_path))[0]
|
| 140 |
+
model_dir = os.path.join(output_dir, model_name)
|
| 141 |
+
converted_models.append(convert_stdp_checkpoint(cp_path, model_dir, config_path))
|
| 142 |
+
|
| 143 |
+
# Include necessary code files
|
| 144 |
+
if include_code:
|
| 145 |
+
code_files = [
|
| 146 |
+
"communicator_STDP.py",
|
| 147 |
+
"config.py",
|
| 148 |
+
"model_Custm.py"
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
for file in code_files:
|
| 152 |
+
src_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), file)
|
| 153 |
+
if os.path.exists(src_path):
|
| 154 |
+
dst_path = os.path.join(output_dir, file)
|
| 155 |
+
shutil.copy2(src_path, dst_path)
|
| 156 |
+
logger.info(f"Copied {file} to {dst_path}")
|
| 157 |
+
|
| 158 |
+
# Create an inference script - FIX: Use single quotes for inner docstring
|
| 159 |
+
inference_script = '''
|
| 160 |
+
import torch
|
| 161 |
+
import os
|
| 162 |
+
import json
|
| 163 |
+
import argparse
|
| 164 |
+
from pathlib import Path
|
| 165 |
+
|
| 166 |
+
def load_stdp_model(model_dir):
|
| 167 |
+
"""Load STDP model from directory."""
|
| 168 |
+
weights_path = os.path.join(model_dir, "pytorch_model.bin")
|
| 169 |
+
config_path = os.path.join(model_dir, "config.json")
|
| 170 |
+
|
| 171 |
+
# Load weights
|
| 172 |
+
weights = torch.load(weights_path, map_location="cpu")
|
| 173 |
+
|
| 174 |
+
# Load config
|
| 175 |
+
with open(config_path, 'r') as f:
|
| 176 |
+
config = json.load(f)
|
| 177 |
+
|
| 178 |
+
return weights, config
|
| 179 |
+
|
| 180 |
+
def main():
|
| 181 |
+
parser = argparse.ArgumentParser(description="Run inference with STDP model")
|
| 182 |
+
parser.add_argument("--model", type=str, required=True, help="Model directory")
|
| 183 |
+
parser.add_argument("--input", type=str, required=True, help="Input text or file")
|
| 184 |
+
args = parser.parse_args()
|
| 185 |
+
|
| 186 |
+
# Load model
|
| 187 |
+
weights, config = load_stdp_model(args.model)
|
| 188 |
+
print(f"Loaded model from {args.model}")
|
| 189 |
+
print(f"Model config: {json.dumps(config, indent=2)}")
|
| 190 |
+
|
| 191 |
+
# Get input
|
| 192 |
+
input_text = args.input
|
| 193 |
+
if os.path.exists(args.input):
|
| 194 |
+
with open(args.input, 'r') as f:
|
| 195 |
+
input_text = f.read()
|
| 196 |
+
|
| 197 |
+
print(f"Input text: {input_text[:100]}...")
|
| 198 |
+
|
| 199 |
+
# Run inference using communicator_STDP if available
|
| 200 |
+
try:
|
| 201 |
+
from communicator_STDP import CommSTDP
|
| 202 |
+
communicator = CommSTDP({}, device="cpu")
|
| 203 |
+
result = communicator.process(input_text, weights)
|
| 204 |
+
print(f"Result: {result}")
|
| 205 |
+
except ImportError:
|
| 206 |
+
print("communicator_STDP not available. Weights loaded successfully.")
|
| 207 |
+
print(f"Weights shape: {weights.shape if hasattr(weights, 'shape') else '[dict of tensors]'}")
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
main()
|
| 211 |
+
'''
|
| 212 |
+
|
| 213 |
+
inference_path = os.path.join(output_dir, "inference.py")
|
| 214 |
+
with open(inference_path, 'w') as f:
|
| 215 |
+
f.write(inference_script.strip())
|
| 216 |
+
logger.info(f"Created inference script: {inference_path}")
|
| 217 |
+
|
| 218 |
+
# Create an overall README
|
| 219 |
+
readme_file = os.path.join(output_dir, "README.md")
|
| 220 |
+
with open(readme_file, 'w') as f:
|
| 221 |
+
f.write("# STDP/SNN Trained Models\n\n")
|
| 222 |
+
f.write("This repository contains STDP/SNN models converted from PyTorch checkpoints for use with Hugging Face's infrastructure.\n\n")
|
| 223 |
+
f.write("## Models Included\n\n")
|
| 224 |
+
for i, model in enumerate(converted_models):
|
| 225 |
+
f.write(f"{i+1}. `{os.path.basename(model)}`\n")
|
| 226 |
+
|
| 227 |
+
f.write("\n## Usage\n\n")
|
| 228 |
+
f.write("```python\n")
|
| 229 |
+
f.write("from transformers import AutoModel\n\n")
|
| 230 |
+
f.write("# Load the model\n")
|
| 231 |
+
f.write("model = AutoModel.from_pretrained('your-username/your-model-name')\n")
|
| 232 |
+
f.write("```\n\n")
|
| 233 |
+
f.write("Or use the included inference.py script:\n\n")
|
| 234 |
+
f.write("```bash\npython inference.py --model ./stdp_model_epoch_15 --input 'Your input text here'\n```")
|
| 235 |
+
|
| 236 |
+
logger.info(f"Prepared {len(converted_models)} models for HF upload in {output_dir}")
|
| 237 |
+
return output_dir
|
| 238 |
+
|
| 239 |
+
if __name__ == "__main__":
|
| 240 |
+
parser = argparse.ArgumentParser(description="Convert PyTorch checkpoints to Hugging Face format")
|
| 241 |
+
parser.add_argument("--checkpoints", nargs="+", required=True, help="Paths to checkpoint files")
|
| 242 |
+
parser.add_argument("--output", type=str, default="hf_model", help="Output directory")
|
| 243 |
+
parser.add_argument("--config", type=str, help="Path to config.json file")
|
| 244 |
+
parser.add_argument("--no-code", action="store_true", help="Don't include inference code")
|
| 245 |
+
|
| 246 |
+
args = parser.parse_args()
|
| 247 |
+
|
| 248 |
+
prepare_for_hf_upload(
|
| 249 |
+
args.checkpoints,
|
| 250 |
+
args.output,
|
| 251 |
+
args.config,
|
| 252 |
+
not args.no_code
|
| 253 |
+
)
|
utils/debug_helper.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import psutil
|
| 5 |
+
import traceback
|
| 6 |
+
import logging
|
| 7 |
+
import threading
|
| 8 |
+
from typing import Dict, Any, Optional, List
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class DebugHelper:
|
| 13 |
+
"""
|
| 14 |
+
Helper class for debugging hanging issues in STDP training.
|
| 15 |
+
|
| 16 |
+
This provides tools for:
|
| 17 |
+
1. Process monitoring and status reporting
|
| 18 |
+
2. Timeout management
|
| 19 |
+
3. Recovery mechanisms for hanging processes
|
| 20 |
+
4. Detailed diagnostics
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def get_process_info(pid: Optional[int] = None) -> Dict[str, Any]:
|
| 25 |
+
"""Get detailed information about the current process."""
|
| 26 |
+
pid = pid or os.getpid()
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
process = psutil.Process(pid)
|
| 30 |
+
|
| 31 |
+
# Get basic process info
|
| 32 |
+
info = {
|
| 33 |
+
'pid': pid,
|
| 34 |
+
'name': process.name(),
|
| 35 |
+
'status': process.status(),
|
| 36 |
+
'cpu_percent': process.cpu_percent(),
|
| 37 |
+
'memory_percent': process.memory_percent(),
|
| 38 |
+
'memory_info': dict(process.memory_info()._asdict()),
|
| 39 |
+
'create_time': process.create_time(),
|
| 40 |
+
'runtime': time.time() - process.create_time(),
|
| 41 |
+
'num_threads': process.num_threads(),
|
| 42 |
+
'open_files': len(process.open_files()),
|
| 43 |
+
'connections': len(process.connections()),
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Get thread details
|
| 47 |
+
try:
|
| 48 |
+
import threading
|
| 49 |
+
info['active_threads'] = [t.name for t in threading.enumerate()]
|
| 50 |
+
except:
|
| 51 |
+
info['active_threads'] = "Could not retrieve thread information"
|
| 52 |
+
|
| 53 |
+
# Get current stack trace
|
| 54 |
+
info['current_stack'] = traceback.format_stack()
|
| 55 |
+
|
| 56 |
+
return info
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Error getting process info: {e}")
|
| 60 |
+
return {'error': str(e)}
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def check_resource_leaks() -> Dict[str, Any]:
|
| 64 |
+
"""Check for potential resource leaks."""
|
| 65 |
+
import gc
|
| 66 |
+
|
| 67 |
+
leaks = {
|
| 68 |
+
'gc_counts': gc.get_count(),
|
| 69 |
+
'gc_objects': len(gc.get_objects()),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# Check for torch memory usage if available
|
| 73 |
+
try:
|
| 74 |
+
import torch
|
| 75 |
+
if torch.cuda.is_available():
|
| 76 |
+
leaks['torch_memory_allocated'] = torch.cuda.memory_allocated()
|
| 77 |
+
leaks['torch_memory_reserved'] = torch.cuda.memory_reserved()
|
| 78 |
+
leaks['torch_max_memory_allocated'] = torch.cuda.max_memory_allocated()
|
| 79 |
+
except ImportError:
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
return leaks
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def register_timeout(seconds: int, callback=None):
|
| 86 |
+
"""Register a timeout that calls the callback after specified seconds."""
|
| 87 |
+
def _timeout_handler():
|
| 88 |
+
time.sleep(seconds)
|
| 89 |
+
if callback:
|
| 90 |
+
callback()
|
| 91 |
+
else:
|
| 92 |
+
print(f"TIMEOUT: Operation took longer than {seconds} seconds")
|
| 93 |
+
info = DebugHelper.get_process_info()
|
| 94 |
+
print(f"Process info: {info}")
|
| 95 |
+
traceback.print_stack()
|
| 96 |
+
|
| 97 |
+
thread = threading.Thread(target=_timeout_handler)
|
| 98 |
+
thread.daemon = True
|
| 99 |
+
thread.start()
|
| 100 |
+
return thread
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def dump_debug_info(filename: str):
|
| 104 |
+
"""Dump debug information to a file."""
|
| 105 |
+
process_info = DebugHelper.get_process_info()
|
| 106 |
+
leak_info = DebugHelper.check_resource_leaks()
|
| 107 |
+
|
| 108 |
+
with open(filename, 'w') as f:
|
| 109 |
+
f.write("===== PROCESS INFORMATION =====\n")
|
| 110 |
+
for key, value in process_info.items():
|
| 111 |
+
f.write(f"{key}: {value}\n")
|
| 112 |
+
|
| 113 |
+
f.write("\n===== RESOURCE LEAK INFORMATION =====\n")
|
| 114 |
+
for key, value in leak_info.items():
|
| 115 |
+
f.write(f"{key}: {value}\n")
|
| 116 |
+
|
| 117 |
+
f.write("\n===== ENVIRONMENT VARIABLES =====\n")
|
| 118 |
+
for key, value in os.environ.items():
|
| 119 |
+
f.write(f"{key}: {value}\n")
|
| 120 |
+
|
| 121 |
+
f.write("\n===== STACK TRACE =====\n")
|
| 122 |
+
f.write(''.join(traceback.format_stack()))
|
| 123 |
+
|
| 124 |
+
logger.info(f"Debug info dumped to {filename}")
|
utils/dual_encoder_utils.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for dual encoder configuration and initialization.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, Any, Optional, Union
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from config import load_config, app_config
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class DualEncoderConfig:
|
| 14 |
+
"""Configuration object for dual encoders"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, config_dict: Optional[Dict[str, Any]] = None):
|
| 17 |
+
"""
|
| 18 |
+
Initialize dual encoder configuration.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
config_dict: Optional configuration dictionary to override defaults
|
| 22 |
+
"""
|
| 23 |
+
# Set defaults from app_config
|
| 24 |
+
config = load_config()
|
| 25 |
+
|
| 26 |
+
# Default configuration
|
| 27 |
+
self.USE_PRETRAINED_ENCODER = True
|
| 28 |
+
self.USE_CUSTOM_ENCODER = True
|
| 29 |
+
self.FUSION_METHOD = "concat" # Options: concat, add, weighted_sum
|
| 30 |
+
self.FUSION_WEIGHTS = [0.5, 0.5] # Weights for pretrained and custom encoders
|
| 31 |
+
self.TRAINING_MODE = "joint" # Options: joint, alternating, pretrained_first
|
| 32 |
+
|
| 33 |
+
# Override defaults with app_config if available
|
| 34 |
+
if hasattr(config, "DUAL_ENCODER_CONFIG"):
|
| 35 |
+
dual_config = config.DUAL_ENCODER_CONFIG
|
| 36 |
+
for key, value in dual_config.items():
|
| 37 |
+
setattr(self, key, value)
|
| 38 |
+
|
| 39 |
+
# Override with provided config_dict if available
|
| 40 |
+
if config_dict:
|
| 41 |
+
for key, value in config_dict.items():
|
| 42 |
+
setattr(self, key, value)
|
| 43 |
+
|
| 44 |
+
class DualEncoderFusion(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
Module that combines outputs from pretrained and custom encoders.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, config: Optional[Union[Dict[str, Any], DualEncoderConfig]] = None):
|
| 50 |
+
"""
|
| 51 |
+
Initialize fusion module.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
config: Configuration for fusion (dict or DualEncoderConfig object)
|
| 55 |
+
"""
|
| 56 |
+
super().__init__()
|
| 57 |
+
|
| 58 |
+
# Convert dict to DualEncoderConfig if needed
|
| 59 |
+
if isinstance(config, dict):
|
| 60 |
+
self.config = DualEncoderConfig(config)
|
| 61 |
+
elif config is None:
|
| 62 |
+
self.config = DualEncoderConfig()
|
| 63 |
+
else:
|
| 64 |
+
self.config = config
|
| 65 |
+
|
| 66 |
+
# Initialize fusion weights if using weighted sum
|
| 67 |
+
if self.config.FUSION_METHOD == "weighted_sum":
|
| 68 |
+
weights = torch.tensor(self.config.FUSION_WEIGHTS, dtype=torch.float32)
|
| 69 |
+
self.register_buffer('fusion_weights', weights / weights.sum())
|
| 70 |
+
|
| 71 |
+
def forward(self, pretrained_output: torch.Tensor, custom_output: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
"""
|
| 73 |
+
Combine encoder outputs based on fusion method.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
pretrained_output: Output from pretrained encoder
|
| 77 |
+
custom_output: Output from custom encoder
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Combined tensor
|
| 81 |
+
"""
|
| 82 |
+
# Handle the case where one encoder is disabled
|
| 83 |
+
if not self.config.USE_PRETRAINED_ENCODER:
|
| 84 |
+
return custom_output
|
| 85 |
+
if not self.config.USE_CUSTOM_ENCODER:
|
| 86 |
+
return pretrained_output
|
| 87 |
+
|
| 88 |
+
# Apply fusion method
|
| 89 |
+
if self.config.FUSION_METHOD == "concat":
|
| 90 |
+
return torch.cat([pretrained_output, custom_output], dim=-1)
|
| 91 |
+
elif self.config.FUSION_METHOD == "add":
|
| 92 |
+
# Ensure dimensions match
|
| 93 |
+
if pretrained_output.shape != custom_output.shape:
|
| 94 |
+
raise ValueError(f"Cannot add tensors with different shapes: {pretrained_output.shape} and {custom_output.shape}")
|
| 95 |
+
return pretrained_output + custom_output
|
| 96 |
+
elif self.config.FUSION_METHOD == "weighted_sum":
|
| 97 |
+
# Ensure dimensions match
|
| 98 |
+
if pretrained_output.shape != custom_output.shape:
|
| 99 |
+
raise ValueError(f"Cannot use weighted sum with different shapes: {pretrained_output.shape} and {custom_output.shape}")
|
| 100 |
+
# Apply weights
|
| 101 |
+
w1, w2 = self.fusion_weights
|
| 102 |
+
return w1 * pretrained_output + w2 * custom_output
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f"Unknown fusion method: {self.config.FUSION_METHOD}")
|
| 105 |
+
|
| 106 |
+
def get_dual_encoder_config() -> DualEncoderConfig:
|
| 107 |
+
"""
|
| 108 |
+
Get dual encoder configuration from app_config.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
DualEncoderConfig object
|
| 112 |
+
"""
|
| 113 |
+
return DualEncoderConfig()
|
| 114 |
+
|
| 115 |
+
# Testing function
|
| 116 |
+
def test_fusion_methods():
|
| 117 |
+
"""Test different fusion methods"""
|
| 118 |
+
config = DualEncoderConfig()
|
| 119 |
+
|
| 120 |
+
# Create test tensors
|
| 121 |
+
x1 = torch.randn(2, 10, 768)
|
| 122 |
+
x2 = torch.randn(2, 10, 768)
|
| 123 |
+
|
| 124 |
+
# Test concat fusion
|
| 125 |
+
config.FUSION_METHOD = "concat"
|
| 126 |
+
fusion_concat = DualEncoderFusion(config)
|
| 127 |
+
output_concat = fusion_concat(x1, x2)
|
| 128 |
+
print(f"Concat output shape: {output_concat.shape}") # Should be [2, 10, 1536]
|
| 129 |
+
|
| 130 |
+
# Test add fusion
|
| 131 |
+
config.FUSION_METHOD = "add"
|
| 132 |
+
fusion_add = DualEncoderFusion(config)
|
| 133 |
+
output_add = fusion_add(x1, x2)
|
| 134 |
+
print(f"Add output shape: {output_add.shape}") # Should be [2, 10, 768]
|
| 135 |
+
|
| 136 |
+
# Test weighted sum fusion
|
| 137 |
+
config.FUSION_METHOD = "weighted_sum"
|
| 138 |
+
config.FUSION_WEIGHTS = [0.7, 0.3]
|
| 139 |
+
fusion_weighted = DualEncoderFusion(config)
|
| 140 |
+
output_weighted = fusion_weighted(x1, x2)
|
| 141 |
+
print(f"Weighted sum output shape: {output_weighted.shape}") # Should be [2, 10, 768]
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
"concat": output_concat,
|
| 145 |
+
"add": output_add,
|
| 146 |
+
"weighted_sum": output_weighted
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
# Run tests
|
| 151 |
+
test_results = test_fusion_methods()
|
utils/emergency_abort.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import signal
|
| 5 |
+
import logging
|
| 6 |
+
import threading
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class EmergencyAbort:
|
| 11 |
+
"""Creates an abort file that can be touched to trigger process termination."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, abort_file="emergency_abort.txt", check_interval=5):
|
| 14 |
+
self.abort_file = abort_file
|
| 15 |
+
self.check_interval = check_interval
|
| 16 |
+
self.running = False
|
| 17 |
+
self.thread = None
|
| 18 |
+
|
| 19 |
+
# Create initial abort file with instructions
|
| 20 |
+
with open(self.abort_file, 'w') as f:
|
| 21 |
+
f.write("# Emergency Abort File\n")
|
| 22 |
+
f.write("# To abort training, update the timestamp of this file\n")
|
| 23 |
+
f.write(f"# Last checked: {time.ctime()}\n")
|
| 24 |
+
|
| 25 |
+
def _check_file(self):
|
| 26 |
+
last_modified = os.path.getmtime(self.abort_file)
|
| 27 |
+
|
| 28 |
+
while self.running:
|
| 29 |
+
time.sleep(self.check_interval)
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
current_modified = os.path.getmtime(self.abort_file)
|
| 33 |
+
|
| 34 |
+
if current_modified > last_modified:
|
| 35 |
+
logger.warning("Emergency abort file modified! Initiating abort sequence.")
|
| 36 |
+
# Kill this process
|
| 37 |
+
os.kill(os.getpid(), signal.SIGTERM)
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
# Update check timestamp in file
|
| 41 |
+
with open(self.abort_file, 'w') as f:
|
| 42 |
+
f.write("# Emergency Abort File\n")
|
| 43 |
+
f.write("# To abort training, update the timestamp of this file\n")
|
| 44 |
+
f.write(f"# Last checked: {time.ctime()}\n")
|
| 45 |
+
|
| 46 |
+
last_modified = current_modified
|
| 47 |
+
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Error checking abort file: {e}")
|
| 50 |
+
|
| 51 |
+
def start(self):
|
| 52 |
+
"""Start the abort file monitor."""
|
| 53 |
+
self.running = True
|
| 54 |
+
self.thread = threading.Thread(target=self._check_file)
|
| 55 |
+
self.thread.daemon = True
|
| 56 |
+
self.thread.start()
|
| 57 |
+
logger.info(f"Emergency abort monitor started. Modify {self.abort_file} to terminate training.")
|
| 58 |
+
return self
|
| 59 |
+
|
| 60 |
+
def stop(self):
|
| 61 |
+
"""Stop the abort file monitor."""
|
| 62 |
+
self.running = False
|
| 63 |
+
if self.thread:
|
| 64 |
+
self.thread.join(timeout=2)
|
| 65 |
+
logger.info("Emergency abort monitor stopped.")
|
utils/event_bus.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple event bus for direct component-to-component communication.
|
| 3 |
+
Provides a lightweight alternative to the full EventSystem.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, List, Callable, Any
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class EventBus:
|
| 11 |
+
"""Simple synchronous event bus for direct event routing."""
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
"""Initialize an empty event bus."""
|
| 15 |
+
self.subscribers = {}
|
| 16 |
+
logger.info("Initialized EventBus")
|
| 17 |
+
|
| 18 |
+
def subscribe(self, event_type: str, callback: Callable[[str, Any], None]) -> None:
|
| 19 |
+
"""Subscribe a callback to an event type."""
|
| 20 |
+
if event_type not in self.subscribers:
|
| 21 |
+
self.subscribers[event_type] = []
|
| 22 |
+
self.subscribers[event_type].append(callback)
|
| 23 |
+
logger.debug(f"Added subscriber to {event_type}")
|
| 24 |
+
|
| 25 |
+
def unsubscribe(self, event_type: str, callback: Callable[[str, Any], None]) -> None:
|
| 26 |
+
"""Unsubscribe a callback from an event type."""
|
| 27 |
+
if event_type in self.subscribers and callback in self.subscribers[event_type]:
|
| 28 |
+
self.subscribers[event_type].remove(callback)
|
| 29 |
+
logger.debug(f"Removed subscriber from {event_type}")
|
| 30 |
+
|
| 31 |
+
def publish(self, event_type: str, data: Any = None) -> None:
|
| 32 |
+
"""Synchronously publish an event to all subscribers."""
|
| 33 |
+
subscribers = self.subscribers.get(event_type, []).copy()
|
| 34 |
+
|
| 35 |
+
if not subscribers:
|
| 36 |
+
logger.debug(f"No subscribers for event {event_type}")
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
logger.debug(f"Dispatching event {event_type} to {len(subscribers)} subscribers")
|
| 40 |
+
|
| 41 |
+
for callback in subscribers:
|
| 42 |
+
try:
|
| 43 |
+
callback(event_type, data)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"Error in subscriber callback: {e}")
|
| 46 |
+
|
| 47 |
+
# Create a global instance for convenience
|
| 48 |
+
event_bus = EventBus()
|
utils/event_system.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Event system module for enabling parallel processing across components.
|
| 3 |
+
Implements a publisher-subscriber pattern to decouple components.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import threading
|
| 7 |
+
import queue
|
| 8 |
+
import time
|
| 9 |
+
from typing import Dict, List, Callable, Any, Optional, Set
|
| 10 |
+
import concurrent.futures
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class Event:
|
| 15 |
+
"""Event class containing event data and metadata."""
|
| 16 |
+
def __init__(self, event_type: str, data: Any = None, source: str = None):
|
| 17 |
+
self.event_type = event_type
|
| 18 |
+
self.data = data
|
| 19 |
+
self.source = source
|
| 20 |
+
self.timestamp = time.time()
|
| 21 |
+
|
| 22 |
+
def __repr__(self) -> str:
|
| 23 |
+
return f"Event({self.event_type}, source={self.source}, timestamp={self.timestamp})"
|
| 24 |
+
|
| 25 |
+
class EventSystem:
|
| 26 |
+
"""Event system for parallel processing of prompts and responses."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, max_workers: int = 4):
|
| 29 |
+
"""Initialize the event system."""
|
| 30 |
+
self.subscribers: Dict[str, List[Callable[[Event], None]]] = {}
|
| 31 |
+
self.lock = threading.RLock() # Reentrant lock for thread safety
|
| 32 |
+
self.event_queue = queue.Queue()
|
| 33 |
+
self.running = False
|
| 34 |
+
self.dispatcher_thread = None
|
| 35 |
+
self.max_workers = max_workers
|
| 36 |
+
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
| 37 |
+
self.futures = set()
|
| 38 |
+
logger.info(f"Initialized EventSystem with {max_workers} workers")
|
| 39 |
+
|
| 40 |
+
def subscribe(self, event_type: str, callback: Callable[[Event], None]) -> None:
|
| 41 |
+
"""Subscribe a callback to a specific event type."""
|
| 42 |
+
with self.lock:
|
| 43 |
+
if event_type not in self.subscribers:
|
| 44 |
+
self.subscribers[event_type] = []
|
| 45 |
+
self.subscribers[event_type].append(callback)
|
| 46 |
+
logger.debug(f"Added subscriber to {event_type}, total: {len(self.subscribers[event_type])}")
|
| 47 |
+
|
| 48 |
+
def unsubscribe(self, event_type: str, callback: Callable[[Event], None]) -> None:
|
| 49 |
+
"""Unsubscribe a callback from a specific event type."""
|
| 50 |
+
with self.lock:
|
| 51 |
+
if event_type in self.subscribers and callback in self.subscribers[event_type]:
|
| 52 |
+
self.subscribers[event_type].remove(callback)
|
| 53 |
+
logger.debug(f"Removed subscriber from {event_type}, remaining: {len(self.subscribers[event_type])}")
|
| 54 |
+
|
| 55 |
+
def publish(self, event: Event) -> None:
|
| 56 |
+
"""Publish an event to all subscribers."""
|
| 57 |
+
self.event_queue.put(event)
|
| 58 |
+
logger.debug(f"Published event: {event}")
|
| 59 |
+
|
| 60 |
+
# Start dispatcher if not running
|
| 61 |
+
with self.lock:
|
| 62 |
+
if not self.running:
|
| 63 |
+
self.start()
|
| 64 |
+
|
| 65 |
+
def publish_from_dict(self, event_type: str, data: Dict[str, Any], source: str = None) -> None:
|
| 66 |
+
"""Convenient method to publish an event from a dictionary."""
|
| 67 |
+
event = Event(event_type, data, source)
|
| 68 |
+
self.publish(event)
|
| 69 |
+
|
| 70 |
+
def start(self) -> None:
|
| 71 |
+
"""Start the event dispatcher thread."""
|
| 72 |
+
with self.lock:
|
| 73 |
+
if not self.running:
|
| 74 |
+
self.running = True
|
| 75 |
+
self.dispatcher_thread = threading.Thread(target=self._dispatch_events)
|
| 76 |
+
self.dispatcher_thread.daemon = True
|
| 77 |
+
self.dispatcher_thread.start()
|
| 78 |
+
logger.info("Event dispatcher thread started")
|
| 79 |
+
|
| 80 |
+
def stop(self) -> None:
|
| 81 |
+
"""Stop the event dispatcher thread."""
|
| 82 |
+
with self.lock:
|
| 83 |
+
if self.running:
|
| 84 |
+
self.running = False
|
| 85 |
+
self.event_queue.put(None) # Sentinel to stop the thread
|
| 86 |
+
if self.dispatcher_thread and self.dispatcher_thread.is_alive():
|
| 87 |
+
self.dispatcher_thread.join(timeout=2.0)
|
| 88 |
+
logger.info("Event dispatcher thread stopped")
|
| 89 |
+
|
| 90 |
+
# Shut down thread pool
|
| 91 |
+
self.thread_pool.shutdown(wait=False)
|
| 92 |
+
|
| 93 |
+
def _dispatch_events(self) -> None:
|
| 94 |
+
"""Dispatcher thread that processes events from the queue."""
|
| 95 |
+
while self.running:
|
| 96 |
+
try:
|
| 97 |
+
# Get next event with timeout to allow checking running flag
|
| 98 |
+
event = self.event_queue.get(timeout=0.5)
|
| 99 |
+
|
| 100 |
+
# Handle sentinel value
|
| 101 |
+
if event is None:
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
# Process the event
|
| 105 |
+
self._process_event(event)
|
| 106 |
+
|
| 107 |
+
# Mark task as done
|
| 108 |
+
self.event_queue.task_done()
|
| 109 |
+
except queue.Empty:
|
| 110 |
+
continue
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"Error in event dispatcher: {e}")
|
| 113 |
+
|
| 114 |
+
logger.info("Event dispatcher thread exiting")
|
| 115 |
+
|
| 116 |
+
def _process_event(self, event: Event) -> None:
|
| 117 |
+
"""Process a single event by notifying subscribers."""
|
| 118 |
+
with self.lock:
|
| 119 |
+
# Get subscribers for this event type
|
| 120 |
+
subscribers = self.subscribers.get(event.event_type, []).copy()
|
| 121 |
+
# Also check for wildcard subscribers
|
| 122 |
+
wildcard_subscribers = self.subscribers.get("*", []).copy()
|
| 123 |
+
all_subscribers = subscribers + wildcard_subscribers
|
| 124 |
+
|
| 125 |
+
if not all_subscribers:
|
| 126 |
+
logger.debug(f"No subscribers for event {event.event_type}")
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
logger.debug(f"Dispatching event {event.event_type} to {len(all_subscribers)} subscribers")
|
| 130 |
+
|
| 131 |
+
# Submit a task to the thread pool for each subscriber
|
| 132 |
+
for callback in all_subscribers:
|
| 133 |
+
future = self.thread_pool.submit(self._safe_callback, callback, event)
|
| 134 |
+
self.futures.add(future)
|
| 135 |
+
future.add_done_callback(lambda f: self.futures.remove(f))
|
| 136 |
+
|
| 137 |
+
def _safe_callback(self, callback: Callable[[Event], None], event: Event) -> None:
|
| 138 |
+
"""Execute a callback safely, catching exceptions."""
|
| 139 |
+
try:
|
| 140 |
+
callback(event)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(f"Error in subscriber callback: {e}")
|
| 143 |
+
|
| 144 |
+
def wait_for_all_events(self, timeout: Optional[float] = None) -> bool:
|
| 145 |
+
"""Wait for all pending events to be processed."""
|
| 146 |
+
try:
|
| 147 |
+
self.event_queue.join()
|
| 148 |
+
|
| 149 |
+
# Also wait for all futures to complete
|
| 150 |
+
done, not_done = concurrent.futures.wait(
|
| 151 |
+
self.futures,
|
| 152 |
+
timeout=timeout,
|
| 153 |
+
return_when=concurrent.futures.ALL_COMPLETED
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return len(not_done) == 0
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.error(f"Error waiting for events: {e}")
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
# Common event types
|
| 162 |
+
EVENT_USER_INPUT = "user_input"
|
| 163 |
+
EVENT_MODEL_REQUEST = "model_request"
|
| 164 |
+
EVENT_MODEL_RESPONSE = "model_response"
|
| 165 |
+
EVENT_STDP_REQUEST = "stdp_request"
|
| 166 |
+
EVENT_STDP_RESPONSE = "stdp_response"
|
| 167 |
+
EVENT_TOKEN_GENERATED = "token_generated"
|
| 168 |
+
EVENT_RESPONSE_COMPLETE = "response_complete"
|
| 169 |
+
EVENT_ERROR = "error"
|
utils/gpu_config_optimizer.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility to optimize transformer configuration for GPU memory constraints.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import torch
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
# Configure logging
|
| 12 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
def optimize_config_for_gpu(config_path, target_vram_mb=None, batch_reduction_factor=0.5):
|
| 16 |
+
"""
|
| 17 |
+
Optimize config settings for the available GPU memory.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
config_path: Path to the config.json file
|
| 21 |
+
target_vram_mb: Target VRAM usage in MB (if None, will use 80% of available)
|
| 22 |
+
batch_reduction_factor: How much to reduce batch size (0.5 = half)
|
| 23 |
+
"""
|
| 24 |
+
# Get current GPU memory capacity
|
| 25 |
+
if torch.cuda.is_available():
|
| 26 |
+
device = torch.cuda.current_device()
|
| 27 |
+
gpu_properties = torch.cuda.get_device_properties(device)
|
| 28 |
+
total_memory = gpu_properties.total_memory / (1024 * 1024) # Convert to MB
|
| 29 |
+
gpu_name = gpu_properties.name
|
| 30 |
+
logger.info(f"GPU detected: {gpu_name} with {total_memory:.0f}MB VRAM")
|
| 31 |
+
|
| 32 |
+
# Set target memory if not specified (80% of available)
|
| 33 |
+
if target_vram_mb is None:
|
| 34 |
+
target_vram_mb = int(total_memory * 0.8)
|
| 35 |
+
else:
|
| 36 |
+
logger.warning("No GPU detected, using conservative settings for CPU")
|
| 37 |
+
target_vram_mb = 2048 # Conservative default for CPU
|
| 38 |
+
gpu_name = "CPU"
|
| 39 |
+
|
| 40 |
+
logger.info(f"Target VRAM usage: {target_vram_mb}MB")
|
| 41 |
+
|
| 42 |
+
# Load current config
|
| 43 |
+
with open(config_path, 'r') as f:
|
| 44 |
+
config = json.load(f)
|
| 45 |
+
|
| 46 |
+
# Store original values for reporting
|
| 47 |
+
original_batch_size = config["TRANSFORMER_CONFIG"]["BATCH_SIZE"]
|
| 48 |
+
original_sequence_length = config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"]
|
| 49 |
+
original_num_layers = config["TRANSFORMER_CONFIG"]["NUM_LAYERS"]
|
| 50 |
+
|
| 51 |
+
# Adjust batch size based on GPU
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
# RTX 4050 specific optimizations (around 6GB VRAM)
|
| 54 |
+
if "4050" in gpu_name or total_memory < 7000:
|
| 55 |
+
# Significant reductions needed for 4050
|
| 56 |
+
config["TRANSFORMER_CONFIG"]["BATCH_SIZE"] = max(4, int(original_batch_size * batch_reduction_factor))
|
| 57 |
+
|
| 58 |
+
# If still too large after batch reduction, reduce sequence length too
|
| 59 |
+
if target_vram_mb < 5500: # RTX 4050 has ~6GB VRAM
|
| 60 |
+
# Maybe reduce sequence length for very large inputs
|
| 61 |
+
if config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] > 256:
|
| 62 |
+
config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] = 256
|
| 63 |
+
|
| 64 |
+
# Maybe reduce model complexity if still needed
|
| 65 |
+
if config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] > 6:
|
| 66 |
+
config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] = 6
|
| 67 |
+
|
| 68 |
+
# Enable gradient checkpointing (trades compute for memory)
|
| 69 |
+
if "OPTIMIZATION" not in config:
|
| 70 |
+
config["OPTIMIZATION"] = {}
|
| 71 |
+
|
| 72 |
+
config["OPTIMIZATION"]["USE_GRADIENT_CHECKPOINTING"] = True
|
| 73 |
+
config["OPTIMIZATION"]["USE_MIXED_PRECISION"] = True
|
| 74 |
+
|
| 75 |
+
# Create optimized filename with the GPU type
|
| 76 |
+
gpu_name_simple = gpu_name.replace(" ", "_").lower()
|
| 77 |
+
opt_config_path = config_path.replace(".json", f"_{gpu_name_simple}_optimized.json")
|
| 78 |
+
|
| 79 |
+
# Save optimized config
|
| 80 |
+
with open(opt_config_path, 'w') as f:
|
| 81 |
+
json.dump(config, f, indent=2)
|
| 82 |
+
|
| 83 |
+
# Report changes
|
| 84 |
+
logger.info(f"Optimized configuration saved to: {opt_config_path}")
|
| 85 |
+
logger.info("Changes made:")
|
| 86 |
+
logger.info(f" - Batch size: {original_batch_size} → {config['TRANSFORMER_CONFIG']['BATCH_SIZE']}")
|
| 87 |
+
logger.info(f" - Sequence length: {original_sequence_length} → {config['TRANSFORMER_CONFIG']['MAX_SEQ_LENGTH']}")
|
| 88 |
+
logger.info(f" - Num layers: {original_num_layers} → {config['TRANSFORMER_CONFIG']['NUM_LAYERS']}")
|
| 89 |
+
logger.info(f" - Gradient checkpointing: Enabled")
|
| 90 |
+
logger.info(f" - Mixed precision: Enabled")
|
| 91 |
+
|
| 92 |
+
return opt_config_path
|
| 93 |
+
|
| 94 |
+
def apply_optimized_config(opt_config_path):
|
| 95 |
+
"""
|
| 96 |
+
Apply the optimized config by updating the main config file.
|
| 97 |
+
"""
|
| 98 |
+
# Load optimized config
|
| 99 |
+
with open(opt_config_path, 'r') as f:
|
| 100 |
+
opt_config = json.load(f)
|
| 101 |
+
|
| 102 |
+
# Get main config path
|
| 103 |
+
main_config_path = os.path.join(os.path.dirname(os.path.dirname(opt_config_path)), "config.json")
|
| 104 |
+
|
| 105 |
+
# Backup original config
|
| 106 |
+
backup_path = main_config_path + '.backup'
|
| 107 |
+
if not os.path.exists(backup_path):
|
| 108 |
+
import shutil
|
| 109 |
+
shutil.copy2(main_config_path, backup_path)
|
| 110 |
+
logger.info(f"Original config backed up to: {backup_path}")
|
| 111 |
+
|
| 112 |
+
# Apply optimized config
|
| 113 |
+
with open(main_config_path, 'w') as f:
|
| 114 |
+
json.dump(opt_config, f, indent=2)
|
| 115 |
+
|
| 116 |
+
logger.info(f"Applied optimized config to {main_config_path}")
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
parser = argparse.ArgumentParser(description="Optimize transformer config for GPU memory constraints")
|
| 120 |
+
parser.add_argument("--config", type=str, default="config.json", help="Path to config file")
|
| 121 |
+
parser.add_argument("--apply", action="store_true", help="Apply optimized config")
|
| 122 |
+
parser.add_argument("--batch-factor", type=float, default=0.5, help="Batch size reduction factor")
|
| 123 |
+
parser.add_argument("--target-vram", type=int, default=None, help="Target VRAM usage in MB")
|
| 124 |
+
|
| 125 |
+
args = parser.parse_args()
|
| 126 |
+
|
| 127 |
+
# Resolve config path
|
| 128 |
+
if not os.path.isabs(args.config):
|
| 129 |
+
config_dir = Path(__file__).resolve().parent.parent
|
| 130 |
+
config_path = os.path.join(config_dir, args.config)
|
| 131 |
+
else:
|
| 132 |
+
config_path = args.config
|
| 133 |
+
|
| 134 |
+
# Optimize config
|
| 135 |
+
opt_config_path = optimize_config_for_gpu(
|
| 136 |
+
config_path,
|
| 137 |
+
target_vram_mb=args.target_vram,
|
| 138 |
+
batch_reduction_factor=args.batch_factor
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Apply if requested
|
| 142 |
+
if args.apply:
|
| 143 |
+
apply_optimized_config(opt_config_path)
|
utils/model_utils.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional, Tuple, Dict, Any
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
# List of validated model types from Hugging Face
|
| 8 |
+
VALIDATED_MODEL_TYPES = [
|
| 9 |
+
'bert', 'roberta', 'distilbert', 'gpt2', 't5', 'albert',
|
| 10 |
+
'xlm-roberta', 'bart', 'electra', 'xlnet'
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
def validate_model_name(model_name: str) -> Tuple[bool, Optional[str]]:
|
| 14 |
+
"""
|
| 15 |
+
Validates if a model name is recognized in the Hugging Face model registry.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
model_name: Name of the model to validate
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Tuple containing:
|
| 22 |
+
- Boolean indicating if the model is valid
|
| 23 |
+
- Recommended fallback model name if the original is invalid, None otherwise
|
| 24 |
+
"""
|
| 25 |
+
# Check if model name contains any known model type
|
| 26 |
+
is_valid = any(model_type in model_name.lower() for model_type in VALIDATED_MODEL_TYPES)
|
| 27 |
+
|
| 28 |
+
# Return appropriate fallback based on failure reason
|
| 29 |
+
if not is_valid:
|
| 30 |
+
return False, 'bert-base-uncased' # Default fallback
|
| 31 |
+
|
| 32 |
+
return True, None
|
| 33 |
+
|
| 34 |
+
def get_safe_model_name(config):
|
| 35 |
+
"""
|
| 36 |
+
Get a validated and sanitized model name from config.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
config: Either a config dictionary or a string model name
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
str: A sanitized model name
|
| 43 |
+
"""
|
| 44 |
+
# Handle string input directly
|
| 45 |
+
if isinstance(config, str):
|
| 46 |
+
model_name = config
|
| 47 |
+
else:
|
| 48 |
+
# Handle dictionary input (original behavior)
|
| 49 |
+
model_name = config.get('MODEL_NAME', 'bert-base-uncased')
|
| 50 |
+
|
| 51 |
+
# Validate the model name
|
| 52 |
+
is_valid, fallback = validate_model_name(model_name)
|
| 53 |
+
|
| 54 |
+
# Return original name if valid, otherwise return fallback
|
| 55 |
+
return model_name if is_valid else fallback
|
| 56 |
+
|
| 57 |
+
def create_model_config_json(model_dir: str, model_type: str = 'bert') -> None:
|
| 58 |
+
"""
|
| 59 |
+
Creates a config.json file for a custom model with proper model_type key.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
model_dir: Directory where model is/will be stored
|
| 63 |
+
model_type: The type of model (e.g., 'bert', 'roberta')
|
| 64 |
+
"""
|
| 65 |
+
import json
|
| 66 |
+
|
| 67 |
+
if not os.path.exists(model_dir):
|
| 68 |
+
os.makedirs(model_dir)
|
| 69 |
+
|
| 70 |
+
config_path = os.path.join(model_dir, 'config.json')
|
| 71 |
+
|
| 72 |
+
# Create a minimal config with the required model_type key
|
| 73 |
+
config = {
|
| 74 |
+
"model_type": model_type,
|
| 75 |
+
"architectures": [f"{model_type.capitalize()}Model"],
|
| 76 |
+
"hidden_size": 768,
|
| 77 |
+
"num_attention_heads": 12,
|
| 78 |
+
"num_hidden_layers": 12
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
with open(config_path, 'w') as f:
|
| 82 |
+
json.dump(config, f, indent=2)
|
| 83 |
+
|
| 84 |
+
logger.info(f"Created model config.json with model_type: {model_type} in {model_dir}")
|
utils/nltk_stub.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Stub implementation of NLTK to avoid dependencies in container environments
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
logger.info("Using stub NLTK implementation")
|
| 9 |
+
|
| 10 |
+
# Stub for download - do nothing
|
| 11 |
+
def download(*args, **kwargs):
|
| 12 |
+
logger.warning("NLTK download stub called - no actual download performed")
|
| 13 |
+
return True
|
| 14 |
+
|
| 15 |
+
# Add the missing SimpleTokenizer class
|
| 16 |
+
class SimpleTokenizer:
|
| 17 |
+
"""A simple tokenizer implementation for the NLTK stub"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
logger.info("Stub SimpleTokenizer initialized")
|
| 21 |
+
|
| 22 |
+
def tokenize(self, text):
|
| 23 |
+
"""Simple word tokenization by whitespace"""
|
| 24 |
+
return text.split()
|
| 25 |
+
|
| 26 |
+
# Tokenization stubs
|
| 27 |
+
class WordTokenizer:
|
| 28 |
+
def tokenize(self, text):
|
| 29 |
+
return text.split()
|
| 30 |
+
|
| 31 |
+
def word_tokenize(text):
|
| 32 |
+
return text.split()
|
| 33 |
+
|
| 34 |
+
class SentenceTokenizer:
|
| 35 |
+
def tokenize(self, text):
|
| 36 |
+
return text.split('.')
|
| 37 |
+
|
| 38 |
+
def sent_tokenize(text):
|
| 39 |
+
return text.split('.')
|
| 40 |
+
|
| 41 |
+
# Stemmer stubs
|
| 42 |
+
class PorterStemmer:
|
| 43 |
+
def stem(self, word):
|
| 44 |
+
# Very basic stemming
|
| 45 |
+
if word.endswith('ing'):
|
| 46 |
+
return word[:-3]
|
| 47 |
+
elif word.endswith('ed'):
|
| 48 |
+
return word[:-2]
|
| 49 |
+
elif word.endswith('s') and not word.endswith('ss'):
|
| 50 |
+
return word[:-1]
|
| 51 |
+
return word
|
| 52 |
+
|
| 53 |
+
class LancasterStemmer:
|
| 54 |
+
def stem(self, word):
|
| 55 |
+
return PorterStemmer().stem(word)
|
| 56 |
+
|
| 57 |
+
class SimpleStemmer:
|
| 58 |
+
def __init__(self):
|
| 59 |
+
logger.info("SimpleStemmer stub initialized")
|
| 60 |
+
|
| 61 |
+
def stem(self, word):
|
| 62 |
+
# Very basic stemming: remove common endings
|
| 63 |
+
if word.endswith('ing'):
|
| 64 |
+
return word[:-3]
|
| 65 |
+
elif word.endswith('ed'):
|
| 66 |
+
return word[:-2]
|
| 67 |
+
elif word.endswith('s') and not word.endswith('ss'):
|
| 68 |
+
return word[:-1]
|
| 69 |
+
return word
|
| 70 |
+
|
| 71 |
+
# Stub WordNetLemmatizer class
|
| 72 |
+
class WordNetLemmatizer:
|
| 73 |
+
def __init__(self):
|
| 74 |
+
logger.info("Stub WordNetLemmatizer initialized")
|
| 75 |
+
|
| 76 |
+
def lemmatize(self, word, pos=None):
|
| 77 |
+
# Just return the word as is
|
| 78 |
+
return word
|
| 79 |
+
|
| 80 |
+
# Namespace stubs for import compatibility
|
| 81 |
+
class tokenize:
|
| 82 |
+
WordTokenizer = WordTokenizer
|
| 83 |
+
SentenceTokenizer = SentenceTokenizer
|
| 84 |
+
word_tokenize = word_tokenize
|
| 85 |
+
sent_tokenize = sent_tokenize
|
| 86 |
+
|
| 87 |
+
class stem:
|
| 88 |
+
PorterStemmer = PorterStemmer
|
| 89 |
+
LancasterStemmer = LancasterStemmer
|
| 90 |
+
SimpleStemmer = SimpleStemmer
|
| 91 |
+
|
| 92 |
+
# Stub for corpus
|
| 93 |
+
class _CorpusModule:
|
| 94 |
+
class stopwords:
|
| 95 |
+
@staticmethod
|
| 96 |
+
def words(language="english"):
|
| 97 |
+
# Return basic English stopwords
|
| 98 |
+
return {
|
| 99 |
+
"i", "me", "my", "myself", "we", "our", "ours", "ourselves",
|
| 100 |
+
"you", "your", "yours", "yourself", "yourselves", "he", "him",
|
| 101 |
+
"his", "himself", "she", "her", "hers", "herself", "it", "its",
|
| 102 |
+
"itself", "they", "them", "their", "theirs", "themselves",
|
| 103 |
+
"what", "which", "who", "whom", "this", "that", "these",
|
| 104 |
+
"those", "am", "is", "are", "was", "were", "be", "been",
|
| 105 |
+
"being", "have", "has", "had", "having", "do", "does", "did",
|
| 106 |
+
"doing", "a", "an", "the", "and", "but", "if", "or", "because",
|
| 107 |
+
"as", "until", "while", "of", "at", "by", "for", "with",
|
| 108 |
+
"about", "against", "between", "into", "through", "during",
|
| 109 |
+
"before", "after", "above", "below", "to", "from", "up", "down",
|
| 110 |
+
"in", "out", "on", "off", "over", "under", "again", "further",
|
| 111 |
+
"then", "once", "here", "there", "when", "where", "why", "how",
|
| 112 |
+
"all", "any", "both", "each", "few", "more", "most", "other",
|
| 113 |
+
"some", "such", "no", "nor", "not", "only", "own", "same", "so",
|
| 114 |
+
"than", "too", "very", "s", "t", "can", "will", "just", "don",
|
| 115 |
+
"should", "now", "d", "ll", "m", "o", "re", "ve", "y", "ain",
|
| 116 |
+
"aren", "couldn", "didn", "doesn", "hadn", "hasn", "haven",
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
corpus = _CorpusModule()
|
utils/output_formatter.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import re
|
| 3 |
+
from typing import Optional, Dict, Any, List
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
class OutputFormatter:
|
| 8 |
+
"""
|
| 9 |
+
Formats model responses for better presentation and usability.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self):
|
| 13 |
+
"""
|
| 14 |
+
Initialize the OutputFormatter.
|
| 15 |
+
"""
|
| 16 |
+
self.post_processors = {
|
| 17 |
+
"programming_software_dev": self._format_code,
|
| 18 |
+
"mbpp": self._format_code,
|
| 19 |
+
"machine_learning_ai_data_science": self._format_technical_content,
|
| 20 |
+
"mathematics": self._format_equations,
|
| 21 |
+
"default": self._default_formatter
|
| 22 |
+
}
|
| 23 |
+
logger.info("OutputFormatter initialized")
|
| 24 |
+
|
| 25 |
+
def format_response(self, response: str, specialization: Optional[str] = None) -> str:
|
| 26 |
+
"""
|
| 27 |
+
Format the model response based on specialization.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
response: The raw response from the model
|
| 31 |
+
specialization: The specialization area (optional)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Formatted response
|
| 35 |
+
"""
|
| 36 |
+
if not response:
|
| 37 |
+
return ""
|
| 38 |
+
|
| 39 |
+
# Apply basic formatting to all responses
|
| 40 |
+
formatted_response = self._clean_whitespace(response)
|
| 41 |
+
|
| 42 |
+
# Apply specialization-specific formatting
|
| 43 |
+
processor = self.post_processors.get(specialization, self.post_processors["default"])
|
| 44 |
+
formatted_response = processor(formatted_response)
|
| 45 |
+
|
| 46 |
+
return formatted_response
|
| 47 |
+
|
| 48 |
+
def _clean_whitespace(self, text: str) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Clean up excessive whitespace.
|
| 51 |
+
"""
|
| 52 |
+
# Replace multiple newlines with double newlines
|
| 53 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 54 |
+
# Replace multiple spaces with a single space
|
| 55 |
+
text = re.sub(r' {2,}', ' ', text)
|
| 56 |
+
return text.strip()
|
| 57 |
+
|
| 58 |
+
def _format_code(self, text: str) -> str:
|
| 59 |
+
"""
|
| 60 |
+
Format code blocks with proper syntax highlighting markers.
|
| 61 |
+
"""
|
| 62 |
+
# Identify unmarked code blocks and add markdown code block syntax
|
| 63 |
+
# Look for patterns that suggest code (indentation, common programming keywords)
|
| 64 |
+
code_patterns = [
|
| 65 |
+
r'((?:^|\n)(?:def |class |import |function |public |private |var |let |const |if |for |while ).+(?:\n[ \t]+.+)+)',
|
| 66 |
+
r'((?:^|\n)(?:SELECT |INSERT |UPDATE |DELETE |CREATE |ALTER |DROP ).+(?:;)(?:\n|$))'
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
for pattern in code_patterns:
|
| 70 |
+
def add_code_markers(match):
|
| 71 |
+
code_block = match.group(1)
|
| 72 |
+
# Try to determine the language based on keywords
|
| 73 |
+
lang = self._detect_language(code_block)
|
| 74 |
+
return f"\n```{lang}\n{code_block}\n```\n"
|
| 75 |
+
|
| 76 |
+
text = re.sub(pattern, add_code_markers, text)
|
| 77 |
+
|
| 78 |
+
return text
|
| 79 |
+
|
| 80 |
+
def _detect_language(self, code_block: str) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Attempt to detect the programming language from a code block.
|
| 83 |
+
"""
|
| 84 |
+
if re.search(r'def |class |import |if __name__ ==|print\(', code_block):
|
| 85 |
+
return "python"
|
| 86 |
+
elif re.search(r'function |var |const |let |=> |document\.', code_block):
|
| 87 |
+
return "javascript"
|
| 88 |
+
elif re.search(r'public |private |class .+ {|void |String |int |boolean', code_block):
|
| 89 |
+
return "java"
|
| 90 |
+
elif re.search(r'#include|int main|std::|printf|scanf', code_block):
|
| 91 |
+
return "c++"
|
| 92 |
+
elif re.search(r'SELECT |INSERT |UPDATE |DELETE |CREATE TABLE|ALTER TABLE', code_block):
|
| 93 |
+
return "sql"
|
| 94 |
+
else:
|
| 95 |
+
return "" # Generic code block
|
| 96 |
+
|
| 97 |
+
def _format_equations(self, text: str) -> str:
|
| 98 |
+
"""
|
| 99 |
+
Format mathematical equations with LaTeX markers if needed.
|
| 100 |
+
"""
|
| 101 |
+
# Basic pattern for unmarked equations
|
| 102 |
+
equation_patterns = [
|
| 103 |
+
r'([^$])(\\frac{.+?}{.+?}|\\sum_|\\int_|\\lim_)',
|
| 104 |
+
r'([^$])([a-zA-Z]_[0-9]+)',
|
| 105 |
+
r'([^$])([a-zA-Z]\\in)'
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
for pattern in equation_patterns:
|
| 109 |
+
text = re.sub(pattern, r'\1$\2$', text)
|
| 110 |
+
|
| 111 |
+
# Ensure equation blocks use proper LaTeX delimiters
|
| 112 |
+
text = re.sub(r'\\begin{equation}(.+?)\\end{equation}', r'$$\1$$', text, flags=re.DOTALL)
|
| 113 |
+
|
| 114 |
+
return text
|
| 115 |
+
|
| 116 |
+
def _format_technical_content(self, text: str) -> str:
|
| 117 |
+
"""
|
| 118 |
+
Format technical content with proper highlighting of terms and concepts.
|
| 119 |
+
"""
|
| 120 |
+
# Highlight technical terms
|
| 121 |
+
technical_terms = [
|
| 122 |
+
"neural network", "machine learning", "deep learning", "algorithm",
|
| 123 |
+
"regression", "classification", "clustering", "backpropagation",
|
| 124 |
+
"gradient descent", "optimization", "hyperparameter"
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
for term in technical_terms:
|
| 128 |
+
# Only highlight whole words, not substrings
|
| 129 |
+
text = re.sub(r'\b(' + re.escape(term) + r')\b(?![*_])', r'*\1*', text)
|
| 130 |
+
|
| 131 |
+
return text
|
| 132 |
+
|
| 133 |
+
def _default_formatter(self, text: str) -> str:
|
| 134 |
+
"""
|
| 135 |
+
Default formatter that applies general improvements.
|
| 136 |
+
"""
|
| 137 |
+
# Add paragraph breaks for readability when appropriate
|
| 138 |
+
text = re.sub(r'(\w\.)\s+([A-Z])', r'\1\n\n\2', text)
|
| 139 |
+
|
| 140 |
+
# Format lists for readability if they're not already formatted
|
| 141 |
+
text = re.sub(r'(?<!\n)(\d+\.)\s+', r'\n\1 ', text)
|
| 142 |
+
|
| 143 |
+
return text
|
| 144 |
+
|
| 145 |
+
def format_structured_output(self, data: Dict[str, Any]) -> str:
|
| 146 |
+
"""
|
| 147 |
+
Format structured data outputs (like JSON) into readable text.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
data: Dictionary containing structured data
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Formatted string representation
|
| 154 |
+
"""
|
| 155 |
+
if not isinstance(data, dict):
|
| 156 |
+
return str(data)
|
| 157 |
+
|
| 158 |
+
formatted_parts = []
|
| 159 |
+
|
| 160 |
+
# Format main response if present
|
| 161 |
+
if "response" in data:
|
| 162 |
+
formatted_parts.append(self.format_response(data["response"]))
|
| 163 |
+
|
| 164 |
+
# Add metadata in a clean format if needed
|
| 165 |
+
metadata = {}
|
| 166 |
+
for key, value in data.items():
|
| 167 |
+
if key != "response" and not key.startswith("_"):
|
| 168 |
+
metadata[key] = value
|
| 169 |
+
|
| 170 |
+
if metadata:
|
| 171 |
+
formatted_parts.append("\n\n---\n")
|
| 172 |
+
for key, value in metadata.items():
|
| 173 |
+
formatted_parts.append(f"**{key.replace('_', ' ').title()}**: {value}")
|
| 174 |
+
|
| 175 |
+
return "\n".join(formatted_parts)
|
utils/prepare_hf_training.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to prepare your model training for Hugging Face's training infrastructure.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 13 |
+
|
| 14 |
+
def prepare_hf_training(output_dir="hf_training"):
|
| 15 |
+
"""Prepare code and configs for Hugging Face training"""
|
| 16 |
+
# Create output directory
|
| 17 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
# Required files for HF training
|
| 20 |
+
required_files = [
|
| 21 |
+
"config.json", # Configuration file
|
| 22 |
+
"model_Custm.py", # Custom model implementation
|
| 23 |
+
"model_PrTr.py", # Pretrained model implementation
|
| 24 |
+
"model_Combn.py", # Combined model implementation
|
| 25 |
+
"tokenizer.py", # Tokenizer wrapper
|
| 26 |
+
"dataloader.py", # Data loading utilities
|
| 27 |
+
"utils", # Utility functions
|
| 28 |
+
"data", # Training data
|
| 29 |
+
"run_gpu_training.py" # Main training script
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
# Copy required files
|
| 33 |
+
for file_or_dir in required_files:
|
| 34 |
+
src_path = Path(file_or_dir)
|
| 35 |
+
dst_path = Path(output_dir) / src_path.name
|
| 36 |
+
|
| 37 |
+
if not src_path.exists():
|
| 38 |
+
logger.warning(f"{file_or_dir} not found, skipping...")
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
if src_path.is_dir():
|
| 42 |
+
if dst_path.exists():
|
| 43 |
+
shutil.rmtree(dst_path)
|
| 44 |
+
shutil.copytree(src_path, dst_path)
|
| 45 |
+
logger.info(f"Copied directory {file_or_dir} to {dst_path}")
|
| 46 |
+
else:
|
| 47 |
+
shutil.copy2(src_path, dst_path)
|
| 48 |
+
logger.info(f"Copied file {file_or_dir} to {dst_path}")
|
| 49 |
+
|
| 50 |
+
# Create HF training script
|
| 51 |
+
create_hf_train_script(output_dir)
|
| 52 |
+
|
| 53 |
+
# Create requirements.txt
|
| 54 |
+
create_requirements_file(output_dir)
|
| 55 |
+
|
| 56 |
+
# Update config for HF environment
|
| 57 |
+
update_config_for_hf(output_dir)
|
| 58 |
+
|
| 59 |
+
logger.info(f"Training package prepared in {output_dir}")
|
| 60 |
+
logger.info(f"Upload this directory to Hugging Face for training")
|
| 61 |
+
logger.info("See https://huggingface.co/docs/hub/spaces-manage-deploy for deployment instructions")
|
| 62 |
+
|
| 63 |
+
def create_hf_train_script(output_dir):
|
| 64 |
+
"""Create training script specifically for Hugging Face environment"""
|
| 65 |
+
train_script = """
|
| 66 |
+
# Hugging Face training script for Transformer model
|
| 67 |
+
import os
|
| 68 |
+
import torch
|
| 69 |
+
from run_gpu_training import run_training
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
# HF automatically provides CUDA device if available
|
| 73 |
+
# Always use mixed precision on HF
|
| 74 |
+
run_training(use_mixed_precision=True)
|
| 75 |
+
|
| 76 |
+
# Save model to the /tmp/model directory, which HF preserves
|
| 77 |
+
os.makedirs("/tmp/model", exist_ok=True)
|
| 78 |
+
torch.save({
|
| 79 |
+
"config": "final_model_config",
|
| 80 |
+
"type": "transformer_trained",
|
| 81 |
+
"epochs_completed": 30
|
| 82 |
+
}, "/tmp/model/model_info.json")
|
| 83 |
+
|
| 84 |
+
print("Training completed, model saved to /tmp/model")
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
with open(os.path.join(output_dir, "train_hf.py"), "w") as f:
|
| 88 |
+
f.write(train_script.strip())
|
| 89 |
+
|
| 90 |
+
logger.info("Created HF training script: train_hf.py")
|
| 91 |
+
|
| 92 |
+
def create_requirements_file(output_dir):
|
| 93 |
+
"""Create requirements.txt with dependencies"""
|
| 94 |
+
requirements = [
|
| 95 |
+
"torch>=2.0.0",
|
| 96 |
+
"transformers>=4.30.0",
|
| 97 |
+
"datasets>=2.12.0",
|
| 98 |
+
"pydantic>=2.0.0",
|
| 99 |
+
"sentence-transformers>=2.2.2",
|
| 100 |
+
"scikit-learn>=1.2.2",
|
| 101 |
+
"numpy>=1.24.0",
|
| 102 |
+
"pandas>=2.0.0",
|
| 103 |
+
"tqdm>=4.65.0",
|
| 104 |
+
"matplotlib>=3.7.1"
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
with open(os.path.join(output_dir, "requirements.txt"), "w") as f:
|
| 108 |
+
f.write("\n".join(requirements))
|
| 109 |
+
|
| 110 |
+
logger.info("Created requirements.txt")
|
| 111 |
+
|
| 112 |
+
def update_config_for_hf(output_dir):
|
| 113 |
+
"""Update configuration for HF environment"""
|
| 114 |
+
config_path = os.path.join(output_dir, "config.json")
|
| 115 |
+
|
| 116 |
+
if not os.path.exists(config_path):
|
| 117 |
+
logger.warning("config.json not found, skipping configuration update")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
with open(config_path, "r") as f:
|
| 122 |
+
config = json.load(f)
|
| 123 |
+
|
| 124 |
+
# Update for HF environment - restore full settings
|
| 125 |
+
config["TRANSFORMER_CONFIG"]["BATCH_SIZE"] = 32
|
| 126 |
+
config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] = 512
|
| 127 |
+
config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] = 12
|
| 128 |
+
|
| 129 |
+
# Add HF-specific config
|
| 130 |
+
if "OPTIMIZATION" not in config:
|
| 131 |
+
config["OPTIMIZATION"] = {}
|
| 132 |
+
|
| 133 |
+
config["OPTIMIZATION"]["USE_MIXED_PRECISION"] = True
|
| 134 |
+
config["OPTIMIZATION"]["PLATFORM"] = "huggingface"
|
| 135 |
+
|
| 136 |
+
with open(config_path, "w") as f:
|
| 137 |
+
json.dump(config, f, indent=2)
|
| 138 |
+
|
| 139 |
+
logger.info("Updated config.json for Hugging Face environment")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Error updating config: {e}")
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
parser = argparse.ArgumentParser(description="Prepare training package for Hugging Face")
|
| 145 |
+
parser.add_argument("--output-dir", type=str, default="hf_training",
|
| 146 |
+
help="Output directory for training package")
|
| 147 |
+
args = parser.parse_args()
|
| 148 |
+
|
| 149 |
+
prepare_hf_training(args.output_dir)
|
utils/prepare_hf_transformer_training.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
How to Use prepare_hf_transformer_training.py Safely
|
| 3 |
+
|
| 4 |
+
Here's a secure way to prepare and upload your model to Hugging Face:
|
| 5 |
+
|
| 6 |
+
Step 1: Navigate to Your Project Directory
|
| 7 |
+
cd C:/Users/User/OneDrive/Documents/tlm
|
| 8 |
+
|
| 9 |
+
Step 2: Set Up Authentication for Hugging Face
|
| 10 |
+
huggingface-cli login
|
| 11 |
+
|
| 12 |
+
Step 3: Run the Preparation Script
|
| 13 |
+
python -m utils.prepare_hf_transformer_training --stdp_checkpoint "checkpoints/stdp_model_epoch_20.pt" --output_dir "C:/Users/User/OneDrive/Documents/tlm/Wildnerve-tlm_HF/hf_upload"
|
| 14 |
+
|
| 15 |
+
Step 4: Initialize Git and Upload to Hugging Face
|
| 16 |
+
cd hf_upload
|
| 17 |
+
git init
|
| 18 |
+
git add .
|
| 19 |
+
git commit -m "Add TLM model with STDP checkpoint"
|
| 20 |
+
git remote add origin https://huggingface.co/YOUR-USERNAME/Wildnerve-tlm01
|
| 21 |
+
git pull origin main --allow-unrelated-histories
|
| 22 |
+
git push origin main
|
| 23 |
+
"""
|
| 24 |
+
import os
|
| 25 |
+
import shutil
|
| 26 |
+
import logging
|
| 27 |
+
import argparse
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 32 |
+
|
| 33 |
+
def prepare_training_package(
|
| 34 |
+
stdp_checkpoint_path,
|
| 35 |
+
output_dir="hf_transformer_training",
|
| 36 |
+
include_all=False
|
| 37 |
+
):
|
| 38 |
+
"""Prepare a clean training package for Hugging Face with STDP checkpoint.
|
| 39 |
+
Args:
|
| 40 |
+
stdp_checkpoint_path: Path to the STDP checkpoint file
|
| 41 |
+
output_dir: Directory where to create the package
|
| 42 |
+
include_all: Whether to include all supporting files (utils, analyzers, etc.)"""
|
| 43 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
# Core files needed for transformer training
|
| 46 |
+
essential_files = [
|
| 47 |
+
# Core components
|
| 48 |
+
"app.py",
|
| 49 |
+
"main.py",
|
| 50 |
+
"config.json",
|
| 51 |
+
"config.py",
|
| 52 |
+
"inference.py",
|
| 53 |
+
|
| 54 |
+
# Model implementations
|
| 55 |
+
"model_List.py",
|
| 56 |
+
"model_Custm.py",
|
| 57 |
+
"model_PrTr.py",
|
| 58 |
+
"model_Combn.py",
|
| 59 |
+
"model_manager.py",
|
| 60 |
+
|
| 61 |
+
# Communication components
|
| 62 |
+
"communicator.py",
|
| 63 |
+
"communicator_STDP.py",
|
| 64 |
+
|
| 65 |
+
# Data and training
|
| 66 |
+
"tokenizer.py",
|
| 67 |
+
"trainer.py",
|
| 68 |
+
"dataloader.py",
|
| 69 |
+
"dataset.py",
|
| 70 |
+
"data",
|
| 71 |
+
|
| 72 |
+
# STDP specific components
|
| 73 |
+
"STDP_Communicator/datasets_stdp.py",
|
| 74 |
+
"STDP_Communicator/train_stdp.py",
|
| 75 |
+
|
| 76 |
+
# Utils (only essential ones)
|
| 77 |
+
"utils/convert_checkpoints.py",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
# Additional support files (only included if include_all=True)
|
| 81 |
+
additional_files = [
|
| 82 |
+
"utils/transformer_utils.py",
|
| 83 |
+
"utils/smartHybridAttention.py",
|
| 84 |
+
"utils/sentence_transformer_utils.py",
|
| 85 |
+
"utils/output_formatter.py",
|
| 86 |
+
"emergency_monitor.py",
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
# Choose which files to copy
|
| 90 |
+
required_files = essential_files + (additional_files if include_all else [])
|
| 91 |
+
|
| 92 |
+
logger.info(f"Starting package preparation in {output_dir}")
|
| 93 |
+
logger.info(f"Including {'all' if include_all else 'only essential'} files")
|
| 94 |
+
|
| 95 |
+
# Track successful and failed copies
|
| 96 |
+
copied_files = []
|
| 97 |
+
missing_files = []
|
| 98 |
+
|
| 99 |
+
# Copy files
|
| 100 |
+
for file_path in required_files:
|
| 101 |
+
src = Path(file_path)
|
| 102 |
+
if not src.exists():
|
| 103 |
+
logger.warning(f"File {file_path} not found, skipping")
|
| 104 |
+
missing_files.append(str(src))
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# Create destination directories
|
| 108 |
+
dst = Path(output_dir) / src
|
| 109 |
+
os.makedirs(dst.parent, exist_ok=True)
|
| 110 |
+
|
| 111 |
+
# Copy file or directory
|
| 112 |
+
try:
|
| 113 |
+
if src.is_dir():
|
| 114 |
+
shutil.copytree(src, dst, dirs_exist_ok=True)
|
| 115 |
+
else:
|
| 116 |
+
shutil.copy2(src, dst)
|
| 117 |
+
copied_files.append(str(src))
|
| 118 |
+
logger.info(f"Copied {src} to {dst}")
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Error copying {src}: {e}")
|
| 121 |
+
|
| 122 |
+
# Copy STDP checkpoint
|
| 123 |
+
if os.path.exists(stdp_checkpoint_path):
|
| 124 |
+
stdp_dst = Path(output_dir) / "checkpoints" / Path(stdp_checkpoint_path).name
|
| 125 |
+
os.makedirs(stdp_dst.parent, exist_ok=True)
|
| 126 |
+
try:
|
| 127 |
+
shutil.copy2(stdp_checkpoint_path, stdp_dst)
|
| 128 |
+
logger.info(f"Copied STDP checkpoint to {stdp_dst}")
|
| 129 |
+
copied_files.append(str(stdp_checkpoint_path))
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Error copying checkpoint: {e}")
|
| 132 |
+
missing_files.append(str(stdp_checkpoint_path))
|
| 133 |
+
else:
|
| 134 |
+
logger.warning(f"STDP checkpoint not found at {stdp_checkpoint_path}")
|
| 135 |
+
missing_files.append(str(stdp_checkpoint_path))
|
| 136 |
+
|
| 137 |
+
# Create Hugging Face training script
|
| 138 |
+
create_transformer_training_script(output_dir, stdp_checkpoint_path) # ADD THIS LINE
|
| 139 |
+
|
| 140 |
+
# Create requirements.txt if not already copied
|
| 141 |
+
if "requirements.txt" not in copied_files:
|
| 142 |
+
create_requirements(output_dir)
|
| 143 |
+
copied_files.append("requirements.txt (generated)")
|
| 144 |
+
|
| 145 |
+
# Create README.md if not already copied
|
| 146 |
+
if "README.md" not in copied_files:
|
| 147 |
+
create_readme(output_dir, stdp_checkpoint_path)
|
| 148 |
+
copied_files.append("README.md (generated)")
|
| 149 |
+
|
| 150 |
+
# Summarize what was done
|
| 151 |
+
logger.info(f"Package prepared in {output_dir}")
|
| 152 |
+
logger.info(f"Copied {len(copied_files)} files: {', '.join(copied_files[:5])}...")
|
| 153 |
+
if missing_files:
|
| 154 |
+
logger.warning(f"Missing {len(missing_files)} files: {', '.join(missing_files)}")
|
| 155 |
+
|
| 156 |
+
return output_dir
|
| 157 |
+
|
| 158 |
+
def create_transformer_training_script(output_dir, stdp_checkpoint_path):
|
| 159 |
+
"""Create a script to load STDP checkpoint and train transformer."""
|
| 160 |
+
# Fix: Change the inner docstring to use single quotes to avoid conflict with the outer triple quotes
|
| 161 |
+
script = """
|
| 162 |
+
import os
|
| 163 |
+
import torch
|
| 164 |
+
import logging
|
| 165 |
+
from config import load_config, app_config
|
| 166 |
+
from tokenizer import TokenizerWrapper
|
| 167 |
+
from model_manager import ModelManager
|
| 168 |
+
from dataloader import prepare_data_loaders
|
| 169 |
+
from trainer import Trainer, EarlyStopping
|
| 170 |
+
|
| 171 |
+
# Configure logging
|
| 172 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 173 |
+
logger = logging.getLogger(__name__)
|
| 174 |
+
|
| 175 |
+
def train_transformer(stdp_checkpoint_path):
|
| 176 |
+
'''Train the transformer component after loading STDP weights.'''
|
| 177 |
+
logger.info(f"Starting transformer training with STDP checkpoint: {stdp_checkpoint_path}")
|
| 178 |
+
|
| 179 |
+
# Initialize components
|
| 180 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 181 |
+
logger.info(f"Using device: {device}")
|
| 182 |
+
|
| 183 |
+
# Create tokenizer
|
| 184 |
+
tokenizer = TokenizerWrapper()
|
| 185 |
+
|
| 186 |
+
# Get model manager
|
| 187 |
+
model_manager = ModelManager()
|
| 188 |
+
|
| 189 |
+
# Get specialization
|
| 190 |
+
specialization = app_config.TRANSFORMER_CONFIG.specialization
|
| 191 |
+
|
| 192 |
+
# Load STDP weights
|
| 193 |
+
if os.path.exists(stdp_checkpoint_path):
|
| 194 |
+
try:
|
| 195 |
+
stdp_checkpoint = torch.load(stdp_checkpoint_path, map_location=device)
|
| 196 |
+
logger.info(f"Loaded STDP checkpoint from {stdp_checkpoint_path}")
|
| 197 |
+
|
| 198 |
+
# Now integrate STDP weights with transformer model if needed
|
| 199 |
+
# This depends on your specific architecture
|
| 200 |
+
except Exception as e:
|
| 201 |
+
logger.error(f"Error loading STDP checkpoint: {e}")
|
| 202 |
+
else:
|
| 203 |
+
logger.warning(f"STDP checkpoint not found at {stdp_checkpoint_path}")
|
| 204 |
+
|
| 205 |
+
# Get model and move to device
|
| 206 |
+
model = model_manager.get_model(specialization)
|
| 207 |
+
model.to(device)
|
| 208 |
+
|
| 209 |
+
# Get data loaders
|
| 210 |
+
data_path = app_config.DATASET_PATHS.get(specialization)
|
| 211 |
+
if not data_path or not os.path.exists(data_path):
|
| 212 |
+
# Use a default dataset path
|
| 213 |
+
data_path = next(iter(app_config.DATASET_PATHS.values()))
|
| 214 |
+
logger.warning(f"Dataset for {specialization} not found, using {data_path}")
|
| 215 |
+
|
| 216 |
+
train_loader, val_loader = prepare_data_loaders(
|
| 217 |
+
data_path,
|
| 218 |
+
tokenizer,
|
| 219 |
+
batch_size=app_config.TRANSFORMER_CONFIG.BATCH_SIZE
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Set up checkpoint directory
|
| 223 |
+
checkpoint_dir = os.path.join("checkpoints", "transformer")
|
| 224 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 225 |
+
|
| 226 |
+
# Set up early stopping
|
| 227 |
+
early_stopping = EarlyStopping(
|
| 228 |
+
patience=app_config.TRAINING_CONFIG.PATIENCE,
|
| 229 |
+
delta=app_config.TRAINING_CONFIG.DELTA,
|
| 230 |
+
verbose=True,
|
| 231 |
+
path=os.path.join(checkpoint_dir, "best_model.pt")
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Create trainer
|
| 235 |
+
trainer = Trainer(
|
| 236 |
+
model=model,
|
| 237 |
+
tokenizer=tokenizer,
|
| 238 |
+
train_dataloader=train_loader,
|
| 239 |
+
val_dataloader=val_loader,
|
| 240 |
+
device=device,
|
| 241 |
+
early_stopping=early_stopping,
|
| 242 |
+
checkpoint_dir=checkpoint_dir,
|
| 243 |
+
total_epochs=app_config.TRAINING_CONFIG.TRANSFORMER_NUM_EPOCHS
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Train the model
|
| 247 |
+
logger.info("Starting transformer training...")
|
| 248 |
+
trainer.train()
|
| 249 |
+
|
| 250 |
+
# Save final model
|
| 251 |
+
final_model_path = os.path.join(checkpoint_dir, "final_model.pt")
|
| 252 |
+
torch.save({
|
| 253 |
+
'model_state_dict': model.state_dict(),
|
| 254 |
+
'config': {
|
| 255 |
+
'transformer_epochs': app_config.TRAINING_CONFIG.TRANSFORMER_NUM_EPOCHS,
|
| 256 |
+
'stdp_epochs': 20, # Assuming the STDP checkpoint is from epoch 20
|
| 257 |
+
'specialization': specialization
|
| 258 |
+
}
|
| 259 |
+
}, final_model_path)
|
| 260 |
+
logger.info(f"Final model saved to {final_model_path}")
|
| 261 |
+
|
| 262 |
+
return final_model_path
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
import argparse
|
| 266 |
+
|
| 267 |
+
parser = argparse.ArgumentParser(description="Train transformer after STDP")
|
| 268 |
+
parser.add_argument("--stdp_checkpoint", type=str, default="checkpoints/stdp_model_epoch_20.pt",
|
| 269 |
+
help="Path to pre-trained STDP checkpoint")
|
| 270 |
+
|
| 271 |
+
args = parser.parse_args()
|
| 272 |
+
|
| 273 |
+
# Train transformer
|
| 274 |
+
train_transformer(args.stdp_checkpoint)
|
| 275 |
+
"""
|
| 276 |
+
script_path = os.path.join(output_dir, "train_transformer_hf.py")
|
| 277 |
+
with open(script_path, "w") as f:
|
| 278 |
+
f.write(script.strip())
|
| 279 |
+
logger.info(f"Created training script at {script_path}")
|
| 280 |
+
|
| 281 |
+
def create_requirements(output_dir):
|
| 282 |
+
"""Create requirements.txt file with all necessary dependencies."""
|
| 283 |
+
requirements = [
|
| 284 |
+
"torch>=2.0.0",
|
| 285 |
+
"transformers>=4.30.0",
|
| 286 |
+
"datasets>=2.12.0",
|
| 287 |
+
"pydantic>=2.0.0",
|
| 288 |
+
"sentence-transformers>=2.2.2",
|
| 289 |
+
"scikit-learn>=1.2.2",
|
| 290 |
+
"numpy>=1.24.0",
|
| 291 |
+
"pandas>=2.0.0",
|
| 292 |
+
"tqdm>=4.65.0",
|
| 293 |
+
"matplotlib>=3.7.1",
|
| 294 |
+
"snntorch>=0.7.0"
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
with open(os.path.join(output_dir, "requirements.txt"), "w") as f:
|
| 298 |
+
f.write("\n".join(requirements))
|
| 299 |
+
logger.info("Created requirements.txt")
|
| 300 |
+
|
| 301 |
+
def create_readme(output_dir, stdp_checkpoint_path):
|
| 302 |
+
"""Create README with model information and usage instructions."""
|
| 303 |
+
readme = f"""# Wildnerve-tlm01: Transformer Language Model with STDP
|
| 304 |
+
|
| 305 |
+
This repository contains the Wildnerve-tlm01 model, a transformer-based language model enhanced with
|
| 306 |
+
STDP (Spike-Timing-Dependent Plasticity) for improved learning capabilities.
|
| 307 |
+
|
| 308 |
+
## Pre-trained STDP Checkpoint
|
| 309 |
+
|
| 310 |
+
The STDP component was trained for 20 epochs and saved in: `{os.path.basename(stdp_checkpoint_path)}`
|
| 311 |
+
|
| 312 |
+
## Model Architecture
|
| 313 |
+
|
| 314 |
+
Wildnerve-tlm01 combines:
|
| 315 |
+
- Transformer architecture for language understanding
|
| 316 |
+
- Spiking Neural Network (SNN) with STDP for biological learning
|
| 317 |
+
- Smart Hybrid Attention for efficient processing
|
| 318 |
+
|
| 319 |
+
## Usage
|
| 320 |
+
"""
|
| 321 |
+
with open(os.path.join(output_dir, "README.md"), "w") as f:
|
| 322 |
+
f.write(readme)
|
| 323 |
+
logger.info("Created README.md")
|
| 324 |
+
|
| 325 |
+
if __name__ == "__main__":
|
| 326 |
+
parser = argparse.ArgumentParser(description="Prepare Hugging Face training package")
|
| 327 |
+
parser.add_argument("--stdp_checkpoint", type=str, default="checkpoints/stdp_model_epoch_20.pt",
|
| 328 |
+
help="Path to pre-trained STDP checkpoint")
|
| 329 |
+
parser.add_argument("--output_dir", type=str, default="hf_upload",
|
| 330 |
+
help="Output directory for training package")
|
| 331 |
+
parser.add_argument("--include_all", action="store_true",
|
| 332 |
+
help="Include additional supporting files")
|
| 333 |
+
|
| 334 |
+
args = parser.parse_args()
|
| 335 |
+
prepare_training_package(args.stdp_checkpoint, args.output_dir, args.include_all)
|
utils/sentence_transformer_utils.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for loading and working with sentence transformers.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# Constants
|
| 13 |
+
DEFAULT_SENTENCE_TRANSFORMER = "Wildnerve-tlm01-0.05Bx12" # Removed fallback to all-MiniLM-L6-v2
|
| 14 |
+
|
| 15 |
+
# Cache for loaded models to avoid reloading
|
| 16 |
+
_sentence_transformer_cache = {}
|
| 17 |
+
def get_sentence_transformer(model_name: str = DEFAULT_SENTENCE_TRANSFORMER):
|
| 18 |
+
"""
|
| 19 |
+
Get a sentence transformer model.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_name: Name of the model to load (default is our primary model)
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
SentenceTransformer model
|
| 26 |
+
"""
|
| 27 |
+
# Define the expected local directory for your custom model
|
| 28 |
+
local_model_dir = os.path.join("c:/Users/User/OneDrive/Documents/tlm/Wildnerve-tlm_HF/models", model_name)
|
| 29 |
+
# Use the local directory if it exists; otherwise, use the provided model_name identifier (which must be on HuggingFace)
|
| 30 |
+
model_path = local_model_dir if os.path.isdir(local_model_dir) else model_name
|
| 31 |
+
try:
|
| 32 |
+
return SentenceTransformer(model_path)
|
| 33 |
+
except Exception as e:
|
| 34 |
+
logger.error(f"Failed to load SentenceTransformer from {model_path}: {e}")
|
| 35 |
+
raise
|
| 36 |
+
|
| 37 |
+
def clear_sentence_transformer_cache():
|
| 38 |
+
"""Clear the sentence transformer cache to free memory."""
|
| 39 |
+
global _sentence_transformer_cache
|
| 40 |
+
_sentence_transformer_cache.clear()
|
| 41 |
+
logger.info("Cleared sentence transformer cache")
|
utils/smartHybridAttention.py
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# smartHybridAttention.py - Enhanced SmartHybridAttention that combines the best features of both implementations:
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import math
|
| 5 |
+
import json
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import logging as logger
|
| 9 |
+
from typing import Optional, Tuple, List, Dict, Any, Union
|
| 10 |
+
|
| 11 |
+
# Fix imports for service_registry - make it more robust with fallbacks
|
| 12 |
+
try:
|
| 13 |
+
# Try direct import first
|
| 14 |
+
from service_registry import ServiceRegistry
|
| 15 |
+
except ImportError:
|
| 16 |
+
try:
|
| 17 |
+
# Try adding parent directories to path
|
| 18 |
+
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 19 |
+
sys.path.append(parent_dir)
|
| 20 |
+
from service_registry import ServiceRegistry
|
| 21 |
+
except ImportError:
|
| 22 |
+
# Create dummy registry if not found
|
| 23 |
+
class DummyRegistry:
|
| 24 |
+
def get_service(self, name): return None
|
| 25 |
+
def register_service(self, name, service): pass
|
| 26 |
+
registry = DummyRegistry()
|
| 27 |
+
|
| 28 |
+
# Use conditional import for AttentionProfileSelector
|
| 29 |
+
try:
|
| 30 |
+
# Try direct import first
|
| 31 |
+
from utils.attention_trigger_system import AttentionProfileSelector
|
| 32 |
+
except ImportError:
|
| 33 |
+
# Try setting up different paths
|
| 34 |
+
try:
|
| 35 |
+
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
|
| 36 |
+
sys.path.append(data_dir)
|
| 37 |
+
from utils.attention_trigger_system import AttentionProfileSelector
|
| 38 |
+
except ImportError:
|
| 39 |
+
# Create a minimal placeholder if not found
|
| 40 |
+
class DummyAttentionProfileSelector:
|
| 41 |
+
def __init__(self, config_path=None): pass
|
| 42 |
+
def select_profile(self, text, context=None): return "standard", 1.0
|
| 43 |
+
def get_attention_type(self, profile_id): return "standard"
|
| 44 |
+
def get_profile_parameters(self, profile_id): return {}
|
| 45 |
+
AttentionProfileSelector = DummyAttentionProfileSelector
|
| 46 |
+
|
| 47 |
+
# Merging the two functions into a single robust implementation
|
| 48 |
+
def get_hybrid_attention_config(config_path: Optional[str] = None) -> Dict[str, Any]:
|
| 49 |
+
"""Get configuration for hybrid attention mechanism from multiple sources.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
config_path: Optional path to a JSON configuration file
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Dictionary with attention configuration parameters
|
| 56 |
+
"""
|
| 57 |
+
# Start with default configuration
|
| 58 |
+
default_config = {
|
| 59 |
+
"DIM": 768,
|
| 60 |
+
"NUM_HEADS": 12,
|
| 61 |
+
"WINDOW_SIZE": 256,
|
| 62 |
+
"USE_SLIDING": True,
|
| 63 |
+
"USE_GLOBAL": True,
|
| 64 |
+
"USE_HIERARCHICAL": False,
|
| 65 |
+
"GLOBAL_TOKEN_RATIO": 0.05,
|
| 66 |
+
"MEMORY_TOKENS": 32,
|
| 67 |
+
"STRIDE": 128
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
# Try to load from app_config if available
|
| 71 |
+
try:
|
| 72 |
+
from config import app_config
|
| 73 |
+
|
| 74 |
+
# Use safe value loading with proper type checking
|
| 75 |
+
dim = safe_get_int_value(app_config, 'EMBEDDING_DIM', default_config["DIM"])
|
| 76 |
+
num_heads = safe_get_int_value(app_config, 'NUM_HEADS', default_config["NUM_HEADS"])
|
| 77 |
+
window_size = safe_get_int_value(app_config, 'WINDOW_SIZE', default_config["WINDOW_SIZE"])
|
| 78 |
+
|
| 79 |
+
# Update config with values from app_config
|
| 80 |
+
default_config.update({
|
| 81 |
+
'DIM': dim,
|
| 82 |
+
'NUM_HEADS': num_heads,
|
| 83 |
+
'WINDOW_SIZE': window_size,
|
| 84 |
+
# These values might be different in app_config, so keep them from the original defaults
|
| 85 |
+
'USE_HIERARCHICAL': True, # Changed from default based on app_config version
|
| 86 |
+
'GLOBAL_TOKEN_RATIO': 0.2, # Changed from default based on app_config version
|
| 87 |
+
'MEMORY_TOKENS': 16, # Changed from default based on app_config version
|
| 88 |
+
})
|
| 89 |
+
|
| 90 |
+
# Calculate stride based on window size with proper type checking
|
| 91 |
+
if isinstance(window_size, int) and window_size > 0:
|
| 92 |
+
default_config['STRIDE'] = window_size // 2
|
| 93 |
+
else:
|
| 94 |
+
default_config['STRIDE'] = 128
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.warning(f"Error loading config from app_config: {e}, using defaults")
|
| 98 |
+
|
| 99 |
+
# Try to load from JSON file if path provided
|
| 100 |
+
if config_path and os.path.exists(config_path):
|
| 101 |
+
try:
|
| 102 |
+
with open(config_path, "r") as f:
|
| 103 |
+
user_config = json.load(f)
|
| 104 |
+
# Merge configs, with user config taking precedence
|
| 105 |
+
for key, value in user_config.items():
|
| 106 |
+
default_config[key.upper()] = value
|
| 107 |
+
logger.info(f"Loaded attention config from {config_path}")
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.warning(f"Error loading attention config from {config_path}: {e}")
|
| 110 |
+
|
| 111 |
+
return default_config
|
| 112 |
+
|
| 113 |
+
# Update the get_attention_config function to use our new merged function
|
| 114 |
+
def get_attention_config(config_path: Optional[str] = None) -> Dict[str, Any]:
|
| 115 |
+
"""Get attention configuration using the most appropriate method available"""
|
| 116 |
+
return get_hybrid_attention_config(config_path)
|
| 117 |
+
|
| 118 |
+
# Add helper function for the above
|
| 119 |
+
def safe_get_int_value(config_obj, key, default=512):
|
| 120 |
+
"""Safely get an integer value from config with proper type checking"""
|
| 121 |
+
try:
|
| 122 |
+
if hasattr(config_obj, key):
|
| 123 |
+
value = getattr(config_obj, key)
|
| 124 |
+
elif hasattr(config_obj, 'TRANSFORMER_CONFIG') and hasattr(config_obj.TRANSFORMER_CONFIG, key):
|
| 125 |
+
value = getattr(config_obj.TRANSFORMER_CONFIG, key)
|
| 126 |
+
else:
|
| 127 |
+
return default
|
| 128 |
+
|
| 129 |
+
if isinstance(value, dict):
|
| 130 |
+
logger.warning(f"Config value {key} is a dictionary, using default: {default}")
|
| 131 |
+
return default
|
| 132 |
+
elif isinstance(value, (int, float)):
|
| 133 |
+
return int(value)
|
| 134 |
+
else:
|
| 135 |
+
logger.warning(f"Config value {key} is not a number, using default: {default}")
|
| 136 |
+
return default
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.warning(f"Error getting config value {key}: {e}")
|
| 139 |
+
return default
|
| 140 |
+
|
| 141 |
+
class SmartHybridAttention(nn.Module):
|
| 142 |
+
"""SmartHybridAttention that combines the best features of both implementations:
|
| 143 |
+
- Memory storage via global token selection from Wildnerve-tlm_HF
|
| 144 |
+
- Multiple attention strategies from utils version
|
| 145 |
+
- HuggingFace compatibility layer
|
| 146 |
+
- Optimized for extremely large context windows"""
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
dim: int,
|
| 150 |
+
num_heads: int = 8,
|
| 151 |
+
window_size: int = 256,
|
| 152 |
+
use_sliding: bool = True,
|
| 153 |
+
use_global: bool = True,
|
| 154 |
+
use_hierarchical: bool = False,
|
| 155 |
+
global_token_ratio: float = 0.05,
|
| 156 |
+
memory_tokens: int = 32,
|
| 157 |
+
config_path: Optional[str] = None,
|
| 158 |
+
registry_key: Optional[str] = None,
|
| 159 |
+
attention_config_path: Optional[str] = None
|
| 160 |
+
):
|
| 161 |
+
super().__init__()
|
| 162 |
+
# Ensure all parameters are the correct types
|
| 163 |
+
self.dim = int(dim) if isinstance(dim, (int, float)) else 768
|
| 164 |
+
self.num_heads = int(num_heads) if isinstance(num_heads, (int, float)) else 8
|
| 165 |
+
self.head_dim = self.dim // self.num_heads # Safe integer division
|
| 166 |
+
self.window_size = int(window_size) if isinstance(window_size, (int, float)) else 256
|
| 167 |
+
self.scale = self.head_dim ** -0.5
|
| 168 |
+
|
| 169 |
+
# Feature flags - ensure boolean types
|
| 170 |
+
self.use_sliding = bool(use_sliding)
|
| 171 |
+
self.use_global = bool(use_global)
|
| 172 |
+
self.use_hierarchical = bool(use_hierarchical)
|
| 173 |
+
|
| 174 |
+
# Ensure float type for ratio
|
| 175 |
+
self.global_token_ratio = float(global_token_ratio) if isinstance(global_token_ratio, (int, float)) else 0.05
|
| 176 |
+
|
| 177 |
+
# Ensure int type for memory tokens
|
| 178 |
+
self.memory_tokens = int(memory_tokens) if isinstance(memory_tokens, (int, float)) else 32
|
| 179 |
+
|
| 180 |
+
# Initialize memory parameter
|
| 181 |
+
self.persistent_memory = nn.Parameter(torch.zeros(self.memory_tokens, 1, self.dim))
|
| 182 |
+
nn.init.normal_(self.persistent_memory, mean=0.0, std=0.02)
|
| 183 |
+
|
| 184 |
+
# Projections
|
| 185 |
+
self.q_proj = nn.Linear(self.dim, self.dim)
|
| 186 |
+
self.k_proj = nn.Linear(self.dim, self.dim)
|
| 187 |
+
self.v_proj = nn.Linear(self.dim, self.dim)
|
| 188 |
+
self.out_proj = nn.Linear(self.dim, self.dim)
|
| 189 |
+
|
| 190 |
+
# Initialize optional components
|
| 191 |
+
self.config = self._load_config(config_path) if config_path else {}
|
| 192 |
+
self.registry_key = registry_key
|
| 193 |
+
self.prompt_analyzer = None
|
| 194 |
+
self._init_external_services()
|
| 195 |
+
|
| 196 |
+
# Initialize content-aware attention selector
|
| 197 |
+
self.attention_config_path = attention_config_path
|
| 198 |
+
if self.attention_config_path:
|
| 199 |
+
try:
|
| 200 |
+
self.profile_selector = AttentionProfileSelector(self.attention_config_path)
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.warning(f"Could not initialize AttentionProfileSelector: {e}")
|
| 203 |
+
self.profile_selector = None
|
| 204 |
+
else:
|
| 205 |
+
self.profile_selector = None
|
| 206 |
+
|
| 207 |
+
def _init_external_services(self):
|
| 208 |
+
"""Initialize external services like registry and analyzer if available."""
|
| 209 |
+
try:
|
| 210 |
+
# Try to import service registry
|
| 211 |
+
sys.path.extend([
|
| 212 |
+
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "utils"),
|
| 213 |
+
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Wildnerve-tlm_HF", "utils")
|
| 214 |
+
])
|
| 215 |
+
|
| 216 |
+
if self.registry_key:
|
| 217 |
+
registry = ServiceRegistry.get_instance()
|
| 218 |
+
self.prompt_analyzer = registry.get_service("prompt_analyzer")
|
| 219 |
+
# Register self if registry key provided
|
| 220 |
+
registry.register_service(self.registry_key, self)
|
| 221 |
+
except (ImportError, AttributeError, Exception):
|
| 222 |
+
logger.debug("External services not available")
|
| 223 |
+
|
| 224 |
+
def _load_config(self, config_path: str) -> Dict:
|
| 225 |
+
"""Load configuration from JSON file."""
|
| 226 |
+
if not os.path.exists(config_path):
|
| 227 |
+
return {}
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
with open(config_path, "r") as f:
|
| 231 |
+
return json.load(f)
|
| 232 |
+
except:
|
| 233 |
+
return {}
|
| 234 |
+
|
| 235 |
+
def _create_sliding_window_mask(
|
| 236 |
+
self,
|
| 237 |
+
seq_len: int,
|
| 238 |
+
window_size: int,
|
| 239 |
+
global_token_indices: Optional[List[int]] = None,
|
| 240 |
+
memory_size: int = 0
|
| 241 |
+
) -> torch.Tensor:
|
| 242 |
+
"""Create attention mask for sliding window with memory tokens and global tokens."""
|
| 243 |
+
total_len = seq_len + memory_size
|
| 244 |
+
mask = torch.zeros(total_len, total_len, dtype=torch.bool)
|
| 245 |
+
|
| 246 |
+
# Memory tokens attend to everything and everything attends to memory
|
| 247 |
+
if memory_size > 0:
|
| 248 |
+
mask[:memory_size, :] = True # Memory attends to all
|
| 249 |
+
mask[:, :memory_size] = True # All attend to memory
|
| 250 |
+
|
| 251 |
+
# Set sliding window attention for content tokens
|
| 252 |
+
for i in range(memory_size, total_len):
|
| 253 |
+
# Adjust window bounds
|
| 254 |
+
start = max(memory_size, i - window_size // 2)
|
| 255 |
+
end = min(total_len, i + window_size // 2 + 1)
|
| 256 |
+
mask[i, start:end] = True
|
| 257 |
+
|
| 258 |
+
# Add global token attention if provided
|
| 259 |
+
if global_token_indices is not None:
|
| 260 |
+
# Adjust indices to account for memory tokens
|
| 261 |
+
adjusted_indices = [idx + memory_size for idx in global_token_indices]
|
| 262 |
+
# Global tokens attend to all tokens
|
| 263 |
+
mask[adjusted_indices, :] = True
|
| 264 |
+
# All tokens attend to global tokens
|
| 265 |
+
mask[:, adjusted_indices] = True
|
| 266 |
+
return mask
|
| 267 |
+
|
| 268 |
+
def _select_global_tokens(
|
| 269 |
+
self,
|
| 270 |
+
key_layer: torch.Tensor,
|
| 271 |
+
ratio: float = None,
|
| 272 |
+
memory_size: int = 0
|
| 273 |
+
) -> List[int]:
|
| 274 |
+
"""Select global tokens based on importance scoring.
|
| 275 |
+
Returns: List of indices of selected global tokens"""
|
| 276 |
+
if ratio is None:
|
| 277 |
+
ratio = self.global_token_ratio
|
| 278 |
+
|
| 279 |
+
seq_len = key_layer.size(0) - memory_size
|
| 280 |
+
num_global_tokens = max(1, int(seq_len * ratio))
|
| 281 |
+
|
| 282 |
+
# Skip memory tokens when scoring importance
|
| 283 |
+
content_keys = key_layer[memory_size:]
|
| 284 |
+
|
| 285 |
+
# Score tokens by L2 norm and recency (more recent = more important)
|
| 286 |
+
base_scores = torch.norm(content_keys, dim=-1).mean(dim=-1) # [seq_len]
|
| 287 |
+
|
| 288 |
+
# Add recency bias
|
| 289 |
+
seq_positions = torch.arange(seq_len, device=base_scores.device) / seq_len
|
| 290 |
+
recency_scores = 0.3 * seq_positions # Mild recency bias
|
| 291 |
+
final_scores = base_scores + recency_scores
|
| 292 |
+
|
| 293 |
+
# Select top-k indices
|
| 294 |
+
_, indices = torch.topk(final_scores, k=min(num_global_tokens, seq_len))
|
| 295 |
+
|
| 296 |
+
return indices.tolist()
|
| 297 |
+
|
| 298 |
+
def _apply_memory_augmented_attention(
|
| 299 |
+
self,
|
| 300 |
+
query: torch.Tensor,
|
| 301 |
+
key: torch.Tensor,
|
| 302 |
+
value: torch.Tensor,
|
| 303 |
+
attention_mask: Optional[torch.Tensor] = None
|
| 304 |
+
) -> torch.Tensor:
|
| 305 |
+
"""Apply attention with persistent memory tokens for long-range context.
|
| 306 |
+
Returns: Output tensor after attention [seq_len, batch, dim]"""
|
| 307 |
+
seq_len, batch_size, _ = query.size()
|
| 308 |
+
|
| 309 |
+
# Expand memory tokens to batch size
|
| 310 |
+
memory_batch = self.persistent_memory.expand(-1, batch_size, -1)
|
| 311 |
+
|
| 312 |
+
# Prepend memory tokens to input
|
| 313 |
+
query_with_memory = torch.cat([memory_batch, query], dim=0)
|
| 314 |
+
key_with_memory = torch.cat([memory_batch, key], dim=0)
|
| 315 |
+
value_with_memory = torch.cat([memory_batch, value], dim=0)
|
| 316 |
+
|
| 317 |
+
# Project query, key, value
|
| 318 |
+
q = self.q_proj(query_with_memory)
|
| 319 |
+
k = self.k_proj(key_with_memory)
|
| 320 |
+
v = self.v_proj(value_with_memory)
|
| 321 |
+
|
| 322 |
+
# Select global tokens for additional global attention
|
| 323 |
+
global_token_indices = None
|
| 324 |
+
if self.use_global and seq_len > self.window_size:
|
| 325 |
+
global_token_indices = self._select_global_tokens(k, memory_size=self.memory_tokens)
|
| 326 |
+
|
| 327 |
+
# Create or modify attention mask
|
| 328 |
+
memory_size = self.memory_tokens
|
| 329 |
+
full_seq_len = seq_len + memory_size
|
| 330 |
+
|
| 331 |
+
if attention_mask is None:
|
| 332 |
+
window_mask = self._create_sliding_window_mask(
|
| 333 |
+
seq_len, self.window_size, global_token_indices, memory_size
|
| 334 |
+
)
|
| 335 |
+
attention_mask = ~window_mask
|
| 336 |
+
attention_mask = attention_mask.to(q.device).unsqueeze(0).unsqueeze(0)
|
| 337 |
+
attention_mask = attention_mask * -1e9 # Large negative for masked positions
|
| 338 |
+
else:
|
| 339 |
+
# Modify existing mask to include memory tokens
|
| 340 |
+
memory_mask = torch.zeros(full_seq_len, full_seq_len, device=attention_mask.device)
|
| 341 |
+
memory_mask[memory_size:, memory_size:] = attention_mask
|
| 342 |
+
# Make memory attend to everything and everything attend to memory
|
| 343 |
+
memory_mask[:memory_size, :] = 0 # 0 = attend (not masked)
|
| 344 |
+
memory_mask[:, :memory_size] = 0
|
| 345 |
+
attention_mask = memory_mask
|
| 346 |
+
|
| 347 |
+
# Reshape for multi-head attention
|
| 348 |
+
q = q.view(full_seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 349 |
+
k = k.view(full_seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 350 |
+
v = v.view(full_seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 351 |
+
|
| 352 |
+
# Compute attention
|
| 353 |
+
scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 354 |
+
if attention_mask is not None:
|
| 355 |
+
scores = scores + attention_mask
|
| 356 |
+
attention_weights = torch.softmax(scores, dim=-1)
|
| 357 |
+
context = torch.matmul(attention_weights, v)
|
| 358 |
+
|
| 359 |
+
# Reshape back
|
| 360 |
+
context = context.transpose(1, 2).transpose(0, 1).contiguous()
|
| 361 |
+
context = context.view(full_seq_len, batch_size, self.dim)
|
| 362 |
+
|
| 363 |
+
# Remove memory tokens from output
|
| 364 |
+
context = context[memory_size:]
|
| 365 |
+
|
| 366 |
+
# Final projection
|
| 367 |
+
output = self.out_proj(context)
|
| 368 |
+
return output
|
| 369 |
+
|
| 370 |
+
def _apply_hierarchical_attention(
|
| 371 |
+
self,
|
| 372 |
+
query: torch.Tensor,
|
| 373 |
+
key: torch.Tensor,
|
| 374 |
+
value: torch.Tensor
|
| 375 |
+
) -> torch.Tensor:
|
| 376 |
+
"""Apply hierarchical attention for very long sequences."""
|
| 377 |
+
seq_len, batch_size, _ = query.size()
|
| 378 |
+
chunk_size = min(512, seq_len)
|
| 379 |
+
num_chunks = math.ceil(seq_len / chunk_size)
|
| 380 |
+
|
| 381 |
+
# First level: process chunks independently
|
| 382 |
+
chunk_outputs = []
|
| 383 |
+
for i in range(num_chunks):
|
| 384 |
+
start_idx = i * chunk_size
|
| 385 |
+
end_idx = min((i + 1) * chunk_size, seq_len)
|
| 386 |
+
|
| 387 |
+
# Extract chunk
|
| 388 |
+
q_chunk = query[start_idx:end_idx]
|
| 389 |
+
k_chunk = key[start_idx:end_idx]
|
| 390 |
+
v_chunk = value[start_idx:end_idx]
|
| 391 |
+
|
| 392 |
+
# Process chunk with full attention
|
| 393 |
+
q_proj = self.q_proj(q_chunk)
|
| 394 |
+
k_proj = self.k_proj(k_chunk)
|
| 395 |
+
v_proj = self.v_proj(v_chunk)
|
| 396 |
+
|
| 397 |
+
# Reshape for multi-head attention
|
| 398 |
+
chunk_len = end_idx - start_idx
|
| 399 |
+
q_proj = q_proj.view(chunk_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 400 |
+
k_proj = k_proj.view(chunk_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 401 |
+
v_proj = v_proj.view(chunk_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 402 |
+
|
| 403 |
+
# Compute attention
|
| 404 |
+
scores = torch.matmul(q_proj, k_proj.transpose(-1, -2)) * self.scale
|
| 405 |
+
weights = torch.softmax(scores, dim=-1)
|
| 406 |
+
context = torch.matmul(weights, v_proj)
|
| 407 |
+
|
| 408 |
+
# Reshape back
|
| 409 |
+
context = context.transpose(1, 2).transpose(0, 1).contiguous()
|
| 410 |
+
context = context.view(chunk_len, batch_size, self.dim)
|
| 411 |
+
chunk_outputs.append(context)
|
| 412 |
+
|
| 413 |
+
# Concatenate chunks
|
| 414 |
+
output = torch.cat(chunk_outputs, dim=0)
|
| 415 |
+
output = self.out_proj(output)
|
| 416 |
+
return output
|
| 417 |
+
|
| 418 |
+
def forward(
|
| 419 |
+
self,
|
| 420 |
+
query: torch.Tensor,
|
| 421 |
+
key: torch.Tensor,
|
| 422 |
+
value: torch.Tensor,
|
| 423 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 424 |
+
input_text: Optional[str] = None, # New parameter for content detection
|
| 425 |
+
context: Optional[Dict] = None, # New parameter for additional context
|
| 426 |
+
**kwargs
|
| 427 |
+
) -> torch.Tensor:
|
| 428 |
+
"""Forward pass for the enhanced smart hybrid attention layer."""
|
| 429 |
+
seq_len, batch_size, _ = query.size()
|
| 430 |
+
|
| 431 |
+
# If input_text is provided, try to use content-aware attention
|
| 432 |
+
if input_text:
|
| 433 |
+
try:
|
| 434 |
+
# Try multiple import paths to find attention_connector
|
| 435 |
+
connector = None
|
| 436 |
+
import_paths = [
|
| 437 |
+
# Try direct import
|
| 438 |
+
lambda: __import__('attention_connector').get_attention_connector(),
|
| 439 |
+
# Try data subdirectory
|
| 440 |
+
lambda: __import__('data.attention_connector').get_attention_connector(),
|
| 441 |
+
# Try relative path
|
| 442 |
+
lambda: __import__('.'.join(['..', 'data', 'attention_connector']),
|
| 443 |
+
fromlist=['get_attention_connector']).get_attention_connector()
|
| 444 |
+
]
|
| 445 |
+
|
| 446 |
+
for import_path in import_paths:
|
| 447 |
+
try:
|
| 448 |
+
connector = import_path()
|
| 449 |
+
break
|
| 450 |
+
except (ImportError, AttributeError):
|
| 451 |
+
continue
|
| 452 |
+
|
| 453 |
+
if connector:
|
| 454 |
+
connector.set_input_text(input_text)
|
| 455 |
+
except Exception as e:
|
| 456 |
+
# Silently continue if connector not available
|
| 457 |
+
logger.debug(f"Error setting input text for content-aware attention: {e}")
|
| 458 |
+
|
| 459 |
+
# Analyze sequence characteristics to choose strategy
|
| 460 |
+
strategy_weights = self._get_attention_strategy(seq_len, input_text, context)
|
| 461 |
+
|
| 462 |
+
# For very short sequences, use standard attention
|
| 463 |
+
if seq_len < 128:
|
| 464 |
+
q = self.q_proj(query)
|
| 465 |
+
k = self.k_proj(key)
|
| 466 |
+
v = self.v_proj(value)
|
| 467 |
+
|
| 468 |
+
# Multi-head attention
|
| 469 |
+
q = q.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 470 |
+
k = k.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 471 |
+
v = v.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 472 |
+
|
| 473 |
+
scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 474 |
+
if attention_mask is not None:
|
| 475 |
+
scores = scores + attention_mask
|
| 476 |
+
attn_weights = torch.softmax(scores, dim=-1)
|
| 477 |
+
context = torch.matmul(attn_weights, v)
|
| 478 |
+
|
| 479 |
+
context = context.transpose(1, 2).transpose(0, 1).contiguous()
|
| 480 |
+
context = context.view(seq_len, batch_size, self.dim)
|
| 481 |
+
return self.out_proj(context)
|
| 482 |
+
|
| 483 |
+
# For longer sequences, use memory-augmented attention
|
| 484 |
+
if strategy_weights["memory"] > 0:
|
| 485 |
+
return self._apply_memory_augmented_attention(query, key, value, attention_mask)
|
| 486 |
+
|
| 487 |
+
# For very long sequences where memory doesn't fit, use hierarchical
|
| 488 |
+
if strategy_weights["hierarchical"] > 0:
|
| 489 |
+
return self._apply_hierarchical_attention(query, key, value)
|
| 490 |
+
|
| 491 |
+
# Fallback: standard sliding window without memory
|
| 492 |
+
q = self.q_proj(query)
|
| 493 |
+
k = self.k_proj(key)
|
| 494 |
+
v = self.v_proj(value)
|
| 495 |
+
|
| 496 |
+
global_token_indices = self._select_global_tokens(k, memory_size=0) if self.use_global else None
|
| 497 |
+
window_mask = self._create_sliding_window_mask(seq_len, self.window_size, global_token_indices, 0)
|
| 498 |
+
masked_attn = ~window_mask
|
| 499 |
+
masked_attn = masked_attn.to(query.device).unsqueeze(0).unsqueeze(0) * -1e9
|
| 500 |
+
|
| 501 |
+
# Standard multi-head attention with mask
|
| 502 |
+
q = q.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 503 |
+
k = k.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 504 |
+
v = v.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
|
| 505 |
+
|
| 506 |
+
scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 507 |
+
scores = scores + masked_attn
|
| 508 |
+
attn_weights = torch.softmax(scores, dim=-1)
|
| 509 |
+
context = torch.matmul(attn_weights, v)
|
| 510 |
+
|
| 511 |
+
context = context.transpose(1, 2).transpose(0, 1).contiguous()
|
| 512 |
+
context = context.view(seq_len, batch_size, self.dim)
|
| 513 |
+
return self.out_proj(context)
|
| 514 |
+
|
| 515 |
+
def _get_attention_strategy(self, seq_len: int, input_text: Optional[str] = None, context: Optional[Dict] = None) -> Dict[str, float]:
|
| 516 |
+
"""Determine which attention strategy to use based on sequence length
|
| 517 |
+
and optional prompt analysis."""
|
| 518 |
+
weights = {
|
| 519 |
+
"standard": 0.0,
|
| 520 |
+
"sliding": 0.0,
|
| 521 |
+
"memory": 0.0,
|
| 522 |
+
"hierarchical": 0.0
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
# Try content-aware selection if available and input_text is provided
|
| 526 |
+
if self.profile_selector and input_text:
|
| 527 |
+
try:
|
| 528 |
+
profile_id, confidence = self.profile_selector.select_profile(input_text, context)
|
| 529 |
+
|
| 530 |
+
# If we have high confidence in the profile selection, use profile-specific weights
|
| 531 |
+
if confidence > 0.65:
|
| 532 |
+
attention_type = self.profile_selector.get_attention_type(profile_id)
|
| 533 |
+
|
| 534 |
+
if attention_type == "hierarchical":
|
| 535 |
+
weights["hierarchical"] = 0.7
|
| 536 |
+
weights["memory"] = 0.3
|
| 537 |
+
return weights
|
| 538 |
+
elif attention_type == "smartHybrid":
|
| 539 |
+
weights["memory"] = 0.5
|
| 540 |
+
weights["sliding"] = 0.5
|
| 541 |
+
return weights
|
| 542 |
+
elif attention_type == "recencyBiased":
|
| 543 |
+
weights["memory"] = 0.8
|
| 544 |
+
weights["sliding"] = 0.2
|
| 545 |
+
return weights
|
| 546 |
+
# Additional attention types can be added here
|
| 547 |
+
except Exception as e:
|
| 548 |
+
print(f"Warning: Error in content-based attention selection: {e}")
|
| 549 |
+
|
| 550 |
+
# Fall back to sequence length-based selection if content detection fails or has low confidence
|
| 551 |
+
if seq_len < 128:
|
| 552 |
+
weights["standard"] = 1.0
|
| 553 |
+
elif seq_len < 2048:
|
| 554 |
+
weights["sliding"] = 0.2
|
| 555 |
+
weights["memory"] = 0.8
|
| 556 |
+
elif seq_len < 8192:
|
| 557 |
+
weights["memory"] = 1.0
|
| 558 |
+
else:
|
| 559 |
+
weights["memory"] = 0.7
|
| 560 |
+
weights["hierarchical"] = 0.3
|
| 561 |
+
|
| 562 |
+
# Adjust based on prompt analyzer if available
|
| 563 |
+
if self.prompt_analyzer:
|
| 564 |
+
try:
|
| 565 |
+
analysis = self.prompt_analyzer.get_current_analysis()
|
| 566 |
+
if analysis:
|
| 567 |
+
complexity = analysis.get("complexity", 0.5)
|
| 568 |
+
structure = analysis.get("structure_score", 0.5)
|
| 569 |
+
|
| 570 |
+
# Adjust for highly structured content
|
| 571 |
+
if structure > 0.7:
|
| 572 |
+
weights["hierarchical"] = min(0.8, weights["hierarchical"] + 0.3)
|
| 573 |
+
weights["memory"] = max(0.2, weights["memory"] - 0.3)
|
| 574 |
+
|
| 575 |
+
# Adjust for high complexity content
|
| 576 |
+
if complexity > 0.8 and seq_len > 1024:
|
| 577 |
+
weights["memory"] = min(1.0, weights["memory"] + 0.2)
|
| 578 |
+
except:
|
| 579 |
+
logger.debug("Error in prompt analysis")
|
| 580 |
+
|
| 581 |
+
return weights
|
| 582 |
+
|
| 583 |
+
def to_hf_attention(self):
|
| 584 |
+
"""
|
| 585 |
+
Convert to HuggingFace-compatible attention layer.
|
| 586 |
+
"""
|
| 587 |
+
class HFCompatibleAttention(nn.Module):
|
| 588 |
+
def __init__(self, smart_attention):
|
| 589 |
+
super().__init__()
|
| 590 |
+
self.smart_attention = smart_attention
|
| 591 |
+
|
| 592 |
+
def __call__(self, hidden_states, attention_mask=None, **kwargs):
|
| 593 |
+
# Convert HF format attention mask if present
|
| 594 |
+
if attention_mask is not None:
|
| 595 |
+
if attention_mask.dim() == 4: # [batch, 1, 1, seq_len]
|
| 596 |
+
attention_mask = attention_mask.squeeze(1).squeeze(1)
|
| 597 |
+
attention_mask = attention_mask.to(dtype=torch.float32)
|
| 598 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
| 599 |
+
|
| 600 |
+
# Convert from [batch, seq, dim] to [seq, batch, dim]
|
| 601 |
+
seq_first = hidden_states.transpose(0, 1)
|
| 602 |
+
|
| 603 |
+
# Apply attention
|
| 604 |
+
output = self.smart_attention(
|
| 605 |
+
seq_first, seq_first, seq_first,
|
| 606 |
+
attention_mask=attention_mask,
|
| 607 |
+
**kwargs
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# Convert back to [batch, seq, dim]
|
| 611 |
+
return output.transpose(0, 1)
|
| 612 |
+
return HFCompatibleAttention(self)
|
| 613 |
+
|
| 614 |
+
@classmethod
|
| 615 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 616 |
+
"""
|
| 617 |
+
Create an instance from a pretrained Hugging Face model configuration.
|
| 618 |
+
"""
|
| 619 |
+
try:
|
| 620 |
+
from transformers import AutoConfig
|
| 621 |
+
|
| 622 |
+
# Load config from HF model
|
| 623 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
| 624 |
+
|
| 625 |
+
# Extract relevant attributes
|
| 626 |
+
attention_kwargs = {
|
| 627 |
+
"dim": config.hidden_size,
|
| 628 |
+
"num_heads": config.num_attention_heads,
|
| 629 |
+
"window_size": kwargs.get("window_size", 512),
|
| 630 |
+
"memory_tokens": kwargs.get("memory_tokens", min(32, config.max_position_embeddings // 64)),
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
# Update with any user-provided kwargs
|
| 634 |
+
attention_kwargs.update(kwargs)
|
| 635 |
+
|
| 636 |
+
# Create instance
|
| 637 |
+
return cls(**attention_kwargs)
|
| 638 |
+
except ImportError:
|
| 639 |
+
raise ImportError("transformers library required to load from pretrained model")
|
| 640 |
+
except Exception as e:
|
| 641 |
+
raise ValueError(f"Failed to initialize from pretrained model: {e}")
|
| 642 |
+
|
| 643 |
+
def create_smart_hybrid_attention(
|
| 644 |
+
dim: int = 768,
|
| 645 |
+
num_heads: int = 12,
|
| 646 |
+
max_sequence_length: int = 8192,
|
| 647 |
+
for_huggingface: bool = True,
|
| 648 |
+
**kwargs
|
| 649 |
+
) -> Union[SmartHybridAttention, nn.Module]:
|
| 650 |
+
"""Factory function to create an attention mechanism suitable for the given context.
|
| 651 |
+
Args:
|
| 652 |
+
dim: Hidden dimension size
|
| 653 |
+
num_heads: Number of attention heads
|
| 654 |
+
max_sequence_length: Maximum expected sequence length
|
| 655 |
+
for_huggingface: Whether to return a HuggingFace-compatible version
|
| 656 |
+
Returns:
|
| 657 |
+
An attention module that can be used in a transformer"""
|
| 658 |
+
# Determine appropriate memory size based on sequence length
|
| 659 |
+
memory_tokens = min(max(16, max_sequence_length // 256), 64)
|
| 660 |
+
|
| 661 |
+
# Determine appropriate window size
|
| 662 |
+
window_size = min(512, max(128, max_sequence_length // 32))
|
| 663 |
+
|
| 664 |
+
# Create attention module
|
| 665 |
+
attention = SmartHybridAttention(
|
| 666 |
+
dim=dim,
|
| 667 |
+
num_heads=num_heads,
|
| 668 |
+
window_size=window_size,
|
| 669 |
+
memory_tokens=memory_tokens,
|
| 670 |
+
**kwargs
|
| 671 |
+
)
|
| 672 |
+
# Wrap for HuggingFace if requested
|
| 673 |
+
if for_huggingface:
|
| 674 |
+
return attention.to_hf_attention()
|
| 675 |
+
return attention
|
utils/tokenizer_utils.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for tokenizer-related operations.
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, List, Any, Union, Optional
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
def get_special_tokens_mask(tokenizer, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
| 12 |
+
"""
|
| 13 |
+
Retrieve special tokens mask.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
tokenizer: Tokenizer to use
|
| 17 |
+
token_ids_0: First token IDs
|
| 18 |
+
token_ids_1: Second token IDs (for pairs)
|
| 19 |
+
already_has_special_tokens: Whether token_ids already contain special tokens
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
List of 1s and 0s, where 1 indicates a special token
|
| 23 |
+
"""
|
| 24 |
+
if already_has_special_tokens:
|
| 25 |
+
return tokenizer.get_special_tokens_mask(
|
| 26 |
+
token_ids_0,
|
| 27 |
+
token_ids_1=token_ids_1,
|
| 28 |
+
already_has_special_tokens=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if token_ids_1 is None:
|
| 32 |
+
return tokenizer.get_special_tokens_mask(
|
| 33 |
+
token_ids_0,
|
| 34 |
+
token_ids_1=None,
|
| 35 |
+
already_has_special_tokens=False
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
return tokenizer.get_special_tokens_mask(
|
| 39 |
+
token_ids_0,
|
| 40 |
+
token_ids_1=token_ids_1,
|
| 41 |
+
already_has_special_tokens=False
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def add_tokens_to_tokenizer(tokenizer, new_tokens):
|
| 45 |
+
"""
|
| 46 |
+
Add new tokens to tokenizer vocabulary.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
tokenizer: Tokenizer to modify
|
| 50 |
+
new_tokens: List of new tokens to add
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Number of tokens added
|
| 54 |
+
"""
|
| 55 |
+
return tokenizer.add_tokens(new_tokens)
|
| 56 |
+
|
| 57 |
+
def format_batch_for_model(
|
| 58 |
+
batch: Dict[str, torch.Tensor],
|
| 59 |
+
device: torch.device = None
|
| 60 |
+
) -> Dict[str, torch.Tensor]:
|
| 61 |
+
"""
|
| 62 |
+
Format a batch for model input, moving tensors to specified device.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
batch: Dictionary of tensors
|
| 66 |
+
device: Device to move tensors to
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Formatted batch dictionary
|
| 70 |
+
"""
|
| 71 |
+
if device is None:
|
| 72 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 73 |
+
|
| 74 |
+
formatted_batch = {}
|
| 75 |
+
for k, v in batch.items():
|
| 76 |
+
if isinstance(v, torch.Tensor):
|
| 77 |
+
formatted_batch[k] = v.to(device)
|
| 78 |
+
else:
|
| 79 |
+
formatted_batch[k] = v
|
| 80 |
+
return formatted_batch
|
| 81 |
+
|
| 82 |
+
def batch_encode_plus(
|
| 83 |
+
tokenizer,
|
| 84 |
+
texts: List[str],
|
| 85 |
+
batch_size: int = 32,
|
| 86 |
+
max_length: int = 512,
|
| 87 |
+
return_tensors: str = "pt",
|
| 88 |
+
**kwargs
|
| 89 |
+
) -> List[Dict[str, torch.Tensor]]:
|
| 90 |
+
"""
|
| 91 |
+
Encode a large batch of texts in smaller chunks.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
tokenizer: Tokenizer to use
|
| 95 |
+
texts: List of texts to encode
|
| 96 |
+
batch_size: Size of each processing batch
|
| 97 |
+
max_length: Maximum sequence length
|
| 98 |
+
return_tensors: Return format ('pt' for PyTorch)
|
| 99 |
+
**kwargs: Additional encoding parameters
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
List of encoded batches
|
| 103 |
+
"""
|
| 104 |
+
batches = []
|
| 105 |
+
|
| 106 |
+
for i in range(0, len(texts), batch_size):
|
| 107 |
+
batch_texts = texts[i:i + batch_size]
|
| 108 |
+
encoded = tokenizer(
|
| 109 |
+
batch_texts,
|
| 110 |
+
max_length=max_length,
|
| 111 |
+
padding="max_length",
|
| 112 |
+
truncation=True,
|
| 113 |
+
return_tensors=return_tensors,
|
| 114 |
+
**kwargs
|
| 115 |
+
)
|
| 116 |
+
batches.append(encoded)
|
| 117 |
+
|
| 118 |
+
return batches
|
| 119 |
+
|
| 120 |
+
def get_tokenizer_info(tokenizer) -> Dict[str, Any]:
|
| 121 |
+
"""
|
| 122 |
+
Get information about a tokenizer.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
tokenizer: Tokenizer to inspect
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Dictionary with tokenizer information
|
| 129 |
+
"""
|
| 130 |
+
info = {
|
| 131 |
+
"vocab_size": len(tokenizer),
|
| 132 |
+
"model_name": getattr(tokenizer, "name_or_path", None),
|
| 133 |
+
"special_tokens": {}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
# Get special token attributes if available
|
| 137 |
+
special_tokens = [
|
| 138 |
+
"pad_token", "unk_token", "sep_token",
|
| 139 |
+
"cls_token", "mask_token", "bos_token", "eos_token"
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
for token_name in special_tokens:
|
| 143 |
+
token_value = getattr(tokenizer, f"{token_name}", None)
|
| 144 |
+
if token_value is not None:
|
| 145 |
+
info["special_tokens"][token_name] = token_value
|
| 146 |
+
|
| 147 |
+
return info
|
utils/transformer_utils.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified utilities for handling transformers, tokenizers, and embeddings.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Dict, Any, Optional, Union, List
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoModel
|
| 10 |
+
|
| 11 |
+
# Import the new sentence transformer utilities
|
| 12 |
+
from utils.sentence_transformer_utils import get_sentence_transformer as load_sentence_transformer
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# Constants
|
| 17 |
+
DEFAULT_SENTENCE_TRANSFORMER = "sentence-transformers/Wildnerve-tlm01-0.05Bx12"
|
| 18 |
+
DEFAULT_TOKENIZER = "bert-base-uncased"
|
| 19 |
+
FALLBACK_TOKENIZERS = ["bert-base-uncased", "gpt2", "roberta-base"]
|
| 20 |
+
|
| 21 |
+
# Cache for loaded models to avoid reloading
|
| 22 |
+
_model_cache = {}
|
| 23 |
+
_tokenizer_cache = {}
|
| 24 |
+
_sentence_transformer_cache = {}
|
| 25 |
+
|
| 26 |
+
def get_sentence_transformer(model_name):
|
| 27 |
+
try:
|
| 28 |
+
from sentence_transformers import SentenceTransformer
|
| 29 |
+
return SentenceTransformer(model_name)
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logging.error(f"Failed to load sentence transformer {model_name}: {e}")
|
| 32 |
+
logging.warning("Falling back to default model: Wildnerve-tlm01-0.05Bx12")
|
| 33 |
+
from sentence_transformers import SentenceTransformer
|
| 34 |
+
return SentenceTransformer("Wildnerve-tlm01-0.05Bx12")
|
| 35 |
+
|
| 36 |
+
def get_tokenizer(model_name: str = "bert-base-uncased"):
|
| 37 |
+
"""Get a tokenizer with proper error handling"""
|
| 38 |
+
try:
|
| 39 |
+
from transformers import AutoTokenizer
|
| 40 |
+
logger.info(f"Loading tokenizer: {model_name}")
|
| 41 |
+
return AutoTokenizer.from_pretrained(model_name)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error(f"Failed to load tokenizer {model_name}: {e}")
|
| 44 |
+
# Return a minimal dummy tokenizer that won't break everything
|
| 45 |
+
logger.warning("Using dummy tokenizer as fallback")
|
| 46 |
+
|
| 47 |
+
class DummyTokenizer:
|
| 48 |
+
def __init__(self):
|
| 49 |
+
self.vocab_size = 30522 # BERT vocab size
|
| 50 |
+
self.pad_token_id = 0
|
| 51 |
+
self.eos_token_id = 102
|
| 52 |
+
self.bos_token_id = 101
|
| 53 |
+
|
| 54 |
+
def __call__(self, text, **kwargs):
|
| 55 |
+
"""Convert text to a dict with dummy tensors"""
|
| 56 |
+
import torch
|
| 57 |
+
|
| 58 |
+
# Handle batch vs single input
|
| 59 |
+
is_batch = isinstance(text, list)
|
| 60 |
+
texts = text if is_batch else [text]
|
| 61 |
+
|
| 62 |
+
# Create random but deterministic IDs based on text length
|
| 63 |
+
input_ids = []
|
| 64 |
+
attention_mask = []
|
| 65 |
+
|
| 66 |
+
for t in texts:
|
| 67 |
+
# Use text length to create deterministic pseudo-random sequence
|
| 68 |
+
import hashlib
|
| 69 |
+
hash_obj = hashlib.md5(t.encode())
|
| 70 |
+
seed = int(hash_obj.hexdigest(), 16) % 10000
|
| 71 |
+
|
| 72 |
+
import random
|
| 73 |
+
random.seed(seed)
|
| 74 |
+
|
| 75 |
+
# Get length or use max_length if provided
|
| 76 |
+
max_length = kwargs.get("max_length", 128)
|
| 77 |
+
length = min(len(t.split()), max_length)
|
| 78 |
+
|
| 79 |
+
# Generate ids and mask
|
| 80 |
+
ids = [self.bos_token_id] + [random.randint(1000, 30000) for _ in range(length-2)] + [self.eos_token_id]
|
| 81 |
+
mask = [1] * len(ids)
|
| 82 |
+
|
| 83 |
+
# Pad if needed
|
| 84 |
+
if "padding" in kwargs:
|
| 85 |
+
pad_length = max_length - len(ids)
|
| 86 |
+
if pad_length > 0:
|
| 87 |
+
ids.extend([self.pad_token_id] * pad_length)
|
| 88 |
+
mask.extend([0] * pad_length)
|
| 89 |
+
|
| 90 |
+
input_ids.append(torch.tensor(ids))
|
| 91 |
+
attention_mask.append(torch.tensor(mask))
|
| 92 |
+
|
| 93 |
+
# Stack tensors
|
| 94 |
+
if "return_tensors" in kwargs and kwargs["return_tensors"] == "pt":
|
| 95 |
+
if is_batch or len(texts) > 1:
|
| 96 |
+
return {
|
| 97 |
+
"input_ids": torch.stack(input_ids),
|
| 98 |
+
"attention_mask": torch.stack(attention_mask)
|
| 99 |
+
}
|
| 100 |
+
else:
|
| 101 |
+
return {
|
| 102 |
+
"input_ids": input_ids[0].unsqueeze(0),
|
| 103 |
+
"attention_mask": attention_mask[0].unsqueeze(0)
|
| 104 |
+
}
|
| 105 |
+
else:
|
| 106 |
+
return {
|
| 107 |
+
"input_ids": input_ids[0] if not is_batch and len(texts) == 1 else input_ids,
|
| 108 |
+
"attention_mask": attention_mask[0] if not is_batch and len(texts) == 1 else attention_mask
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
def decode(self, token_ids, skip_special_tokens=True, **kwargs):
|
| 112 |
+
"""Convert token IDs back to text"""
|
| 113 |
+
if isinstance(token_ids, (list, tuple)) and len(token_ids) > 0:
|
| 114 |
+
return f"Decoded text from {len(token_ids)} tokens"
|
| 115 |
+
return "Decoded text"
|
| 116 |
+
|
| 117 |
+
return DummyTokenizer()
|
| 118 |
+
|
| 119 |
+
def get_hybrid_attention_config():
|
| 120 |
+
"""Get configuration for smart hybrid attention mechanism"""
|
| 121 |
+
from utils.smartHybridAttention import get_hybrid_attention_config
|
| 122 |
+
return get_hybrid_attention_config()
|
| 123 |
+
|
| 124 |
+
def load_transformer_model(model_name: str, device: Optional[torch.device] = None) -> AutoModel:
|
| 125 |
+
"""
|
| 126 |
+
Load a transformer model.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
model_name: Name of the model to load
|
| 130 |
+
device: Optional device to load the model on
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Loaded transformer model
|
| 134 |
+
"""
|
| 135 |
+
try:
|
| 136 |
+
logger.info(f"Loading transformer model: {model_name}")
|
| 137 |
+
model = AutoModel.from_pretrained(model_name)
|
| 138 |
+
|
| 139 |
+
if device:
|
| 140 |
+
model = model.to(device)
|
| 141 |
+
|
| 142 |
+
logger.info(f"Successfully loaded model: {model_name}")
|
| 143 |
+
return model
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Error loading model {model_name}: {e}")
|
| 146 |
+
raise
|
| 147 |
+
|
| 148 |
+
def clear_cache():
|
| 149 |
+
"""Clear all model and tokenizer caches to free memory."""
|
| 150 |
+
global _model_cache, _tokenizer_cache, _sentence_transformer_cache
|
| 151 |
+
_model_cache.clear()
|
| 152 |
+
_tokenizer_cache.clear()
|
| 153 |
+
_sentence_transformer_cache.clear()
|
| 154 |
+
logger.info("Cleared transformer model and tokenizer caches")
|
| 155 |
+
|
| 156 |
+
def get_embedding(text: str, model: Optional[SentenceTransformer] = None) -> torch.Tensor:
|
| 157 |
+
"""Get embedding for a text string using a sentence transformer model."""
|
| 158 |
+
if model is None:
|
| 159 |
+
model = get_sentence_transformer(DEFAULT_SENTENCE_TRANSFORMER)
|
| 160 |
+
return model.encode(text, convert_to_tensor=True)
|