WildnerveAI commited on
Commit
0861a59
·
verified ·
1 Parent(s): 6e29355

Upload 20 files

Browse files
utils/__init__.py CHANGED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ # Add project root to path
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+ # Utils package initialization
7
+ from .transformer_utils import get_tokenizer, get_sentence_transformer
8
+ try:
9
+ from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
10
+ except ImportError:
11
+ try:
12
+ from smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
13
+ except ImportError:
14
+ print("Warning: Could not import SmartHybridAttention")
15
+ SmartHybridAttention = None
16
+ get_hybrid_attention_config = None
17
+
18
+ __all__ = ['get_tokenizer', 'get_sentence_transformer', 'SmartHybridAttention', 'get_hybrid_attention_config']
utils/attention_connector.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from typing import Dict, Any, Optional, List, Tuple, Union
5
+ import torch
6
+
7
+ # Conditional import for attention profile selector
8
+ try:
9
+ from attention_trigger_system import AttentionProfileSelector
10
+ ATTENTION_SELECTOR_AVAILABLE = True
11
+ except ImportError:
12
+ ATTENTION_SELECTOR_AVAILABLE = False
13
+ logging.warning("AttentionProfileSelector not available - content-aware attention disabled")
14
+
15
+ class AttentionConnector:
16
+ """
17
+ Connects the core architecture with content-aware attention mechanisms.
18
+ This class serves as the integration layer between:
19
+ 1. Original text from user input
20
+ 2. The attention configuration system
21
+ 3. The SmartHybridAttention implementation
22
+ """
23
+
24
+ def __init__(self, config_path: Optional[str] = None):
25
+ """Initialize the connector with configuration"""
26
+ self.logger = logging.getLogger(__name__)
27
+
28
+ # Set up default config path if none provided
29
+ if config_path is None:
30
+ self.config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
31
+ else:
32
+ self.config_path = config_path
33
+
34
+ # Initialize profile selector if available
35
+ self.profile_selector = self._init_profile_selector()
36
+
37
+ # Thread-local storage for current input text
38
+ self.current_input_text = None
39
+ self.current_context = {}
40
+ self.active_profile_id = "standard"
41
+ self.profile_confidence = 1.0
42
+
43
+ def _init_profile_selector(self) -> Optional[Any]:
44
+ """Initialize the attention profile selector"""
45
+ if not ATTENTION_SELECTOR_AVAILABLE:
46
+ self.logger.warning("AttentionProfileSelector not available - using default attention")
47
+ return None
48
+
49
+ try:
50
+ selector = AttentionProfileSelector(self.config_path)
51
+ self.logger.info(f"Initialized AttentionProfileSelector with {len(selector.profiles)} profiles")
52
+ return selector
53
+ except Exception as e:
54
+ self.logger.error(f"Error initializing AttentionProfileSelector: {e}")
55
+ return None
56
+
57
+ def set_input_text(self, text: str, context: Optional[Dict[str, Any]] = None):
58
+ """Set the current input text for attention mechanism"""
59
+ self.current_input_text = text
60
+ self.current_context = context or {}
61
+
62
+ # If we have a profile selector, determine the appropriate profile
63
+ if self.profile_selector:
64
+ self.active_profile_id, self.profile_confidence = self.profile_selector.select_profile(
65
+ text, self.current_context
66
+ )
67
+ self.logger.info(f"Selected attention profile: {self.active_profile_id} (confidence: {self.profile_confidence:.2f})")
68
+
69
+ def get_attention_parameters(self) -> Dict[str, Any]:
70
+ """Get parameters for the current attention profile"""
71
+ if not self.profile_selector:
72
+ return {}
73
+
74
+ return self.profile_selector.get_profile_parameters(self.active_profile_id)
75
+
76
+ def inject_attention_parameters(self, attention_module: Any) -> Any:
77
+ """Inject content-aware parameters into an attention module"""
78
+ if not hasattr(attention_module, 'set_parameters'):
79
+ self.logger.warning(f"Attention module does not support parameter injection")
80
+ return attention_module
81
+
82
+ params = self.get_attention_parameters()
83
+ attention_module.set_parameters(**params)
84
+ return attention_module
85
+
86
+ def get_input_context(self) -> Dict[str, Any]:
87
+ """Get the current input context for attention mechanism"""
88
+ return {
89
+ "input_text": self.current_input_text,
90
+ "context": self.current_context,
91
+ "profile_id": self.active_profile_id,
92
+ "confidence": self.profile_confidence
93
+ }
94
+
95
+ # Global singleton instance
96
+ _connector_instance = None
97
+
98
+ def get_attention_connector() -> AttentionConnector:
99
+ """Get or create the global attention connector instance"""
100
+ global _connector_instance
101
+ if _connector_instance is None:
102
+ _connector_instance = AttentionConnector()
103
+ return _connector_instance
104
+
105
+ # Hook functions to integrate with existing architecture
106
+
107
+ def inject_input_text(input_text: str, model_forward_kwargs: Dict[str, Any]) -> Dict[str, Any]:
108
+ """
109
+ Inject input text into model forward kwargs.
110
+ This function should be called in the communicator before forwarding to model.
111
+ """
112
+ connector = get_attention_connector()
113
+ connector.set_input_text(input_text)
114
+
115
+ # Add input_text to kwargs for models that support it directly
116
+ model_forward_kwargs["original_text"] = input_text
117
+
118
+ return model_forward_kwargs
119
+
120
+ def prepare_attention_context(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Dict[str, Any]:
121
+ """
122
+ Prepare attention context from connector for attention mechanism.
123
+ This should be called within the attention module's forward method.
124
+ """
125
+ connector = get_attention_connector()
126
+ attention_context = connector.get_input_context()
127
+
128
+ # Include tensor shapes for attention mechanism's reference
129
+ attention_context.update({
130
+ "query_shape": query.shape,
131
+ "key_shape": key.shape,
132
+ "value_shape": value.shape,
133
+ })
134
+
135
+ return attention_context
utils/attention_trigger_system.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Dict, Any, Optional, List, Tuple
4
+ import re
5
+ from pathlib import Path
6
+
7
+ class AttentionProfileSelector:
8
+ """
9
+ Selects appropriate attention profiles based on input characteristics
10
+ and configuration specified in the JSON dataset.
11
+ """
12
+
13
+ def __init__(self, config_path: Optional[str] = None):
14
+ """
15
+ Initialize the selector with the provided configuration.
16
+
17
+ Args:
18
+ config_path: Path to the attention configuration JSON
19
+ """
20
+ if config_path is None:
21
+ # Default to the standard location
22
+ config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
23
+
24
+ self.config = self._load_config(config_path)
25
+ self.profiles = {p["profile_id"]: p for p in self.config.get("attention_profiles", [])}
26
+ self.default_profile_id = self.config.get("default_profile", "standard")
27
+ self.selection_strategy = self.config.get("profile_selection_strategy", {})
28
+
29
+ def _load_config(self, config_path: str) -> Dict[str, Any]:
30
+ """Load configuration from JSON file."""
31
+ try:
32
+ with open(config_path, 'r') as f:
33
+ return json.load(f)
34
+ except (FileNotFoundError, json.JSONDecodeError) as e:
35
+ print(f"Error loading attention configuration: {e}")
36
+ return {}
37
+
38
+ def select_profile(self,
39
+ input_text: str,
40
+ context: Optional[Dict[str, Any]] = None) -> Tuple[str, float]:
41
+ """
42
+ Select the most appropriate attention profile based on input characteristics.
43
+
44
+ Args:
45
+ input_text: The user's input text
46
+ context: Additional context about the interaction
47
+
48
+ Returns:
49
+ Tuple of (profile_id, confidence)
50
+ """
51
+ if not self.profiles:
52
+ return self.default_profile_id, 1.0
53
+
54
+ # Initialize scores for each profile
55
+ scores = {profile_id: 0.0 for profile_id in self.profiles.keys()}
56
+
57
+ # Calculate content length score
58
+ input_length = len(input_text)
59
+ for profile_id, profile in self.profiles.items():
60
+ # Check document length thresholds
61
+ length_threshold = profile.get("activation_signals", {}).get("document_length_threshold", 0)
62
+ if input_length > length_threshold and length_threshold > 0:
63
+ scores[profile_id] += self.selection_strategy.get("strategy_parameters", {}).get("input_length_weight", 0.3)
64
+
65
+ # Check content type signals
66
+ for profile_id, profile in self.profiles.items():
67
+ content_signals = profile.get("activation_signals", {}).get("content_type_signals", [])
68
+ matched_signals = sum(1 for signal in content_signals if signal.lower() in input_text.lower())
69
+ if content_signals:
70
+ signal_score = matched_signals / len(content_signals)
71
+ scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
72
+
73
+ # Check structure indicators
74
+ for profile_id, profile in self.profiles.items():
75
+ structure_signals = profile.get("activation_signals", {}).get("structure_indicators", [])
76
+ matched_signals = sum(1 for signal in structure_signals if signal.lower() in input_text.lower())
77
+ if structure_signals:
78
+ signal_score = matched_signals / len(structure_signals)
79
+ scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
80
+
81
+ # Check for explicit request in context
82
+ if context and "requested_attention" in context:
83
+ requested = context["requested_attention"]
84
+ if requested in self.profiles:
85
+ scores[requested] += self.selection_strategy.get("strategy_parameters", {}).get("explicit_request_weight", 1.0)
86
+
87
+ # Get the highest scoring profile
88
+ if not scores:
89
+ return self.default_profile_id, 1.0
90
+
91
+ best_profile_id = max(scores.items(), key=lambda x: x[1])[0]
92
+ confidence = scores[best_profile_id]
93
+
94
+ # Apply minimum confidence threshold
95
+ min_confidence = self.selection_strategy.get("strategy_parameters", {}).get("minimum_confidence", 0.65)
96
+ if confidence < min_confidence:
97
+ return self.default_profile_id, confidence
98
+
99
+ return best_profile_id, confidence
100
+
101
+ def get_profile_parameters(self, profile_id: str) -> Dict[str, Any]:
102
+ """
103
+ Get the parameters for the specified attention profile.
104
+
105
+ Args:
106
+ profile_id: ID of the attention profile
107
+
108
+ Returns:
109
+ Dictionary of attention parameters
110
+ """
111
+ if profile_id in self.profiles:
112
+ return self.profiles[profile_id].get("parameters", {})
113
+ return {}
114
+
115
+ def get_attention_type(self, profile_id: str) -> str:
116
+ """
117
+ Get the attention mechanism type for the specified profile.
118
+
119
+ Args:
120
+ profile_id: ID of the attention profile
121
+
122
+ Returns:
123
+ String identifying the attention type
124
+ """
125
+ if profile_id in self.profiles:
126
+ return self.profiles[profile_id].get("attention_type", "standard")
127
+ return "standard"
128
+
129
+
130
+ # Factory method to create appropriate attention mechanism
131
+ def create_attention_mechanism(profile_id: str, model_dim: int, selector: AttentionProfileSelector):
132
+ """
133
+ Factory function to create an attention mechanism based on the selected profile.
134
+
135
+ Args:
136
+ profile_id: ID of the selected attention profile
137
+ model_dim: Model hidden dimension
138
+ selector: AttentionProfileSelector instance
139
+
140
+ Returns:
141
+ Configured attention mechanism
142
+ """
143
+ # This function would integrate with your existing attention mechanisms
144
+ # For implementation with smartHybridAttention:
145
+ attention_type = selector.get_attention_type(profile_id)
146
+ parameters = selector.get_profile_parameters(profile_id)
147
+
148
+ # Import here to avoid circular imports
149
+ try:
150
+ from smartHybridAttention import EnhancedSmartHybridAttention, create_smart_hybrid_attention
151
+
152
+ # Map parameters from JSON to the attention mechanism
153
+ attention_params = {
154
+ "dim": model_dim,
155
+ "num_heads": parameters.get("num_heads", 8),
156
+ "window_size": parameters.get("window_size", 256),
157
+ "use_sliding": parameters.get("use_sliding_window", True),
158
+ "use_global": parameters.get("use_global_tokens", True),
159
+ "global_token_ratio": parameters.get("global_token_ratio", 0.05),
160
+ "memory_tokens": parameters.get("memory_token_count", 16)
161
+ }
162
+
163
+ # Create appropriate attention mechanism based on type
164
+ if attention_type == "hierarchical":
165
+ attention_params["use_hierarchical"] = True
166
+
167
+ return create_smart_hybrid_attention(**attention_params)
168
+
169
+ except ImportError:
170
+ print("Warning: smartHybridAttention not found. Using placeholder.")
171
+ # Return a placeholder if the module is not available
172
+ import torch.nn as nn
173
+ return nn.MultiheadAttention(model_dim, 8)
174
+
175
+
176
+ # Usage example:
177
+ if __name__ == "__main__":
178
+ selector = AttentionProfileSelector()
179
+
180
+ # Example inputs
181
+ code_input = "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)"
182
+
183
+ document_input = """# Chapter 1: Introduction
184
+ This technical document covers the architecture of our system.
185
+ ## Section 1.1: Overview
186
+ The system consists of multiple components working together.
187
+ """
188
+
189
+ conversation_input = "How did you like the movie we saw yesterday? I thought the ending was unexpected."
190
+
191
+ # Test profile selection
192
+ code_profile, code_conf = selector.select_profile(code_input)
193
+ doc_profile, doc_conf = selector.select_profile(document_input)
194
+ conv_profile, conv_conf = selector.select_profile(conversation_input)
195
+
196
+ print(f"Code input → {code_profile} (confidence: {code_conf:.2f})")
197
+ print(f"Document input → {doc_profile} (confidence: {doc_conf:.2f})")
198
+ print(f"Conversation input → {conv_profile} (confidence: {conv_conf:.2f})")
utils/collator.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom data collators for transformer training.
3
+ """
4
+ import torch
5
+ import random
6
+ from typing import Dict, List, Any, Union
7
+ from dataclasses import dataclass
8
+
9
+ @dataclass
10
+ class DataCollatorForLanguageModeling:
11
+ """
12
+ Data collator for language modeling.
13
+
14
+ This collator will tokenize inputs and dynamically mask tokens
15
+ for masked language modeling tasks.
16
+ """
17
+ tokenizer: Any
18
+ mlm: bool = True # Whether to use masked language modeling
19
+ mlm_probability: float = 0.15 # Probability of masking a token
20
+
21
+ def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
22
+ """
23
+ Collate a batch of examples.
24
+
25
+ Args:
26
+ examples: List of examples from dataset
27
+
28
+ Returns:
29
+ Batch dictionary for model
30
+ """
31
+ # Extract input_ids
32
+ input_ids = [example["input_ids"] for example in examples]
33
+
34
+ # Concatenate inputs
35
+ batch = self.tokenizer.pad(
36
+ {"input_ids": input_ids},
37
+ return_tensors="pt"
38
+ )
39
+
40
+ # If masked language modeling is enabled
41
+ if self.mlm:
42
+ inputs, labels = self.mask_tokens(batch["input_ids"])
43
+ return {"input_ids": inputs, "labels": labels}
44
+ else:
45
+ labels = batch["input_ids"].clone()
46
+ return {
47
+ "input_ids": batch["input_ids"],
48
+ "labels": labels,
49
+ "attention_mask": batch.get("attention_mask", None)
50
+ }
51
+
52
+ def mask_tokens(
53
+ self, inputs: torch.Tensor
54
+ ) -> tuple[torch.Tensor, torch.Tensor]:
55
+ """
56
+ Prepare masked tokens inputs/labels for masked language modeling.
57
+
58
+ Args:
59
+ inputs: Input tensor
60
+
61
+ Returns:
62
+ Tuple of (masked inputs, labels)
63
+ """
64
+ labels = inputs.clone()
65
+
66
+ # Get probability mask
67
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
68
+
69
+ # Create special tokens mask
70
+ if hasattr(self.tokenizer, "get_special_tokens_mask"):
71
+ special_tokens_mask = [
72
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
73
+ for val in labels.tolist()
74
+ ]
75
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
76
+ else:
77
+ special_tokens_mask = torch.tensor(
78
+ [
79
+ [self._is_special_token(x) for x in val]
80
+ for val in labels.tolist()
81
+ ],
82
+ dtype=torch.bool,
83
+ )
84
+
85
+ # Don't mask special tokens
86
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
87
+
88
+ # Get mask indices
89
+ masked_indices = torch.bernoulli(probability_matrix).bool()
90
+
91
+ # Set labels for non-masked tokens to -100 (ignored in loss)
92
+ labels[~masked_indices] = -100
93
+
94
+ # Set 80% of masked tokens to [MASK]
95
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
96
+
97
+ if hasattr(self.tokenizer, "mask_token_id") and self.tokenizer.mask_token_id is not None:
98
+ inputs[indices_replaced] = self.tokenizer.mask_token_id
99
+
100
+ # Set
utils/convert_checkpoints.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility to convert PyTorch (.pt) checkpoints to Hugging Face (.bin) format
3
+ python -m utils.convert_checkpoints --checkpoints checkpoints/stdp_model_epoch_15.pt checkpoints/stdp_model_epoch_20.pt --output hf_stdp_model
4
+ """
5
+ import os
6
+ import torch
7
+ import logging
8
+ import argparse
9
+ import datetime # Added missing import
10
+ from pathlib import Path
11
+ from typing import Dict, Any, Optional
12
+ import json
13
+ import shutil
14
+
15
+ # Configure logging - Fix the typo in format string (levellevel → levelname)
16
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
17
+ logger = logging.getLogger(__name__)
18
+
19
+ def convert_stdp_checkpoint(
20
+ checkpoint_path: str,
21
+ output_dir: str,
22
+ config_path: Optional[str] = None
23
+ ) -> str:
24
+ """
25
+ Convert STDP/SNN PyTorch checkpoint to Hugging Face format.
26
+
27
+ Args:
28
+ checkpoint_path: Path to the .pt checkpoint file
29
+ output_dir: Directory to save the converted model
30
+ config_path: Optional path to config.json file
31
+
32
+ Returns:
33
+ Path to the converted model directory
34
+ """
35
+ logger.info(f"Converting checkpoint: {checkpoint_path}")
36
+
37
+ # Create output directory
38
+ os.makedirs(output_dir, exist_ok=True)
39
+
40
+ try:
41
+ # Load checkpoint
42
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
43
+
44
+ # Extract epoch from filename if available
45
+ checkpoint_filename = os.path.basename(checkpoint_path)
46
+ epoch = None
47
+ if "epoch_" in checkpoint_filename:
48
+ try:
49
+ epoch = int(checkpoint_filename.split("epoch_")[1].split(".")[0])
50
+ except (ValueError, IndexError):
51
+ pass
52
+
53
+ # Create config for the model
54
+ config = {
55
+ "model_type": "stdp_snn",
56
+ "architectures": ["STDPSpikeNeuralNetwork"],
57
+ "epoch": epoch,
58
+ "original_checkpoint": checkpoint_path,
59
+ "conversion_date": str(datetime.datetime.now())
60
+ }
61
+
62
+ # Update with loaded config if it exists in checkpoint
63
+ if isinstance(checkpoint, dict) and "config" in checkpoint:
64
+ config.update(checkpoint["config"])
65
+
66
+ # Load additional config from file if provided
67
+ if config_path and os.path.exists(config_path):
68
+ with open(config_path, 'r') as f:
69
+ file_config = json.load(f)
70
+ if "STDP_CONFIG" in file_config:
71
+ config.update(file_config["STDP_CONFIG"])
72
+
73
+ # Extract model weights
74
+ model_weights = {}
75
+ if "model_state_dict" in checkpoint:
76
+ model_weights = checkpoint["model_state_dict"]
77
+ elif "state_dict" in checkpoint:
78
+ model_weights = checkpoint["state_dict"]
79
+ elif "weights" in checkpoint:
80
+ model_weights = {"weights": checkpoint["weights"]}
81
+ elif "synaptic_weights" in checkpoint:
82
+ model_weights = {"synaptic_weights": checkpoint["synaptic_weights"]}
83
+ else:
84
+ # If no recognized format, assume the checkpoint itself is the weights
85
+ model_weights = checkpoint
86
+
87
+ # Create model directory structure
88
+ model_dir = os.path.join(output_dir, "pytorch_model.bin")
89
+
90
+ # Save converted weights in HF format
91
+ torch.save(model_weights, model_dir)
92
+ logger.info(f"Saved model weights to {model_dir}")
93
+
94
+ # Save config file
95
+ config_file = os.path.join(output_dir, "config.json")
96
+ with open(config_file, 'w') as f:
97
+ json.dump(config, f, indent=2)
98
+ logger.info(f"Saved model config to {config_file}")
99
+
100
+ # Create a simple README
101
+ readme_file = os.path.join(output_dir, "README.md")
102
+ with open(readme_file, 'w') as f:
103
+ f.write(f"# Converted STDP/SNN Model\n\n")
104
+ f.write(f"This model was converted from PyTorch checkpoint: `{checkpoint_path}`\n\n")
105
+ f.write(f"Converted on: {config['conversion_date']}\n")
106
+ if epoch is not None:
107
+ f.write(f"Training epoch: {epoch}\n")
108
+
109
+ return output_dir
110
+
111
+ except Exception as e:
112
+ logger.error(f"Error converting checkpoint: {e}")
113
+ raise
114
+
115
+ def prepare_for_hf_upload(
116
+ checkpoint_paths: list,
117
+ output_dir: str,
118
+ config_path: Optional[str] = None,
119
+ include_code: bool = True
120
+ ) -> str:
121
+ """
122
+ Prepare multiple checkpoints for HF upload with code.
123
+
124
+ Args:
125
+ checkpoint_paths: List of paths to checkpoint files
126
+ output_dir: Directory to save the prepared model
127
+ config_path: Optional path to config.json file
128
+ include_code: Whether to include inference code
129
+
130
+ Returns:
131
+ Path to the prepared directory
132
+ """
133
+ # Create output directory
134
+ os.makedirs(output_dir, exist_ok=True)
135
+
136
+ # Convert each checkpoint
137
+ converted_models = []
138
+ for cp_path in checkpoint_paths:
139
+ model_name = os.path.splitext(os.path.basename(cp_path))[0]
140
+ model_dir = os.path.join(output_dir, model_name)
141
+ converted_models.append(convert_stdp_checkpoint(cp_path, model_dir, config_path))
142
+
143
+ # Include necessary code files
144
+ if include_code:
145
+ code_files = [
146
+ "communicator_STDP.py",
147
+ "config.py",
148
+ "model_Custm.py"
149
+ ]
150
+
151
+ for file in code_files:
152
+ src_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), file)
153
+ if os.path.exists(src_path):
154
+ dst_path = os.path.join(output_dir, file)
155
+ shutil.copy2(src_path, dst_path)
156
+ logger.info(f"Copied {file} to {dst_path}")
157
+
158
+ # Create an inference script - FIX: Use single quotes for inner docstring
159
+ inference_script = '''
160
+ import torch
161
+ import os
162
+ import json
163
+ import argparse
164
+ from pathlib import Path
165
+
166
+ def load_stdp_model(model_dir):
167
+ """Load STDP model from directory."""
168
+ weights_path = os.path.join(model_dir, "pytorch_model.bin")
169
+ config_path = os.path.join(model_dir, "config.json")
170
+
171
+ # Load weights
172
+ weights = torch.load(weights_path, map_location="cpu")
173
+
174
+ # Load config
175
+ with open(config_path, 'r') as f:
176
+ config = json.load(f)
177
+
178
+ return weights, config
179
+
180
+ def main():
181
+ parser = argparse.ArgumentParser(description="Run inference with STDP model")
182
+ parser.add_argument("--model", type=str, required=True, help="Model directory")
183
+ parser.add_argument("--input", type=str, required=True, help="Input text or file")
184
+ args = parser.parse_args()
185
+
186
+ # Load model
187
+ weights, config = load_stdp_model(args.model)
188
+ print(f"Loaded model from {args.model}")
189
+ print(f"Model config: {json.dumps(config, indent=2)}")
190
+
191
+ # Get input
192
+ input_text = args.input
193
+ if os.path.exists(args.input):
194
+ with open(args.input, 'r') as f:
195
+ input_text = f.read()
196
+
197
+ print(f"Input text: {input_text[:100]}...")
198
+
199
+ # Run inference using communicator_STDP if available
200
+ try:
201
+ from communicator_STDP import CommSTDP
202
+ communicator = CommSTDP({}, device="cpu")
203
+ result = communicator.process(input_text, weights)
204
+ print(f"Result: {result}")
205
+ except ImportError:
206
+ print("communicator_STDP not available. Weights loaded successfully.")
207
+ print(f"Weights shape: {weights.shape if hasattr(weights, 'shape') else '[dict of tensors]'}")
208
+
209
+ if __name__ == "__main__":
210
+ main()
211
+ '''
212
+
213
+ inference_path = os.path.join(output_dir, "inference.py")
214
+ with open(inference_path, 'w') as f:
215
+ f.write(inference_script.strip())
216
+ logger.info(f"Created inference script: {inference_path}")
217
+
218
+ # Create an overall README
219
+ readme_file = os.path.join(output_dir, "README.md")
220
+ with open(readme_file, 'w') as f:
221
+ f.write("# STDP/SNN Trained Models\n\n")
222
+ f.write("This repository contains STDP/SNN models converted from PyTorch checkpoints for use with Hugging Face's infrastructure.\n\n")
223
+ f.write("## Models Included\n\n")
224
+ for i, model in enumerate(converted_models):
225
+ f.write(f"{i+1}. `{os.path.basename(model)}`\n")
226
+
227
+ f.write("\n## Usage\n\n")
228
+ f.write("```python\n")
229
+ f.write("from transformers import AutoModel\n\n")
230
+ f.write("# Load the model\n")
231
+ f.write("model = AutoModel.from_pretrained('your-username/your-model-name')\n")
232
+ f.write("```\n\n")
233
+ f.write("Or use the included inference.py script:\n\n")
234
+ f.write("```bash\npython inference.py --model ./stdp_model_epoch_15 --input 'Your input text here'\n```")
235
+
236
+ logger.info(f"Prepared {len(converted_models)} models for HF upload in {output_dir}")
237
+ return output_dir
238
+
239
+ if __name__ == "__main__":
240
+ parser = argparse.ArgumentParser(description="Convert PyTorch checkpoints to Hugging Face format")
241
+ parser.add_argument("--checkpoints", nargs="+", required=True, help="Paths to checkpoint files")
242
+ parser.add_argument("--output", type=str, default="hf_model", help="Output directory")
243
+ parser.add_argument("--config", type=str, help="Path to config.json file")
244
+ parser.add_argument("--no-code", action="store_true", help="Don't include inference code")
245
+
246
+ args = parser.parse_args()
247
+
248
+ prepare_for_hf_upload(
249
+ args.checkpoints,
250
+ args.output,
251
+ args.config,
252
+ not args.no_code
253
+ )
utils/debug_helper.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import psutil
5
+ import traceback
6
+ import logging
7
+ import threading
8
+ from typing import Dict, Any, Optional, List
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class DebugHelper:
13
+ """
14
+ Helper class for debugging hanging issues in STDP training.
15
+
16
+ This provides tools for:
17
+ 1. Process monitoring and status reporting
18
+ 2. Timeout management
19
+ 3. Recovery mechanisms for hanging processes
20
+ 4. Detailed diagnostics
21
+ """
22
+
23
+ @staticmethod
24
+ def get_process_info(pid: Optional[int] = None) -> Dict[str, Any]:
25
+ """Get detailed information about the current process."""
26
+ pid = pid or os.getpid()
27
+
28
+ try:
29
+ process = psutil.Process(pid)
30
+
31
+ # Get basic process info
32
+ info = {
33
+ 'pid': pid,
34
+ 'name': process.name(),
35
+ 'status': process.status(),
36
+ 'cpu_percent': process.cpu_percent(),
37
+ 'memory_percent': process.memory_percent(),
38
+ 'memory_info': dict(process.memory_info()._asdict()),
39
+ 'create_time': process.create_time(),
40
+ 'runtime': time.time() - process.create_time(),
41
+ 'num_threads': process.num_threads(),
42
+ 'open_files': len(process.open_files()),
43
+ 'connections': len(process.connections()),
44
+ }
45
+
46
+ # Get thread details
47
+ try:
48
+ import threading
49
+ info['active_threads'] = [t.name for t in threading.enumerate()]
50
+ except:
51
+ info['active_threads'] = "Could not retrieve thread information"
52
+
53
+ # Get current stack trace
54
+ info['current_stack'] = traceback.format_stack()
55
+
56
+ return info
57
+
58
+ except Exception as e:
59
+ logger.error(f"Error getting process info: {e}")
60
+ return {'error': str(e)}
61
+
62
+ @staticmethod
63
+ def check_resource_leaks() -> Dict[str, Any]:
64
+ """Check for potential resource leaks."""
65
+ import gc
66
+
67
+ leaks = {
68
+ 'gc_counts': gc.get_count(),
69
+ 'gc_objects': len(gc.get_objects()),
70
+ }
71
+
72
+ # Check for torch memory usage if available
73
+ try:
74
+ import torch
75
+ if torch.cuda.is_available():
76
+ leaks['torch_memory_allocated'] = torch.cuda.memory_allocated()
77
+ leaks['torch_memory_reserved'] = torch.cuda.memory_reserved()
78
+ leaks['torch_max_memory_allocated'] = torch.cuda.max_memory_allocated()
79
+ except ImportError:
80
+ pass
81
+
82
+ return leaks
83
+
84
+ @staticmethod
85
+ def register_timeout(seconds: int, callback=None):
86
+ """Register a timeout that calls the callback after specified seconds."""
87
+ def _timeout_handler():
88
+ time.sleep(seconds)
89
+ if callback:
90
+ callback()
91
+ else:
92
+ print(f"TIMEOUT: Operation took longer than {seconds} seconds")
93
+ info = DebugHelper.get_process_info()
94
+ print(f"Process info: {info}")
95
+ traceback.print_stack()
96
+
97
+ thread = threading.Thread(target=_timeout_handler)
98
+ thread.daemon = True
99
+ thread.start()
100
+ return thread
101
+
102
+ @staticmethod
103
+ def dump_debug_info(filename: str):
104
+ """Dump debug information to a file."""
105
+ process_info = DebugHelper.get_process_info()
106
+ leak_info = DebugHelper.check_resource_leaks()
107
+
108
+ with open(filename, 'w') as f:
109
+ f.write("===== PROCESS INFORMATION =====\n")
110
+ for key, value in process_info.items():
111
+ f.write(f"{key}: {value}\n")
112
+
113
+ f.write("\n===== RESOURCE LEAK INFORMATION =====\n")
114
+ for key, value in leak_info.items():
115
+ f.write(f"{key}: {value}\n")
116
+
117
+ f.write("\n===== ENVIRONMENT VARIABLES =====\n")
118
+ for key, value in os.environ.items():
119
+ f.write(f"{key}: {value}\n")
120
+
121
+ f.write("\n===== STACK TRACE =====\n")
122
+ f.write(''.join(traceback.format_stack()))
123
+
124
+ logger.info(f"Debug info dumped to {filename}")
utils/dual_encoder_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for dual encoder configuration and initialization.
3
+ """
4
+ import logging
5
+ from typing import Dict, Any, Optional, Union
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from config import load_config, app_config
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class DualEncoderConfig:
14
+ """Configuration object for dual encoders"""
15
+
16
+ def __init__(self, config_dict: Optional[Dict[str, Any]] = None):
17
+ """
18
+ Initialize dual encoder configuration.
19
+
20
+ Args:
21
+ config_dict: Optional configuration dictionary to override defaults
22
+ """
23
+ # Set defaults from app_config
24
+ config = load_config()
25
+
26
+ # Default configuration
27
+ self.USE_PRETRAINED_ENCODER = True
28
+ self.USE_CUSTOM_ENCODER = True
29
+ self.FUSION_METHOD = "concat" # Options: concat, add, weighted_sum
30
+ self.FUSION_WEIGHTS = [0.5, 0.5] # Weights for pretrained and custom encoders
31
+ self.TRAINING_MODE = "joint" # Options: joint, alternating, pretrained_first
32
+
33
+ # Override defaults with app_config if available
34
+ if hasattr(config, "DUAL_ENCODER_CONFIG"):
35
+ dual_config = config.DUAL_ENCODER_CONFIG
36
+ for key, value in dual_config.items():
37
+ setattr(self, key, value)
38
+
39
+ # Override with provided config_dict if available
40
+ if config_dict:
41
+ for key, value in config_dict.items():
42
+ setattr(self, key, value)
43
+
44
+ class DualEncoderFusion(nn.Module):
45
+ """
46
+ Module that combines outputs from pretrained and custom encoders.
47
+ """
48
+
49
+ def __init__(self, config: Optional[Union[Dict[str, Any], DualEncoderConfig]] = None):
50
+ """
51
+ Initialize fusion module.
52
+
53
+ Args:
54
+ config: Configuration for fusion (dict or DualEncoderConfig object)
55
+ """
56
+ super().__init__()
57
+
58
+ # Convert dict to DualEncoderConfig if needed
59
+ if isinstance(config, dict):
60
+ self.config = DualEncoderConfig(config)
61
+ elif config is None:
62
+ self.config = DualEncoderConfig()
63
+ else:
64
+ self.config = config
65
+
66
+ # Initialize fusion weights if using weighted sum
67
+ if self.config.FUSION_METHOD == "weighted_sum":
68
+ weights = torch.tensor(self.config.FUSION_WEIGHTS, dtype=torch.float32)
69
+ self.register_buffer('fusion_weights', weights / weights.sum())
70
+
71
+ def forward(self, pretrained_output: torch.Tensor, custom_output: torch.Tensor) -> torch.Tensor:
72
+ """
73
+ Combine encoder outputs based on fusion method.
74
+
75
+ Args:
76
+ pretrained_output: Output from pretrained encoder
77
+ custom_output: Output from custom encoder
78
+
79
+ Returns:
80
+ Combined tensor
81
+ """
82
+ # Handle the case where one encoder is disabled
83
+ if not self.config.USE_PRETRAINED_ENCODER:
84
+ return custom_output
85
+ if not self.config.USE_CUSTOM_ENCODER:
86
+ return pretrained_output
87
+
88
+ # Apply fusion method
89
+ if self.config.FUSION_METHOD == "concat":
90
+ return torch.cat([pretrained_output, custom_output], dim=-1)
91
+ elif self.config.FUSION_METHOD == "add":
92
+ # Ensure dimensions match
93
+ if pretrained_output.shape != custom_output.shape:
94
+ raise ValueError(f"Cannot add tensors with different shapes: {pretrained_output.shape} and {custom_output.shape}")
95
+ return pretrained_output + custom_output
96
+ elif self.config.FUSION_METHOD == "weighted_sum":
97
+ # Ensure dimensions match
98
+ if pretrained_output.shape != custom_output.shape:
99
+ raise ValueError(f"Cannot use weighted sum with different shapes: {pretrained_output.shape} and {custom_output.shape}")
100
+ # Apply weights
101
+ w1, w2 = self.fusion_weights
102
+ return w1 * pretrained_output + w2 * custom_output
103
+ else:
104
+ raise ValueError(f"Unknown fusion method: {self.config.FUSION_METHOD}")
105
+
106
+ def get_dual_encoder_config() -> DualEncoderConfig:
107
+ """
108
+ Get dual encoder configuration from app_config.
109
+
110
+ Returns:
111
+ DualEncoderConfig object
112
+ """
113
+ return DualEncoderConfig()
114
+
115
+ # Testing function
116
+ def test_fusion_methods():
117
+ """Test different fusion methods"""
118
+ config = DualEncoderConfig()
119
+
120
+ # Create test tensors
121
+ x1 = torch.randn(2, 10, 768)
122
+ x2 = torch.randn(2, 10, 768)
123
+
124
+ # Test concat fusion
125
+ config.FUSION_METHOD = "concat"
126
+ fusion_concat = DualEncoderFusion(config)
127
+ output_concat = fusion_concat(x1, x2)
128
+ print(f"Concat output shape: {output_concat.shape}") # Should be [2, 10, 1536]
129
+
130
+ # Test add fusion
131
+ config.FUSION_METHOD = "add"
132
+ fusion_add = DualEncoderFusion(config)
133
+ output_add = fusion_add(x1, x2)
134
+ print(f"Add output shape: {output_add.shape}") # Should be [2, 10, 768]
135
+
136
+ # Test weighted sum fusion
137
+ config.FUSION_METHOD = "weighted_sum"
138
+ config.FUSION_WEIGHTS = [0.7, 0.3]
139
+ fusion_weighted = DualEncoderFusion(config)
140
+ output_weighted = fusion_weighted(x1, x2)
141
+ print(f"Weighted sum output shape: {output_weighted.shape}") # Should be [2, 10, 768]
142
+
143
+ return {
144
+ "concat": output_concat,
145
+ "add": output_add,
146
+ "weighted_sum": output_weighted
147
+ }
148
+
149
+ if __name__ == "__main__":
150
+ # Run tests
151
+ test_results = test_fusion_methods()
utils/emergency_abort.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import signal
5
+ import logging
6
+ import threading
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class EmergencyAbort:
11
+ """Creates an abort file that can be touched to trigger process termination."""
12
+
13
+ def __init__(self, abort_file="emergency_abort.txt", check_interval=5):
14
+ self.abort_file = abort_file
15
+ self.check_interval = check_interval
16
+ self.running = False
17
+ self.thread = None
18
+
19
+ # Create initial abort file with instructions
20
+ with open(self.abort_file, 'w') as f:
21
+ f.write("# Emergency Abort File\n")
22
+ f.write("# To abort training, update the timestamp of this file\n")
23
+ f.write(f"# Last checked: {time.ctime()}\n")
24
+
25
+ def _check_file(self):
26
+ last_modified = os.path.getmtime(self.abort_file)
27
+
28
+ while self.running:
29
+ time.sleep(self.check_interval)
30
+
31
+ try:
32
+ current_modified = os.path.getmtime(self.abort_file)
33
+
34
+ if current_modified > last_modified:
35
+ logger.warning("Emergency abort file modified! Initiating abort sequence.")
36
+ # Kill this process
37
+ os.kill(os.getpid(), signal.SIGTERM)
38
+ return
39
+
40
+ # Update check timestamp in file
41
+ with open(self.abort_file, 'w') as f:
42
+ f.write("# Emergency Abort File\n")
43
+ f.write("# To abort training, update the timestamp of this file\n")
44
+ f.write(f"# Last checked: {time.ctime()}\n")
45
+
46
+ last_modified = current_modified
47
+
48
+ except Exception as e:
49
+ logger.error(f"Error checking abort file: {e}")
50
+
51
+ def start(self):
52
+ """Start the abort file monitor."""
53
+ self.running = True
54
+ self.thread = threading.Thread(target=self._check_file)
55
+ self.thread.daemon = True
56
+ self.thread.start()
57
+ logger.info(f"Emergency abort monitor started. Modify {self.abort_file} to terminate training.")
58
+ return self
59
+
60
+ def stop(self):
61
+ """Stop the abort file monitor."""
62
+ self.running = False
63
+ if self.thread:
64
+ self.thread.join(timeout=2)
65
+ logger.info("Emergency abort monitor stopped.")
utils/event_bus.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple event bus for direct component-to-component communication.
3
+ Provides a lightweight alternative to the full EventSystem.
4
+ """
5
+ import logging
6
+ from typing import Dict, List, Callable, Any
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class EventBus:
11
+ """Simple synchronous event bus for direct event routing."""
12
+
13
+ def __init__(self):
14
+ """Initialize an empty event bus."""
15
+ self.subscribers = {}
16
+ logger.info("Initialized EventBus")
17
+
18
+ def subscribe(self, event_type: str, callback: Callable[[str, Any], None]) -> None:
19
+ """Subscribe a callback to an event type."""
20
+ if event_type not in self.subscribers:
21
+ self.subscribers[event_type] = []
22
+ self.subscribers[event_type].append(callback)
23
+ logger.debug(f"Added subscriber to {event_type}")
24
+
25
+ def unsubscribe(self, event_type: str, callback: Callable[[str, Any], None]) -> None:
26
+ """Unsubscribe a callback from an event type."""
27
+ if event_type in self.subscribers and callback in self.subscribers[event_type]:
28
+ self.subscribers[event_type].remove(callback)
29
+ logger.debug(f"Removed subscriber from {event_type}")
30
+
31
+ def publish(self, event_type: str, data: Any = None) -> None:
32
+ """Synchronously publish an event to all subscribers."""
33
+ subscribers = self.subscribers.get(event_type, []).copy()
34
+
35
+ if not subscribers:
36
+ logger.debug(f"No subscribers for event {event_type}")
37
+ return
38
+
39
+ logger.debug(f"Dispatching event {event_type} to {len(subscribers)} subscribers")
40
+
41
+ for callback in subscribers:
42
+ try:
43
+ callback(event_type, data)
44
+ except Exception as e:
45
+ logger.error(f"Error in subscriber callback: {e}")
46
+
47
+ # Create a global instance for convenience
48
+ event_bus = EventBus()
utils/event_system.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Event system module for enabling parallel processing across components.
3
+ Implements a publisher-subscriber pattern to decouple components.
4
+ """
5
+ import logging
6
+ import threading
7
+ import queue
8
+ import time
9
+ from typing import Dict, List, Callable, Any, Optional, Set
10
+ import concurrent.futures
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class Event:
15
+ """Event class containing event data and metadata."""
16
+ def __init__(self, event_type: str, data: Any = None, source: str = None):
17
+ self.event_type = event_type
18
+ self.data = data
19
+ self.source = source
20
+ self.timestamp = time.time()
21
+
22
+ def __repr__(self) -> str:
23
+ return f"Event({self.event_type}, source={self.source}, timestamp={self.timestamp})"
24
+
25
+ class EventSystem:
26
+ """Event system for parallel processing of prompts and responses."""
27
+
28
+ def __init__(self, max_workers: int = 4):
29
+ """Initialize the event system."""
30
+ self.subscribers: Dict[str, List[Callable[[Event], None]]] = {}
31
+ self.lock = threading.RLock() # Reentrant lock for thread safety
32
+ self.event_queue = queue.Queue()
33
+ self.running = False
34
+ self.dispatcher_thread = None
35
+ self.max_workers = max_workers
36
+ self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
37
+ self.futures = set()
38
+ logger.info(f"Initialized EventSystem with {max_workers} workers")
39
+
40
+ def subscribe(self, event_type: str, callback: Callable[[Event], None]) -> None:
41
+ """Subscribe a callback to a specific event type."""
42
+ with self.lock:
43
+ if event_type not in self.subscribers:
44
+ self.subscribers[event_type] = []
45
+ self.subscribers[event_type].append(callback)
46
+ logger.debug(f"Added subscriber to {event_type}, total: {len(self.subscribers[event_type])}")
47
+
48
+ def unsubscribe(self, event_type: str, callback: Callable[[Event], None]) -> None:
49
+ """Unsubscribe a callback from a specific event type."""
50
+ with self.lock:
51
+ if event_type in self.subscribers and callback in self.subscribers[event_type]:
52
+ self.subscribers[event_type].remove(callback)
53
+ logger.debug(f"Removed subscriber from {event_type}, remaining: {len(self.subscribers[event_type])}")
54
+
55
+ def publish(self, event: Event) -> None:
56
+ """Publish an event to all subscribers."""
57
+ self.event_queue.put(event)
58
+ logger.debug(f"Published event: {event}")
59
+
60
+ # Start dispatcher if not running
61
+ with self.lock:
62
+ if not self.running:
63
+ self.start()
64
+
65
+ def publish_from_dict(self, event_type: str, data: Dict[str, Any], source: str = None) -> None:
66
+ """Convenient method to publish an event from a dictionary."""
67
+ event = Event(event_type, data, source)
68
+ self.publish(event)
69
+
70
+ def start(self) -> None:
71
+ """Start the event dispatcher thread."""
72
+ with self.lock:
73
+ if not self.running:
74
+ self.running = True
75
+ self.dispatcher_thread = threading.Thread(target=self._dispatch_events)
76
+ self.dispatcher_thread.daemon = True
77
+ self.dispatcher_thread.start()
78
+ logger.info("Event dispatcher thread started")
79
+
80
+ def stop(self) -> None:
81
+ """Stop the event dispatcher thread."""
82
+ with self.lock:
83
+ if self.running:
84
+ self.running = False
85
+ self.event_queue.put(None) # Sentinel to stop the thread
86
+ if self.dispatcher_thread and self.dispatcher_thread.is_alive():
87
+ self.dispatcher_thread.join(timeout=2.0)
88
+ logger.info("Event dispatcher thread stopped")
89
+
90
+ # Shut down thread pool
91
+ self.thread_pool.shutdown(wait=False)
92
+
93
+ def _dispatch_events(self) -> None:
94
+ """Dispatcher thread that processes events from the queue."""
95
+ while self.running:
96
+ try:
97
+ # Get next event with timeout to allow checking running flag
98
+ event = self.event_queue.get(timeout=0.5)
99
+
100
+ # Handle sentinel value
101
+ if event is None:
102
+ break
103
+
104
+ # Process the event
105
+ self._process_event(event)
106
+
107
+ # Mark task as done
108
+ self.event_queue.task_done()
109
+ except queue.Empty:
110
+ continue
111
+ except Exception as e:
112
+ logger.error(f"Error in event dispatcher: {e}")
113
+
114
+ logger.info("Event dispatcher thread exiting")
115
+
116
+ def _process_event(self, event: Event) -> None:
117
+ """Process a single event by notifying subscribers."""
118
+ with self.lock:
119
+ # Get subscribers for this event type
120
+ subscribers = self.subscribers.get(event.event_type, []).copy()
121
+ # Also check for wildcard subscribers
122
+ wildcard_subscribers = self.subscribers.get("*", []).copy()
123
+ all_subscribers = subscribers + wildcard_subscribers
124
+
125
+ if not all_subscribers:
126
+ logger.debug(f"No subscribers for event {event.event_type}")
127
+ return
128
+
129
+ logger.debug(f"Dispatching event {event.event_type} to {len(all_subscribers)} subscribers")
130
+
131
+ # Submit a task to the thread pool for each subscriber
132
+ for callback in all_subscribers:
133
+ future = self.thread_pool.submit(self._safe_callback, callback, event)
134
+ self.futures.add(future)
135
+ future.add_done_callback(lambda f: self.futures.remove(f))
136
+
137
+ def _safe_callback(self, callback: Callable[[Event], None], event: Event) -> None:
138
+ """Execute a callback safely, catching exceptions."""
139
+ try:
140
+ callback(event)
141
+ except Exception as e:
142
+ logger.error(f"Error in subscriber callback: {e}")
143
+
144
+ def wait_for_all_events(self, timeout: Optional[float] = None) -> bool:
145
+ """Wait for all pending events to be processed."""
146
+ try:
147
+ self.event_queue.join()
148
+
149
+ # Also wait for all futures to complete
150
+ done, not_done = concurrent.futures.wait(
151
+ self.futures,
152
+ timeout=timeout,
153
+ return_when=concurrent.futures.ALL_COMPLETED
154
+ )
155
+
156
+ return len(not_done) == 0
157
+ except Exception as e:
158
+ logger.error(f"Error waiting for events: {e}")
159
+ return False
160
+
161
+ # Common event types
162
+ EVENT_USER_INPUT = "user_input"
163
+ EVENT_MODEL_REQUEST = "model_request"
164
+ EVENT_MODEL_RESPONSE = "model_response"
165
+ EVENT_STDP_REQUEST = "stdp_request"
166
+ EVENT_STDP_RESPONSE = "stdp_response"
167
+ EVENT_TOKEN_GENERATED = "token_generated"
168
+ EVENT_RESPONSE_COMPLETE = "response_complete"
169
+ EVENT_ERROR = "error"
utils/gpu_config_optimizer.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility to optimize transformer configuration for GPU memory constraints.
3
+ """
4
+ import os
5
+ import json
6
+ import torch
7
+ import logging
8
+ from pathlib import Path
9
+ import argparse
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def optimize_config_for_gpu(config_path, target_vram_mb=None, batch_reduction_factor=0.5):
16
+ """
17
+ Optimize config settings for the available GPU memory.
18
+
19
+ Args:
20
+ config_path: Path to the config.json file
21
+ target_vram_mb: Target VRAM usage in MB (if None, will use 80% of available)
22
+ batch_reduction_factor: How much to reduce batch size (0.5 = half)
23
+ """
24
+ # Get current GPU memory capacity
25
+ if torch.cuda.is_available():
26
+ device = torch.cuda.current_device()
27
+ gpu_properties = torch.cuda.get_device_properties(device)
28
+ total_memory = gpu_properties.total_memory / (1024 * 1024) # Convert to MB
29
+ gpu_name = gpu_properties.name
30
+ logger.info(f"GPU detected: {gpu_name} with {total_memory:.0f}MB VRAM")
31
+
32
+ # Set target memory if not specified (80% of available)
33
+ if target_vram_mb is None:
34
+ target_vram_mb = int(total_memory * 0.8)
35
+ else:
36
+ logger.warning("No GPU detected, using conservative settings for CPU")
37
+ target_vram_mb = 2048 # Conservative default for CPU
38
+ gpu_name = "CPU"
39
+
40
+ logger.info(f"Target VRAM usage: {target_vram_mb}MB")
41
+
42
+ # Load current config
43
+ with open(config_path, 'r') as f:
44
+ config = json.load(f)
45
+
46
+ # Store original values for reporting
47
+ original_batch_size = config["TRANSFORMER_CONFIG"]["BATCH_SIZE"]
48
+ original_sequence_length = config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"]
49
+ original_num_layers = config["TRANSFORMER_CONFIG"]["NUM_LAYERS"]
50
+
51
+ # Adjust batch size based on GPU
52
+ if torch.cuda.is_available():
53
+ # RTX 4050 specific optimizations (around 6GB VRAM)
54
+ if "4050" in gpu_name or total_memory < 7000:
55
+ # Significant reductions needed for 4050
56
+ config["TRANSFORMER_CONFIG"]["BATCH_SIZE"] = max(4, int(original_batch_size * batch_reduction_factor))
57
+
58
+ # If still too large after batch reduction, reduce sequence length too
59
+ if target_vram_mb < 5500: # RTX 4050 has ~6GB VRAM
60
+ # Maybe reduce sequence length for very large inputs
61
+ if config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] > 256:
62
+ config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] = 256
63
+
64
+ # Maybe reduce model complexity if still needed
65
+ if config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] > 6:
66
+ config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] = 6
67
+
68
+ # Enable gradient checkpointing (trades compute for memory)
69
+ if "OPTIMIZATION" not in config:
70
+ config["OPTIMIZATION"] = {}
71
+
72
+ config["OPTIMIZATION"]["USE_GRADIENT_CHECKPOINTING"] = True
73
+ config["OPTIMIZATION"]["USE_MIXED_PRECISION"] = True
74
+
75
+ # Create optimized filename with the GPU type
76
+ gpu_name_simple = gpu_name.replace(" ", "_").lower()
77
+ opt_config_path = config_path.replace(".json", f"_{gpu_name_simple}_optimized.json")
78
+
79
+ # Save optimized config
80
+ with open(opt_config_path, 'w') as f:
81
+ json.dump(config, f, indent=2)
82
+
83
+ # Report changes
84
+ logger.info(f"Optimized configuration saved to: {opt_config_path}")
85
+ logger.info("Changes made:")
86
+ logger.info(f" - Batch size: {original_batch_size} → {config['TRANSFORMER_CONFIG']['BATCH_SIZE']}")
87
+ logger.info(f" - Sequence length: {original_sequence_length} → {config['TRANSFORMER_CONFIG']['MAX_SEQ_LENGTH']}")
88
+ logger.info(f" - Num layers: {original_num_layers} → {config['TRANSFORMER_CONFIG']['NUM_LAYERS']}")
89
+ logger.info(f" - Gradient checkpointing: Enabled")
90
+ logger.info(f" - Mixed precision: Enabled")
91
+
92
+ return opt_config_path
93
+
94
+ def apply_optimized_config(opt_config_path):
95
+ """
96
+ Apply the optimized config by updating the main config file.
97
+ """
98
+ # Load optimized config
99
+ with open(opt_config_path, 'r') as f:
100
+ opt_config = json.load(f)
101
+
102
+ # Get main config path
103
+ main_config_path = os.path.join(os.path.dirname(os.path.dirname(opt_config_path)), "config.json")
104
+
105
+ # Backup original config
106
+ backup_path = main_config_path + '.backup'
107
+ if not os.path.exists(backup_path):
108
+ import shutil
109
+ shutil.copy2(main_config_path, backup_path)
110
+ logger.info(f"Original config backed up to: {backup_path}")
111
+
112
+ # Apply optimized config
113
+ with open(main_config_path, 'w') as f:
114
+ json.dump(opt_config, f, indent=2)
115
+
116
+ logger.info(f"Applied optimized config to {main_config_path}")
117
+
118
+ if __name__ == "__main__":
119
+ parser = argparse.ArgumentParser(description="Optimize transformer config for GPU memory constraints")
120
+ parser.add_argument("--config", type=str, default="config.json", help="Path to config file")
121
+ parser.add_argument("--apply", action="store_true", help="Apply optimized config")
122
+ parser.add_argument("--batch-factor", type=float, default=0.5, help="Batch size reduction factor")
123
+ parser.add_argument("--target-vram", type=int, default=None, help="Target VRAM usage in MB")
124
+
125
+ args = parser.parse_args()
126
+
127
+ # Resolve config path
128
+ if not os.path.isabs(args.config):
129
+ config_dir = Path(__file__).resolve().parent.parent
130
+ config_path = os.path.join(config_dir, args.config)
131
+ else:
132
+ config_path = args.config
133
+
134
+ # Optimize config
135
+ opt_config_path = optimize_config_for_gpu(
136
+ config_path,
137
+ target_vram_mb=args.target_vram,
138
+ batch_reduction_factor=args.batch_factor
139
+ )
140
+
141
+ # Apply if requested
142
+ if args.apply:
143
+ apply_optimized_config(opt_config_path)
utils/model_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Optional, Tuple, Dict, Any
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ # List of validated model types from Hugging Face
8
+ VALIDATED_MODEL_TYPES = [
9
+ 'bert', 'roberta', 'distilbert', 'gpt2', 't5', 'albert',
10
+ 'xlm-roberta', 'bart', 'electra', 'xlnet'
11
+ ]
12
+
13
+ def validate_model_name(model_name: str) -> Tuple[bool, Optional[str]]:
14
+ """
15
+ Validates if a model name is recognized in the Hugging Face model registry.
16
+
17
+ Args:
18
+ model_name: Name of the model to validate
19
+
20
+ Returns:
21
+ Tuple containing:
22
+ - Boolean indicating if the model is valid
23
+ - Recommended fallback model name if the original is invalid, None otherwise
24
+ """
25
+ # Check if model name contains any known model type
26
+ is_valid = any(model_type in model_name.lower() for model_type in VALIDATED_MODEL_TYPES)
27
+
28
+ # Return appropriate fallback based on failure reason
29
+ if not is_valid:
30
+ return False, 'bert-base-uncased' # Default fallback
31
+
32
+ return True, None
33
+
34
+ def get_safe_model_name(config):
35
+ """
36
+ Get a validated and sanitized model name from config.
37
+
38
+ Args:
39
+ config: Either a config dictionary or a string model name
40
+
41
+ Returns:
42
+ str: A sanitized model name
43
+ """
44
+ # Handle string input directly
45
+ if isinstance(config, str):
46
+ model_name = config
47
+ else:
48
+ # Handle dictionary input (original behavior)
49
+ model_name = config.get('MODEL_NAME', 'bert-base-uncased')
50
+
51
+ # Validate the model name
52
+ is_valid, fallback = validate_model_name(model_name)
53
+
54
+ # Return original name if valid, otherwise return fallback
55
+ return model_name if is_valid else fallback
56
+
57
+ def create_model_config_json(model_dir: str, model_type: str = 'bert') -> None:
58
+ """
59
+ Creates a config.json file for a custom model with proper model_type key.
60
+
61
+ Args:
62
+ model_dir: Directory where model is/will be stored
63
+ model_type: The type of model (e.g., 'bert', 'roberta')
64
+ """
65
+ import json
66
+
67
+ if not os.path.exists(model_dir):
68
+ os.makedirs(model_dir)
69
+
70
+ config_path = os.path.join(model_dir, 'config.json')
71
+
72
+ # Create a minimal config with the required model_type key
73
+ config = {
74
+ "model_type": model_type,
75
+ "architectures": [f"{model_type.capitalize()}Model"],
76
+ "hidden_size": 768,
77
+ "num_attention_heads": 12,
78
+ "num_hidden_layers": 12
79
+ }
80
+
81
+ with open(config_path, 'w') as f:
82
+ json.dump(config, f, indent=2)
83
+
84
+ logger.info(f"Created model config.json with model_type: {model_type} in {model_dir}")
utils/nltk_stub.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stub implementation of NLTK to avoid dependencies in container environments
3
+ """
4
+
5
+ import logging
6
+ logger = logging.getLogger(__name__)
7
+
8
+ logger.info("Using stub NLTK implementation")
9
+
10
+ # Stub for download - do nothing
11
+ def download(*args, **kwargs):
12
+ logger.warning("NLTK download stub called - no actual download performed")
13
+ return True
14
+
15
+ # Add the missing SimpleTokenizer class
16
+ class SimpleTokenizer:
17
+ """A simple tokenizer implementation for the NLTK stub"""
18
+
19
+ def __init__(self):
20
+ logger.info("Stub SimpleTokenizer initialized")
21
+
22
+ def tokenize(self, text):
23
+ """Simple word tokenization by whitespace"""
24
+ return text.split()
25
+
26
+ # Tokenization stubs
27
+ class WordTokenizer:
28
+ def tokenize(self, text):
29
+ return text.split()
30
+
31
+ def word_tokenize(text):
32
+ return text.split()
33
+
34
+ class SentenceTokenizer:
35
+ def tokenize(self, text):
36
+ return text.split('.')
37
+
38
+ def sent_tokenize(text):
39
+ return text.split('.')
40
+
41
+ # Stemmer stubs
42
+ class PorterStemmer:
43
+ def stem(self, word):
44
+ # Very basic stemming
45
+ if word.endswith('ing'):
46
+ return word[:-3]
47
+ elif word.endswith('ed'):
48
+ return word[:-2]
49
+ elif word.endswith('s') and not word.endswith('ss'):
50
+ return word[:-1]
51
+ return word
52
+
53
+ class LancasterStemmer:
54
+ def stem(self, word):
55
+ return PorterStemmer().stem(word)
56
+
57
+ class SimpleStemmer:
58
+ def __init__(self):
59
+ logger.info("SimpleStemmer stub initialized")
60
+
61
+ def stem(self, word):
62
+ # Very basic stemming: remove common endings
63
+ if word.endswith('ing'):
64
+ return word[:-3]
65
+ elif word.endswith('ed'):
66
+ return word[:-2]
67
+ elif word.endswith('s') and not word.endswith('ss'):
68
+ return word[:-1]
69
+ return word
70
+
71
+ # Stub WordNetLemmatizer class
72
+ class WordNetLemmatizer:
73
+ def __init__(self):
74
+ logger.info("Stub WordNetLemmatizer initialized")
75
+
76
+ def lemmatize(self, word, pos=None):
77
+ # Just return the word as is
78
+ return word
79
+
80
+ # Namespace stubs for import compatibility
81
+ class tokenize:
82
+ WordTokenizer = WordTokenizer
83
+ SentenceTokenizer = SentenceTokenizer
84
+ word_tokenize = word_tokenize
85
+ sent_tokenize = sent_tokenize
86
+
87
+ class stem:
88
+ PorterStemmer = PorterStemmer
89
+ LancasterStemmer = LancasterStemmer
90
+ SimpleStemmer = SimpleStemmer
91
+
92
+ # Stub for corpus
93
+ class _CorpusModule:
94
+ class stopwords:
95
+ @staticmethod
96
+ def words(language="english"):
97
+ # Return basic English stopwords
98
+ return {
99
+ "i", "me", "my", "myself", "we", "our", "ours", "ourselves",
100
+ "you", "your", "yours", "yourself", "yourselves", "he", "him",
101
+ "his", "himself", "she", "her", "hers", "herself", "it", "its",
102
+ "itself", "they", "them", "their", "theirs", "themselves",
103
+ "what", "which", "who", "whom", "this", "that", "these",
104
+ "those", "am", "is", "are", "was", "were", "be", "been",
105
+ "being", "have", "has", "had", "having", "do", "does", "did",
106
+ "doing", "a", "an", "the", "and", "but", "if", "or", "because",
107
+ "as", "until", "while", "of", "at", "by", "for", "with",
108
+ "about", "against", "between", "into", "through", "during",
109
+ "before", "after", "above", "below", "to", "from", "up", "down",
110
+ "in", "out", "on", "off", "over", "under", "again", "further",
111
+ "then", "once", "here", "there", "when", "where", "why", "how",
112
+ "all", "any", "both", "each", "few", "more", "most", "other",
113
+ "some", "such", "no", "nor", "not", "only", "own", "same", "so",
114
+ "than", "too", "very", "s", "t", "can", "will", "just", "don",
115
+ "should", "now", "d", "ll", "m", "o", "re", "ve", "y", "ain",
116
+ "aren", "couldn", "didn", "doesn", "hadn", "hasn", "haven",
117
+ }
118
+
119
+ corpus = _CorpusModule()
utils/output_formatter.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from typing import Optional, Dict, Any, List
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ class OutputFormatter:
8
+ """
9
+ Formats model responses for better presentation and usability.
10
+ """
11
+
12
+ def __init__(self):
13
+ """
14
+ Initialize the OutputFormatter.
15
+ """
16
+ self.post_processors = {
17
+ "programming_software_dev": self._format_code,
18
+ "mbpp": self._format_code,
19
+ "machine_learning_ai_data_science": self._format_technical_content,
20
+ "mathematics": self._format_equations,
21
+ "default": self._default_formatter
22
+ }
23
+ logger.info("OutputFormatter initialized")
24
+
25
+ def format_response(self, response: str, specialization: Optional[str] = None) -> str:
26
+ """
27
+ Format the model response based on specialization.
28
+
29
+ Args:
30
+ response: The raw response from the model
31
+ specialization: The specialization area (optional)
32
+
33
+ Returns:
34
+ Formatted response
35
+ """
36
+ if not response:
37
+ return ""
38
+
39
+ # Apply basic formatting to all responses
40
+ formatted_response = self._clean_whitespace(response)
41
+
42
+ # Apply specialization-specific formatting
43
+ processor = self.post_processors.get(specialization, self.post_processors["default"])
44
+ formatted_response = processor(formatted_response)
45
+
46
+ return formatted_response
47
+
48
+ def _clean_whitespace(self, text: str) -> str:
49
+ """
50
+ Clean up excessive whitespace.
51
+ """
52
+ # Replace multiple newlines with double newlines
53
+ text = re.sub(r'\n{3,}', '\n\n', text)
54
+ # Replace multiple spaces with a single space
55
+ text = re.sub(r' {2,}', ' ', text)
56
+ return text.strip()
57
+
58
+ def _format_code(self, text: str) -> str:
59
+ """
60
+ Format code blocks with proper syntax highlighting markers.
61
+ """
62
+ # Identify unmarked code blocks and add markdown code block syntax
63
+ # Look for patterns that suggest code (indentation, common programming keywords)
64
+ code_patterns = [
65
+ r'((?:^|\n)(?:def |class |import |function |public |private |var |let |const |if |for |while ).+(?:\n[ \t]+.+)+)',
66
+ r'((?:^|\n)(?:SELECT |INSERT |UPDATE |DELETE |CREATE |ALTER |DROP ).+(?:;)(?:\n|$))'
67
+ ]
68
+
69
+ for pattern in code_patterns:
70
+ def add_code_markers(match):
71
+ code_block = match.group(1)
72
+ # Try to determine the language based on keywords
73
+ lang = self._detect_language(code_block)
74
+ return f"\n```{lang}\n{code_block}\n```\n"
75
+
76
+ text = re.sub(pattern, add_code_markers, text)
77
+
78
+ return text
79
+
80
+ def _detect_language(self, code_block: str) -> str:
81
+ """
82
+ Attempt to detect the programming language from a code block.
83
+ """
84
+ if re.search(r'def |class |import |if __name__ ==|print\(', code_block):
85
+ return "python"
86
+ elif re.search(r'function |var |const |let |=> |document\.', code_block):
87
+ return "javascript"
88
+ elif re.search(r'public |private |class .+ {|void |String |int |boolean', code_block):
89
+ return "java"
90
+ elif re.search(r'#include|int main|std::|printf|scanf', code_block):
91
+ return "c++"
92
+ elif re.search(r'SELECT |INSERT |UPDATE |DELETE |CREATE TABLE|ALTER TABLE', code_block):
93
+ return "sql"
94
+ else:
95
+ return "" # Generic code block
96
+
97
+ def _format_equations(self, text: str) -> str:
98
+ """
99
+ Format mathematical equations with LaTeX markers if needed.
100
+ """
101
+ # Basic pattern for unmarked equations
102
+ equation_patterns = [
103
+ r'([^$])(\\frac{.+?}{.+?}|\\sum_|\\int_|\\lim_)',
104
+ r'([^$])([a-zA-Z]_[0-9]+)',
105
+ r'([^$])([a-zA-Z]\\in)'
106
+ ]
107
+
108
+ for pattern in equation_patterns:
109
+ text = re.sub(pattern, r'\1$\2$', text)
110
+
111
+ # Ensure equation blocks use proper LaTeX delimiters
112
+ text = re.sub(r'\\begin{equation}(.+?)\\end{equation}', r'$$\1$$', text, flags=re.DOTALL)
113
+
114
+ return text
115
+
116
+ def _format_technical_content(self, text: str) -> str:
117
+ """
118
+ Format technical content with proper highlighting of terms and concepts.
119
+ """
120
+ # Highlight technical terms
121
+ technical_terms = [
122
+ "neural network", "machine learning", "deep learning", "algorithm",
123
+ "regression", "classification", "clustering", "backpropagation",
124
+ "gradient descent", "optimization", "hyperparameter"
125
+ ]
126
+
127
+ for term in technical_terms:
128
+ # Only highlight whole words, not substrings
129
+ text = re.sub(r'\b(' + re.escape(term) + r')\b(?![*_])', r'*\1*', text)
130
+
131
+ return text
132
+
133
+ def _default_formatter(self, text: str) -> str:
134
+ """
135
+ Default formatter that applies general improvements.
136
+ """
137
+ # Add paragraph breaks for readability when appropriate
138
+ text = re.sub(r'(\w\.)\s+([A-Z])', r'\1\n\n\2', text)
139
+
140
+ # Format lists for readability if they're not already formatted
141
+ text = re.sub(r'(?<!\n)(\d+\.)\s+', r'\n\1 ', text)
142
+
143
+ return text
144
+
145
+ def format_structured_output(self, data: Dict[str, Any]) -> str:
146
+ """
147
+ Format structured data outputs (like JSON) into readable text.
148
+
149
+ Args:
150
+ data: Dictionary containing structured data
151
+
152
+ Returns:
153
+ Formatted string representation
154
+ """
155
+ if not isinstance(data, dict):
156
+ return str(data)
157
+
158
+ formatted_parts = []
159
+
160
+ # Format main response if present
161
+ if "response" in data:
162
+ formatted_parts.append(self.format_response(data["response"]))
163
+
164
+ # Add metadata in a clean format if needed
165
+ metadata = {}
166
+ for key, value in data.items():
167
+ if key != "response" and not key.startswith("_"):
168
+ metadata[key] = value
169
+
170
+ if metadata:
171
+ formatted_parts.append("\n\n---\n")
172
+ for key, value in metadata.items():
173
+ formatted_parts.append(f"**{key.replace('_', ' ').title()}**: {value}")
174
+
175
+ return "\n".join(formatted_parts)
utils/prepare_hf_training.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to prepare your model training for Hugging Face's training infrastructure.
3
+ """
4
+ import os
5
+ import shutil
6
+ import argparse
7
+ import logging
8
+ import json
9
+ from pathlib import Path
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
13
+
14
+ def prepare_hf_training(output_dir="hf_training"):
15
+ """Prepare code and configs for Hugging Face training"""
16
+ # Create output directory
17
+ os.makedirs(output_dir, exist_ok=True)
18
+
19
+ # Required files for HF training
20
+ required_files = [
21
+ "config.json", # Configuration file
22
+ "model_Custm.py", # Custom model implementation
23
+ "model_PrTr.py", # Pretrained model implementation
24
+ "model_Combn.py", # Combined model implementation
25
+ "tokenizer.py", # Tokenizer wrapper
26
+ "dataloader.py", # Data loading utilities
27
+ "utils", # Utility functions
28
+ "data", # Training data
29
+ "run_gpu_training.py" # Main training script
30
+ ]
31
+
32
+ # Copy required files
33
+ for file_or_dir in required_files:
34
+ src_path = Path(file_or_dir)
35
+ dst_path = Path(output_dir) / src_path.name
36
+
37
+ if not src_path.exists():
38
+ logger.warning(f"{file_or_dir} not found, skipping...")
39
+ continue
40
+
41
+ if src_path.is_dir():
42
+ if dst_path.exists():
43
+ shutil.rmtree(dst_path)
44
+ shutil.copytree(src_path, dst_path)
45
+ logger.info(f"Copied directory {file_or_dir} to {dst_path}")
46
+ else:
47
+ shutil.copy2(src_path, dst_path)
48
+ logger.info(f"Copied file {file_or_dir} to {dst_path}")
49
+
50
+ # Create HF training script
51
+ create_hf_train_script(output_dir)
52
+
53
+ # Create requirements.txt
54
+ create_requirements_file(output_dir)
55
+
56
+ # Update config for HF environment
57
+ update_config_for_hf(output_dir)
58
+
59
+ logger.info(f"Training package prepared in {output_dir}")
60
+ logger.info(f"Upload this directory to Hugging Face for training")
61
+ logger.info("See https://huggingface.co/docs/hub/spaces-manage-deploy for deployment instructions")
62
+
63
+ def create_hf_train_script(output_dir):
64
+ """Create training script specifically for Hugging Face environment"""
65
+ train_script = """
66
+ # Hugging Face training script for Transformer model
67
+ import os
68
+ import torch
69
+ from run_gpu_training import run_training
70
+
71
+ if __name__ == "__main__":
72
+ # HF automatically provides CUDA device if available
73
+ # Always use mixed precision on HF
74
+ run_training(use_mixed_precision=True)
75
+
76
+ # Save model to the /tmp/model directory, which HF preserves
77
+ os.makedirs("/tmp/model", exist_ok=True)
78
+ torch.save({
79
+ "config": "final_model_config",
80
+ "type": "transformer_trained",
81
+ "epochs_completed": 30
82
+ }, "/tmp/model/model_info.json")
83
+
84
+ print("Training completed, model saved to /tmp/model")
85
+ """
86
+
87
+ with open(os.path.join(output_dir, "train_hf.py"), "w") as f:
88
+ f.write(train_script.strip())
89
+
90
+ logger.info("Created HF training script: train_hf.py")
91
+
92
+ def create_requirements_file(output_dir):
93
+ """Create requirements.txt with dependencies"""
94
+ requirements = [
95
+ "torch>=2.0.0",
96
+ "transformers>=4.30.0",
97
+ "datasets>=2.12.0",
98
+ "pydantic>=2.0.0",
99
+ "sentence-transformers>=2.2.2",
100
+ "scikit-learn>=1.2.2",
101
+ "numpy>=1.24.0",
102
+ "pandas>=2.0.0",
103
+ "tqdm>=4.65.0",
104
+ "matplotlib>=3.7.1"
105
+ ]
106
+
107
+ with open(os.path.join(output_dir, "requirements.txt"), "w") as f:
108
+ f.write("\n".join(requirements))
109
+
110
+ logger.info("Created requirements.txt")
111
+
112
+ def update_config_for_hf(output_dir):
113
+ """Update configuration for HF environment"""
114
+ config_path = os.path.join(output_dir, "config.json")
115
+
116
+ if not os.path.exists(config_path):
117
+ logger.warning("config.json not found, skipping configuration update")
118
+ return
119
+
120
+ try:
121
+ with open(config_path, "r") as f:
122
+ config = json.load(f)
123
+
124
+ # Update for HF environment - restore full settings
125
+ config["TRANSFORMER_CONFIG"]["BATCH_SIZE"] = 32
126
+ config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] = 512
127
+ config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] = 12
128
+
129
+ # Add HF-specific config
130
+ if "OPTIMIZATION" not in config:
131
+ config["OPTIMIZATION"] = {}
132
+
133
+ config["OPTIMIZATION"]["USE_MIXED_PRECISION"] = True
134
+ config["OPTIMIZATION"]["PLATFORM"] = "huggingface"
135
+
136
+ with open(config_path, "w") as f:
137
+ json.dump(config, f, indent=2)
138
+
139
+ logger.info("Updated config.json for Hugging Face environment")
140
+ except Exception as e:
141
+ logger.error(f"Error updating config: {e}")
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser(description="Prepare training package for Hugging Face")
145
+ parser.add_argument("--output-dir", type=str, default="hf_training",
146
+ help="Output directory for training package")
147
+ args = parser.parse_args()
148
+
149
+ prepare_hf_training(args.output_dir)
utils/prepare_hf_transformer_training.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ How to Use prepare_hf_transformer_training.py Safely
3
+
4
+ Here's a secure way to prepare and upload your model to Hugging Face:
5
+
6
+ Step 1: Navigate to Your Project Directory
7
+ cd C:/Users/User/OneDrive/Documents/tlm
8
+
9
+ Step 2: Set Up Authentication for Hugging Face
10
+ huggingface-cli login
11
+
12
+ Step 3: Run the Preparation Script
13
+ python -m utils.prepare_hf_transformer_training --stdp_checkpoint "checkpoints/stdp_model_epoch_20.pt" --output_dir "C:/Users/User/OneDrive/Documents/tlm/Wildnerve-tlm_HF/hf_upload"
14
+
15
+ Step 4: Initialize Git and Upload to Hugging Face
16
+ cd hf_upload
17
+ git init
18
+ git add .
19
+ git commit -m "Add TLM model with STDP checkpoint"
20
+ git remote add origin https://huggingface.co/YOUR-USERNAME/Wildnerve-tlm01
21
+ git pull origin main --allow-unrelated-histories
22
+ git push origin main
23
+ """
24
+ import os
25
+ import shutil
26
+ import logging
27
+ import argparse
28
+ from pathlib import Path
29
+
30
+ logger = logging.getLogger(__name__)
31
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
32
+
33
+ def prepare_training_package(
34
+ stdp_checkpoint_path,
35
+ output_dir="hf_transformer_training",
36
+ include_all=False
37
+ ):
38
+ """Prepare a clean training package for Hugging Face with STDP checkpoint.
39
+ Args:
40
+ stdp_checkpoint_path: Path to the STDP checkpoint file
41
+ output_dir: Directory where to create the package
42
+ include_all: Whether to include all supporting files (utils, analyzers, etc.)"""
43
+ os.makedirs(output_dir, exist_ok=True)
44
+
45
+ # Core files needed for transformer training
46
+ essential_files = [
47
+ # Core components
48
+ "app.py",
49
+ "main.py",
50
+ "config.json",
51
+ "config.py",
52
+ "inference.py",
53
+
54
+ # Model implementations
55
+ "model_List.py",
56
+ "model_Custm.py",
57
+ "model_PrTr.py",
58
+ "model_Combn.py",
59
+ "model_manager.py",
60
+
61
+ # Communication components
62
+ "communicator.py",
63
+ "communicator_STDP.py",
64
+
65
+ # Data and training
66
+ "tokenizer.py",
67
+ "trainer.py",
68
+ "dataloader.py",
69
+ "dataset.py",
70
+ "data",
71
+
72
+ # STDP specific components
73
+ "STDP_Communicator/datasets_stdp.py",
74
+ "STDP_Communicator/train_stdp.py",
75
+
76
+ # Utils (only essential ones)
77
+ "utils/convert_checkpoints.py",
78
+ ]
79
+
80
+ # Additional support files (only included if include_all=True)
81
+ additional_files = [
82
+ "utils/transformer_utils.py",
83
+ "utils/smartHybridAttention.py",
84
+ "utils/sentence_transformer_utils.py",
85
+ "utils/output_formatter.py",
86
+ "emergency_monitor.py",
87
+ ]
88
+
89
+ # Choose which files to copy
90
+ required_files = essential_files + (additional_files if include_all else [])
91
+
92
+ logger.info(f"Starting package preparation in {output_dir}")
93
+ logger.info(f"Including {'all' if include_all else 'only essential'} files")
94
+
95
+ # Track successful and failed copies
96
+ copied_files = []
97
+ missing_files = []
98
+
99
+ # Copy files
100
+ for file_path in required_files:
101
+ src = Path(file_path)
102
+ if not src.exists():
103
+ logger.warning(f"File {file_path} not found, skipping")
104
+ missing_files.append(str(src))
105
+ continue
106
+
107
+ # Create destination directories
108
+ dst = Path(output_dir) / src
109
+ os.makedirs(dst.parent, exist_ok=True)
110
+
111
+ # Copy file or directory
112
+ try:
113
+ if src.is_dir():
114
+ shutil.copytree(src, dst, dirs_exist_ok=True)
115
+ else:
116
+ shutil.copy2(src, dst)
117
+ copied_files.append(str(src))
118
+ logger.info(f"Copied {src} to {dst}")
119
+ except Exception as e:
120
+ logger.error(f"Error copying {src}: {e}")
121
+
122
+ # Copy STDP checkpoint
123
+ if os.path.exists(stdp_checkpoint_path):
124
+ stdp_dst = Path(output_dir) / "checkpoints" / Path(stdp_checkpoint_path).name
125
+ os.makedirs(stdp_dst.parent, exist_ok=True)
126
+ try:
127
+ shutil.copy2(stdp_checkpoint_path, stdp_dst)
128
+ logger.info(f"Copied STDP checkpoint to {stdp_dst}")
129
+ copied_files.append(str(stdp_checkpoint_path))
130
+ except Exception as e:
131
+ logger.error(f"Error copying checkpoint: {e}")
132
+ missing_files.append(str(stdp_checkpoint_path))
133
+ else:
134
+ logger.warning(f"STDP checkpoint not found at {stdp_checkpoint_path}")
135
+ missing_files.append(str(stdp_checkpoint_path))
136
+
137
+ # Create Hugging Face training script
138
+ create_transformer_training_script(output_dir, stdp_checkpoint_path) # ADD THIS LINE
139
+
140
+ # Create requirements.txt if not already copied
141
+ if "requirements.txt" not in copied_files:
142
+ create_requirements(output_dir)
143
+ copied_files.append("requirements.txt (generated)")
144
+
145
+ # Create README.md if not already copied
146
+ if "README.md" not in copied_files:
147
+ create_readme(output_dir, stdp_checkpoint_path)
148
+ copied_files.append("README.md (generated)")
149
+
150
+ # Summarize what was done
151
+ logger.info(f"Package prepared in {output_dir}")
152
+ logger.info(f"Copied {len(copied_files)} files: {', '.join(copied_files[:5])}...")
153
+ if missing_files:
154
+ logger.warning(f"Missing {len(missing_files)} files: {', '.join(missing_files)}")
155
+
156
+ return output_dir
157
+
158
+ def create_transformer_training_script(output_dir, stdp_checkpoint_path):
159
+ """Create a script to load STDP checkpoint and train transformer."""
160
+ # Fix: Change the inner docstring to use single quotes to avoid conflict with the outer triple quotes
161
+ script = """
162
+ import os
163
+ import torch
164
+ import logging
165
+ from config import load_config, app_config
166
+ from tokenizer import TokenizerWrapper
167
+ from model_manager import ModelManager
168
+ from dataloader import prepare_data_loaders
169
+ from trainer import Trainer, EarlyStopping
170
+
171
+ # Configure logging
172
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
173
+ logger = logging.getLogger(__name__)
174
+
175
+ def train_transformer(stdp_checkpoint_path):
176
+ '''Train the transformer component after loading STDP weights.'''
177
+ logger.info(f"Starting transformer training with STDP checkpoint: {stdp_checkpoint_path}")
178
+
179
+ # Initialize components
180
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
181
+ logger.info(f"Using device: {device}")
182
+
183
+ # Create tokenizer
184
+ tokenizer = TokenizerWrapper()
185
+
186
+ # Get model manager
187
+ model_manager = ModelManager()
188
+
189
+ # Get specialization
190
+ specialization = app_config.TRANSFORMER_CONFIG.specialization
191
+
192
+ # Load STDP weights
193
+ if os.path.exists(stdp_checkpoint_path):
194
+ try:
195
+ stdp_checkpoint = torch.load(stdp_checkpoint_path, map_location=device)
196
+ logger.info(f"Loaded STDP checkpoint from {stdp_checkpoint_path}")
197
+
198
+ # Now integrate STDP weights with transformer model if needed
199
+ # This depends on your specific architecture
200
+ except Exception as e:
201
+ logger.error(f"Error loading STDP checkpoint: {e}")
202
+ else:
203
+ logger.warning(f"STDP checkpoint not found at {stdp_checkpoint_path}")
204
+
205
+ # Get model and move to device
206
+ model = model_manager.get_model(specialization)
207
+ model.to(device)
208
+
209
+ # Get data loaders
210
+ data_path = app_config.DATASET_PATHS.get(specialization)
211
+ if not data_path or not os.path.exists(data_path):
212
+ # Use a default dataset path
213
+ data_path = next(iter(app_config.DATASET_PATHS.values()))
214
+ logger.warning(f"Dataset for {specialization} not found, using {data_path}")
215
+
216
+ train_loader, val_loader = prepare_data_loaders(
217
+ data_path,
218
+ tokenizer,
219
+ batch_size=app_config.TRANSFORMER_CONFIG.BATCH_SIZE
220
+ )
221
+
222
+ # Set up checkpoint directory
223
+ checkpoint_dir = os.path.join("checkpoints", "transformer")
224
+ os.makedirs(checkpoint_dir, exist_ok=True)
225
+
226
+ # Set up early stopping
227
+ early_stopping = EarlyStopping(
228
+ patience=app_config.TRAINING_CONFIG.PATIENCE,
229
+ delta=app_config.TRAINING_CONFIG.DELTA,
230
+ verbose=True,
231
+ path=os.path.join(checkpoint_dir, "best_model.pt")
232
+ )
233
+
234
+ # Create trainer
235
+ trainer = Trainer(
236
+ model=model,
237
+ tokenizer=tokenizer,
238
+ train_dataloader=train_loader,
239
+ val_dataloader=val_loader,
240
+ device=device,
241
+ early_stopping=early_stopping,
242
+ checkpoint_dir=checkpoint_dir,
243
+ total_epochs=app_config.TRAINING_CONFIG.TRANSFORMER_NUM_EPOCHS
244
+ )
245
+
246
+ # Train the model
247
+ logger.info("Starting transformer training...")
248
+ trainer.train()
249
+
250
+ # Save final model
251
+ final_model_path = os.path.join(checkpoint_dir, "final_model.pt")
252
+ torch.save({
253
+ 'model_state_dict': model.state_dict(),
254
+ 'config': {
255
+ 'transformer_epochs': app_config.TRAINING_CONFIG.TRANSFORMER_NUM_EPOCHS,
256
+ 'stdp_epochs': 20, # Assuming the STDP checkpoint is from epoch 20
257
+ 'specialization': specialization
258
+ }
259
+ }, final_model_path)
260
+ logger.info(f"Final model saved to {final_model_path}")
261
+
262
+ return final_model_path
263
+
264
+ if __name__ == "__main__":
265
+ import argparse
266
+
267
+ parser = argparse.ArgumentParser(description="Train transformer after STDP")
268
+ parser.add_argument("--stdp_checkpoint", type=str, default="checkpoints/stdp_model_epoch_20.pt",
269
+ help="Path to pre-trained STDP checkpoint")
270
+
271
+ args = parser.parse_args()
272
+
273
+ # Train transformer
274
+ train_transformer(args.stdp_checkpoint)
275
+ """
276
+ script_path = os.path.join(output_dir, "train_transformer_hf.py")
277
+ with open(script_path, "w") as f:
278
+ f.write(script.strip())
279
+ logger.info(f"Created training script at {script_path}")
280
+
281
+ def create_requirements(output_dir):
282
+ """Create requirements.txt file with all necessary dependencies."""
283
+ requirements = [
284
+ "torch>=2.0.0",
285
+ "transformers>=4.30.0",
286
+ "datasets>=2.12.0",
287
+ "pydantic>=2.0.0",
288
+ "sentence-transformers>=2.2.2",
289
+ "scikit-learn>=1.2.2",
290
+ "numpy>=1.24.0",
291
+ "pandas>=2.0.0",
292
+ "tqdm>=4.65.0",
293
+ "matplotlib>=3.7.1",
294
+ "snntorch>=0.7.0"
295
+ ]
296
+
297
+ with open(os.path.join(output_dir, "requirements.txt"), "w") as f:
298
+ f.write("\n".join(requirements))
299
+ logger.info("Created requirements.txt")
300
+
301
+ def create_readme(output_dir, stdp_checkpoint_path):
302
+ """Create README with model information and usage instructions."""
303
+ readme = f"""# Wildnerve-tlm01: Transformer Language Model with STDP
304
+
305
+ This repository contains the Wildnerve-tlm01 model, a transformer-based language model enhanced with
306
+ STDP (Spike-Timing-Dependent Plasticity) for improved learning capabilities.
307
+
308
+ ## Pre-trained STDP Checkpoint
309
+
310
+ The STDP component was trained for 20 epochs and saved in: `{os.path.basename(stdp_checkpoint_path)}`
311
+
312
+ ## Model Architecture
313
+
314
+ Wildnerve-tlm01 combines:
315
+ - Transformer architecture for language understanding
316
+ - Spiking Neural Network (SNN) with STDP for biological learning
317
+ - Smart Hybrid Attention for efficient processing
318
+
319
+ ## Usage
320
+ """
321
+ with open(os.path.join(output_dir, "README.md"), "w") as f:
322
+ f.write(readme)
323
+ logger.info("Created README.md")
324
+
325
+ if __name__ == "__main__":
326
+ parser = argparse.ArgumentParser(description="Prepare Hugging Face training package")
327
+ parser.add_argument("--stdp_checkpoint", type=str, default="checkpoints/stdp_model_epoch_20.pt",
328
+ help="Path to pre-trained STDP checkpoint")
329
+ parser.add_argument("--output_dir", type=str, default="hf_upload",
330
+ help="Output directory for training package")
331
+ parser.add_argument("--include_all", action="store_true",
332
+ help="Include additional supporting files")
333
+
334
+ args = parser.parse_args()
335
+ prepare_training_package(args.stdp_checkpoint, args.output_dir, args.include_all)
utils/sentence_transformer_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for loading and working with sentence transformers.
3
+ """
4
+ import logging
5
+ import torch
6
+ import os
7
+ from typing import Optional
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Constants
13
+ DEFAULT_SENTENCE_TRANSFORMER = "Wildnerve-tlm01-0.05Bx12" # Removed fallback to all-MiniLM-L6-v2
14
+
15
+ # Cache for loaded models to avoid reloading
16
+ _sentence_transformer_cache = {}
17
+ def get_sentence_transformer(model_name: str = DEFAULT_SENTENCE_TRANSFORMER):
18
+ """
19
+ Get a sentence transformer model.
20
+
21
+ Args:
22
+ model_name: Name of the model to load (default is our primary model)
23
+
24
+ Returns:
25
+ SentenceTransformer model
26
+ """
27
+ # Define the expected local directory for your custom model
28
+ local_model_dir = os.path.join("c:/Users/User/OneDrive/Documents/tlm/Wildnerve-tlm_HF/models", model_name)
29
+ # Use the local directory if it exists; otherwise, use the provided model_name identifier (which must be on HuggingFace)
30
+ model_path = local_model_dir if os.path.isdir(local_model_dir) else model_name
31
+ try:
32
+ return SentenceTransformer(model_path)
33
+ except Exception as e:
34
+ logger.error(f"Failed to load SentenceTransformer from {model_path}: {e}")
35
+ raise
36
+
37
+ def clear_sentence_transformer_cache():
38
+ """Clear the sentence transformer cache to free memory."""
39
+ global _sentence_transformer_cache
40
+ _sentence_transformer_cache.clear()
41
+ logger.info("Cleared sentence transformer cache")
utils/smartHybridAttention.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # smartHybridAttention.py - Enhanced SmartHybridAttention that combines the best features of both implementations:
2
+ import os
3
+ import sys
4
+ import math
5
+ import json
6
+ import torch
7
+ import torch.nn as nn
8
+ import logging as logger
9
+ from typing import Optional, Tuple, List, Dict, Any, Union
10
+
11
+ # Fix imports for service_registry - make it more robust with fallbacks
12
+ try:
13
+ # Try direct import first
14
+ from service_registry import ServiceRegistry
15
+ except ImportError:
16
+ try:
17
+ # Try adding parent directories to path
18
+ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
+ sys.path.append(parent_dir)
20
+ from service_registry import ServiceRegistry
21
+ except ImportError:
22
+ # Create dummy registry if not found
23
+ class DummyRegistry:
24
+ def get_service(self, name): return None
25
+ def register_service(self, name, service): pass
26
+ registry = DummyRegistry()
27
+
28
+ # Use conditional import for AttentionProfileSelector
29
+ try:
30
+ # Try direct import first
31
+ from utils.attention_trigger_system import AttentionProfileSelector
32
+ except ImportError:
33
+ # Try setting up different paths
34
+ try:
35
+ data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
36
+ sys.path.append(data_dir)
37
+ from utils.attention_trigger_system import AttentionProfileSelector
38
+ except ImportError:
39
+ # Create a minimal placeholder if not found
40
+ class DummyAttentionProfileSelector:
41
+ def __init__(self, config_path=None): pass
42
+ def select_profile(self, text, context=None): return "standard", 1.0
43
+ def get_attention_type(self, profile_id): return "standard"
44
+ def get_profile_parameters(self, profile_id): return {}
45
+ AttentionProfileSelector = DummyAttentionProfileSelector
46
+
47
+ # Merging the two functions into a single robust implementation
48
+ def get_hybrid_attention_config(config_path: Optional[str] = None) -> Dict[str, Any]:
49
+ """Get configuration for hybrid attention mechanism from multiple sources.
50
+
51
+ Args:
52
+ config_path: Optional path to a JSON configuration file
53
+
54
+ Returns:
55
+ Dictionary with attention configuration parameters
56
+ """
57
+ # Start with default configuration
58
+ default_config = {
59
+ "DIM": 768,
60
+ "NUM_HEADS": 12,
61
+ "WINDOW_SIZE": 256,
62
+ "USE_SLIDING": True,
63
+ "USE_GLOBAL": True,
64
+ "USE_HIERARCHICAL": False,
65
+ "GLOBAL_TOKEN_RATIO": 0.05,
66
+ "MEMORY_TOKENS": 32,
67
+ "STRIDE": 128
68
+ }
69
+
70
+ # Try to load from app_config if available
71
+ try:
72
+ from config import app_config
73
+
74
+ # Use safe value loading with proper type checking
75
+ dim = safe_get_int_value(app_config, 'EMBEDDING_DIM', default_config["DIM"])
76
+ num_heads = safe_get_int_value(app_config, 'NUM_HEADS', default_config["NUM_HEADS"])
77
+ window_size = safe_get_int_value(app_config, 'WINDOW_SIZE', default_config["WINDOW_SIZE"])
78
+
79
+ # Update config with values from app_config
80
+ default_config.update({
81
+ 'DIM': dim,
82
+ 'NUM_HEADS': num_heads,
83
+ 'WINDOW_SIZE': window_size,
84
+ # These values might be different in app_config, so keep them from the original defaults
85
+ 'USE_HIERARCHICAL': True, # Changed from default based on app_config version
86
+ 'GLOBAL_TOKEN_RATIO': 0.2, # Changed from default based on app_config version
87
+ 'MEMORY_TOKENS': 16, # Changed from default based on app_config version
88
+ })
89
+
90
+ # Calculate stride based on window size with proper type checking
91
+ if isinstance(window_size, int) and window_size > 0:
92
+ default_config['STRIDE'] = window_size // 2
93
+ else:
94
+ default_config['STRIDE'] = 128
95
+
96
+ except Exception as e:
97
+ logger.warning(f"Error loading config from app_config: {e}, using defaults")
98
+
99
+ # Try to load from JSON file if path provided
100
+ if config_path and os.path.exists(config_path):
101
+ try:
102
+ with open(config_path, "r") as f:
103
+ user_config = json.load(f)
104
+ # Merge configs, with user config taking precedence
105
+ for key, value in user_config.items():
106
+ default_config[key.upper()] = value
107
+ logger.info(f"Loaded attention config from {config_path}")
108
+ except Exception as e:
109
+ logger.warning(f"Error loading attention config from {config_path}: {e}")
110
+
111
+ return default_config
112
+
113
+ # Update the get_attention_config function to use our new merged function
114
+ def get_attention_config(config_path: Optional[str] = None) -> Dict[str, Any]:
115
+ """Get attention configuration using the most appropriate method available"""
116
+ return get_hybrid_attention_config(config_path)
117
+
118
+ # Add helper function for the above
119
+ def safe_get_int_value(config_obj, key, default=512):
120
+ """Safely get an integer value from config with proper type checking"""
121
+ try:
122
+ if hasattr(config_obj, key):
123
+ value = getattr(config_obj, key)
124
+ elif hasattr(config_obj, 'TRANSFORMER_CONFIG') and hasattr(config_obj.TRANSFORMER_CONFIG, key):
125
+ value = getattr(config_obj.TRANSFORMER_CONFIG, key)
126
+ else:
127
+ return default
128
+
129
+ if isinstance(value, dict):
130
+ logger.warning(f"Config value {key} is a dictionary, using default: {default}")
131
+ return default
132
+ elif isinstance(value, (int, float)):
133
+ return int(value)
134
+ else:
135
+ logger.warning(f"Config value {key} is not a number, using default: {default}")
136
+ return default
137
+ except Exception as e:
138
+ logger.warning(f"Error getting config value {key}: {e}")
139
+ return default
140
+
141
+ class SmartHybridAttention(nn.Module):
142
+ """SmartHybridAttention that combines the best features of both implementations:
143
+ - Memory storage via global token selection from Wildnerve-tlm_HF
144
+ - Multiple attention strategies from utils version
145
+ - HuggingFace compatibility layer
146
+ - Optimized for extremely large context windows"""
147
+ def __init__(
148
+ self,
149
+ dim: int,
150
+ num_heads: int = 8,
151
+ window_size: int = 256,
152
+ use_sliding: bool = True,
153
+ use_global: bool = True,
154
+ use_hierarchical: bool = False,
155
+ global_token_ratio: float = 0.05,
156
+ memory_tokens: int = 32,
157
+ config_path: Optional[str] = None,
158
+ registry_key: Optional[str] = None,
159
+ attention_config_path: Optional[str] = None
160
+ ):
161
+ super().__init__()
162
+ # Ensure all parameters are the correct types
163
+ self.dim = int(dim) if isinstance(dim, (int, float)) else 768
164
+ self.num_heads = int(num_heads) if isinstance(num_heads, (int, float)) else 8
165
+ self.head_dim = self.dim // self.num_heads # Safe integer division
166
+ self.window_size = int(window_size) if isinstance(window_size, (int, float)) else 256
167
+ self.scale = self.head_dim ** -0.5
168
+
169
+ # Feature flags - ensure boolean types
170
+ self.use_sliding = bool(use_sliding)
171
+ self.use_global = bool(use_global)
172
+ self.use_hierarchical = bool(use_hierarchical)
173
+
174
+ # Ensure float type for ratio
175
+ self.global_token_ratio = float(global_token_ratio) if isinstance(global_token_ratio, (int, float)) else 0.05
176
+
177
+ # Ensure int type for memory tokens
178
+ self.memory_tokens = int(memory_tokens) if isinstance(memory_tokens, (int, float)) else 32
179
+
180
+ # Initialize memory parameter
181
+ self.persistent_memory = nn.Parameter(torch.zeros(self.memory_tokens, 1, self.dim))
182
+ nn.init.normal_(self.persistent_memory, mean=0.0, std=0.02)
183
+
184
+ # Projections
185
+ self.q_proj = nn.Linear(self.dim, self.dim)
186
+ self.k_proj = nn.Linear(self.dim, self.dim)
187
+ self.v_proj = nn.Linear(self.dim, self.dim)
188
+ self.out_proj = nn.Linear(self.dim, self.dim)
189
+
190
+ # Initialize optional components
191
+ self.config = self._load_config(config_path) if config_path else {}
192
+ self.registry_key = registry_key
193
+ self.prompt_analyzer = None
194
+ self._init_external_services()
195
+
196
+ # Initialize content-aware attention selector
197
+ self.attention_config_path = attention_config_path
198
+ if self.attention_config_path:
199
+ try:
200
+ self.profile_selector = AttentionProfileSelector(self.attention_config_path)
201
+ except Exception as e:
202
+ logger.warning(f"Could not initialize AttentionProfileSelector: {e}")
203
+ self.profile_selector = None
204
+ else:
205
+ self.profile_selector = None
206
+
207
+ def _init_external_services(self):
208
+ """Initialize external services like registry and analyzer if available."""
209
+ try:
210
+ # Try to import service registry
211
+ sys.path.extend([
212
+ os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "utils"),
213
+ os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Wildnerve-tlm_HF", "utils")
214
+ ])
215
+
216
+ if self.registry_key:
217
+ registry = ServiceRegistry.get_instance()
218
+ self.prompt_analyzer = registry.get_service("prompt_analyzer")
219
+ # Register self if registry key provided
220
+ registry.register_service(self.registry_key, self)
221
+ except (ImportError, AttributeError, Exception):
222
+ logger.debug("External services not available")
223
+
224
+ def _load_config(self, config_path: str) -> Dict:
225
+ """Load configuration from JSON file."""
226
+ if not os.path.exists(config_path):
227
+ return {}
228
+
229
+ try:
230
+ with open(config_path, "r") as f:
231
+ return json.load(f)
232
+ except:
233
+ return {}
234
+
235
+ def _create_sliding_window_mask(
236
+ self,
237
+ seq_len: int,
238
+ window_size: int,
239
+ global_token_indices: Optional[List[int]] = None,
240
+ memory_size: int = 0
241
+ ) -> torch.Tensor:
242
+ """Create attention mask for sliding window with memory tokens and global tokens."""
243
+ total_len = seq_len + memory_size
244
+ mask = torch.zeros(total_len, total_len, dtype=torch.bool)
245
+
246
+ # Memory tokens attend to everything and everything attends to memory
247
+ if memory_size > 0:
248
+ mask[:memory_size, :] = True # Memory attends to all
249
+ mask[:, :memory_size] = True # All attend to memory
250
+
251
+ # Set sliding window attention for content tokens
252
+ for i in range(memory_size, total_len):
253
+ # Adjust window bounds
254
+ start = max(memory_size, i - window_size // 2)
255
+ end = min(total_len, i + window_size // 2 + 1)
256
+ mask[i, start:end] = True
257
+
258
+ # Add global token attention if provided
259
+ if global_token_indices is not None:
260
+ # Adjust indices to account for memory tokens
261
+ adjusted_indices = [idx + memory_size for idx in global_token_indices]
262
+ # Global tokens attend to all tokens
263
+ mask[adjusted_indices, :] = True
264
+ # All tokens attend to global tokens
265
+ mask[:, adjusted_indices] = True
266
+ return mask
267
+
268
+ def _select_global_tokens(
269
+ self,
270
+ key_layer: torch.Tensor,
271
+ ratio: float = None,
272
+ memory_size: int = 0
273
+ ) -> List[int]:
274
+ """Select global tokens based on importance scoring.
275
+ Returns: List of indices of selected global tokens"""
276
+ if ratio is None:
277
+ ratio = self.global_token_ratio
278
+
279
+ seq_len = key_layer.size(0) - memory_size
280
+ num_global_tokens = max(1, int(seq_len * ratio))
281
+
282
+ # Skip memory tokens when scoring importance
283
+ content_keys = key_layer[memory_size:]
284
+
285
+ # Score tokens by L2 norm and recency (more recent = more important)
286
+ base_scores = torch.norm(content_keys, dim=-1).mean(dim=-1) # [seq_len]
287
+
288
+ # Add recency bias
289
+ seq_positions = torch.arange(seq_len, device=base_scores.device) / seq_len
290
+ recency_scores = 0.3 * seq_positions # Mild recency bias
291
+ final_scores = base_scores + recency_scores
292
+
293
+ # Select top-k indices
294
+ _, indices = torch.topk(final_scores, k=min(num_global_tokens, seq_len))
295
+
296
+ return indices.tolist()
297
+
298
+ def _apply_memory_augmented_attention(
299
+ self,
300
+ query: torch.Tensor,
301
+ key: torch.Tensor,
302
+ value: torch.Tensor,
303
+ attention_mask: Optional[torch.Tensor] = None
304
+ ) -> torch.Tensor:
305
+ """Apply attention with persistent memory tokens for long-range context.
306
+ Returns: Output tensor after attention [seq_len, batch, dim]"""
307
+ seq_len, batch_size, _ = query.size()
308
+
309
+ # Expand memory tokens to batch size
310
+ memory_batch = self.persistent_memory.expand(-1, batch_size, -1)
311
+
312
+ # Prepend memory tokens to input
313
+ query_with_memory = torch.cat([memory_batch, query], dim=0)
314
+ key_with_memory = torch.cat([memory_batch, key], dim=0)
315
+ value_with_memory = torch.cat([memory_batch, value], dim=0)
316
+
317
+ # Project query, key, value
318
+ q = self.q_proj(query_with_memory)
319
+ k = self.k_proj(key_with_memory)
320
+ v = self.v_proj(value_with_memory)
321
+
322
+ # Select global tokens for additional global attention
323
+ global_token_indices = None
324
+ if self.use_global and seq_len > self.window_size:
325
+ global_token_indices = self._select_global_tokens(k, memory_size=self.memory_tokens)
326
+
327
+ # Create or modify attention mask
328
+ memory_size = self.memory_tokens
329
+ full_seq_len = seq_len + memory_size
330
+
331
+ if attention_mask is None:
332
+ window_mask = self._create_sliding_window_mask(
333
+ seq_len, self.window_size, global_token_indices, memory_size
334
+ )
335
+ attention_mask = ~window_mask
336
+ attention_mask = attention_mask.to(q.device).unsqueeze(0).unsqueeze(0)
337
+ attention_mask = attention_mask * -1e9 # Large negative for masked positions
338
+ else:
339
+ # Modify existing mask to include memory tokens
340
+ memory_mask = torch.zeros(full_seq_len, full_seq_len, device=attention_mask.device)
341
+ memory_mask[memory_size:, memory_size:] = attention_mask
342
+ # Make memory attend to everything and everything attend to memory
343
+ memory_mask[:memory_size, :] = 0 # 0 = attend (not masked)
344
+ memory_mask[:, :memory_size] = 0
345
+ attention_mask = memory_mask
346
+
347
+ # Reshape for multi-head attention
348
+ q = q.view(full_seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
349
+ k = k.view(full_seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
350
+ v = v.view(full_seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
351
+
352
+ # Compute attention
353
+ scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
354
+ if attention_mask is not None:
355
+ scores = scores + attention_mask
356
+ attention_weights = torch.softmax(scores, dim=-1)
357
+ context = torch.matmul(attention_weights, v)
358
+
359
+ # Reshape back
360
+ context = context.transpose(1, 2).transpose(0, 1).contiguous()
361
+ context = context.view(full_seq_len, batch_size, self.dim)
362
+
363
+ # Remove memory tokens from output
364
+ context = context[memory_size:]
365
+
366
+ # Final projection
367
+ output = self.out_proj(context)
368
+ return output
369
+
370
+ def _apply_hierarchical_attention(
371
+ self,
372
+ query: torch.Tensor,
373
+ key: torch.Tensor,
374
+ value: torch.Tensor
375
+ ) -> torch.Tensor:
376
+ """Apply hierarchical attention for very long sequences."""
377
+ seq_len, batch_size, _ = query.size()
378
+ chunk_size = min(512, seq_len)
379
+ num_chunks = math.ceil(seq_len / chunk_size)
380
+
381
+ # First level: process chunks independently
382
+ chunk_outputs = []
383
+ for i in range(num_chunks):
384
+ start_idx = i * chunk_size
385
+ end_idx = min((i + 1) * chunk_size, seq_len)
386
+
387
+ # Extract chunk
388
+ q_chunk = query[start_idx:end_idx]
389
+ k_chunk = key[start_idx:end_idx]
390
+ v_chunk = value[start_idx:end_idx]
391
+
392
+ # Process chunk with full attention
393
+ q_proj = self.q_proj(q_chunk)
394
+ k_proj = self.k_proj(k_chunk)
395
+ v_proj = self.v_proj(v_chunk)
396
+
397
+ # Reshape for multi-head attention
398
+ chunk_len = end_idx - start_idx
399
+ q_proj = q_proj.view(chunk_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
400
+ k_proj = k_proj.view(chunk_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
401
+ v_proj = v_proj.view(chunk_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
402
+
403
+ # Compute attention
404
+ scores = torch.matmul(q_proj, k_proj.transpose(-1, -2)) * self.scale
405
+ weights = torch.softmax(scores, dim=-1)
406
+ context = torch.matmul(weights, v_proj)
407
+
408
+ # Reshape back
409
+ context = context.transpose(1, 2).transpose(0, 1).contiguous()
410
+ context = context.view(chunk_len, batch_size, self.dim)
411
+ chunk_outputs.append(context)
412
+
413
+ # Concatenate chunks
414
+ output = torch.cat(chunk_outputs, dim=0)
415
+ output = self.out_proj(output)
416
+ return output
417
+
418
+ def forward(
419
+ self,
420
+ query: torch.Tensor,
421
+ key: torch.Tensor,
422
+ value: torch.Tensor,
423
+ attention_mask: Optional[torch.Tensor] = None,
424
+ input_text: Optional[str] = None, # New parameter for content detection
425
+ context: Optional[Dict] = None, # New parameter for additional context
426
+ **kwargs
427
+ ) -> torch.Tensor:
428
+ """Forward pass for the enhanced smart hybrid attention layer."""
429
+ seq_len, batch_size, _ = query.size()
430
+
431
+ # If input_text is provided, try to use content-aware attention
432
+ if input_text:
433
+ try:
434
+ # Try multiple import paths to find attention_connector
435
+ connector = None
436
+ import_paths = [
437
+ # Try direct import
438
+ lambda: __import__('attention_connector').get_attention_connector(),
439
+ # Try data subdirectory
440
+ lambda: __import__('data.attention_connector').get_attention_connector(),
441
+ # Try relative path
442
+ lambda: __import__('.'.join(['..', 'data', 'attention_connector']),
443
+ fromlist=['get_attention_connector']).get_attention_connector()
444
+ ]
445
+
446
+ for import_path in import_paths:
447
+ try:
448
+ connector = import_path()
449
+ break
450
+ except (ImportError, AttributeError):
451
+ continue
452
+
453
+ if connector:
454
+ connector.set_input_text(input_text)
455
+ except Exception as e:
456
+ # Silently continue if connector not available
457
+ logger.debug(f"Error setting input text for content-aware attention: {e}")
458
+
459
+ # Analyze sequence characteristics to choose strategy
460
+ strategy_weights = self._get_attention_strategy(seq_len, input_text, context)
461
+
462
+ # For very short sequences, use standard attention
463
+ if seq_len < 128:
464
+ q = self.q_proj(query)
465
+ k = self.k_proj(key)
466
+ v = self.v_proj(value)
467
+
468
+ # Multi-head attention
469
+ q = q.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
470
+ k = k.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
471
+ v = v.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
472
+
473
+ scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
474
+ if attention_mask is not None:
475
+ scores = scores + attention_mask
476
+ attn_weights = torch.softmax(scores, dim=-1)
477
+ context = torch.matmul(attn_weights, v)
478
+
479
+ context = context.transpose(1, 2).transpose(0, 1).contiguous()
480
+ context = context.view(seq_len, batch_size, self.dim)
481
+ return self.out_proj(context)
482
+
483
+ # For longer sequences, use memory-augmented attention
484
+ if strategy_weights["memory"] > 0:
485
+ return self._apply_memory_augmented_attention(query, key, value, attention_mask)
486
+
487
+ # For very long sequences where memory doesn't fit, use hierarchical
488
+ if strategy_weights["hierarchical"] > 0:
489
+ return self._apply_hierarchical_attention(query, key, value)
490
+
491
+ # Fallback: standard sliding window without memory
492
+ q = self.q_proj(query)
493
+ k = self.k_proj(key)
494
+ v = self.v_proj(value)
495
+
496
+ global_token_indices = self._select_global_tokens(k, memory_size=0) if self.use_global else None
497
+ window_mask = self._create_sliding_window_mask(seq_len, self.window_size, global_token_indices, 0)
498
+ masked_attn = ~window_mask
499
+ masked_attn = masked_attn.to(query.device).unsqueeze(0).unsqueeze(0) * -1e9
500
+
501
+ # Standard multi-head attention with mask
502
+ q = q.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
503
+ k = k.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
504
+ v = v.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1).transpose(1, 2)
505
+
506
+ scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
507
+ scores = scores + masked_attn
508
+ attn_weights = torch.softmax(scores, dim=-1)
509
+ context = torch.matmul(attn_weights, v)
510
+
511
+ context = context.transpose(1, 2).transpose(0, 1).contiguous()
512
+ context = context.view(seq_len, batch_size, self.dim)
513
+ return self.out_proj(context)
514
+
515
+ def _get_attention_strategy(self, seq_len: int, input_text: Optional[str] = None, context: Optional[Dict] = None) -> Dict[str, float]:
516
+ """Determine which attention strategy to use based on sequence length
517
+ and optional prompt analysis."""
518
+ weights = {
519
+ "standard": 0.0,
520
+ "sliding": 0.0,
521
+ "memory": 0.0,
522
+ "hierarchical": 0.0
523
+ }
524
+
525
+ # Try content-aware selection if available and input_text is provided
526
+ if self.profile_selector and input_text:
527
+ try:
528
+ profile_id, confidence = self.profile_selector.select_profile(input_text, context)
529
+
530
+ # If we have high confidence in the profile selection, use profile-specific weights
531
+ if confidence > 0.65:
532
+ attention_type = self.profile_selector.get_attention_type(profile_id)
533
+
534
+ if attention_type == "hierarchical":
535
+ weights["hierarchical"] = 0.7
536
+ weights["memory"] = 0.3
537
+ return weights
538
+ elif attention_type == "smartHybrid":
539
+ weights["memory"] = 0.5
540
+ weights["sliding"] = 0.5
541
+ return weights
542
+ elif attention_type == "recencyBiased":
543
+ weights["memory"] = 0.8
544
+ weights["sliding"] = 0.2
545
+ return weights
546
+ # Additional attention types can be added here
547
+ except Exception as e:
548
+ print(f"Warning: Error in content-based attention selection: {e}")
549
+
550
+ # Fall back to sequence length-based selection if content detection fails or has low confidence
551
+ if seq_len < 128:
552
+ weights["standard"] = 1.0
553
+ elif seq_len < 2048:
554
+ weights["sliding"] = 0.2
555
+ weights["memory"] = 0.8
556
+ elif seq_len < 8192:
557
+ weights["memory"] = 1.0
558
+ else:
559
+ weights["memory"] = 0.7
560
+ weights["hierarchical"] = 0.3
561
+
562
+ # Adjust based on prompt analyzer if available
563
+ if self.prompt_analyzer:
564
+ try:
565
+ analysis = self.prompt_analyzer.get_current_analysis()
566
+ if analysis:
567
+ complexity = analysis.get("complexity", 0.5)
568
+ structure = analysis.get("structure_score", 0.5)
569
+
570
+ # Adjust for highly structured content
571
+ if structure > 0.7:
572
+ weights["hierarchical"] = min(0.8, weights["hierarchical"] + 0.3)
573
+ weights["memory"] = max(0.2, weights["memory"] - 0.3)
574
+
575
+ # Adjust for high complexity content
576
+ if complexity > 0.8 and seq_len > 1024:
577
+ weights["memory"] = min(1.0, weights["memory"] + 0.2)
578
+ except:
579
+ logger.debug("Error in prompt analysis")
580
+
581
+ return weights
582
+
583
+ def to_hf_attention(self):
584
+ """
585
+ Convert to HuggingFace-compatible attention layer.
586
+ """
587
+ class HFCompatibleAttention(nn.Module):
588
+ def __init__(self, smart_attention):
589
+ super().__init__()
590
+ self.smart_attention = smart_attention
591
+
592
+ def __call__(self, hidden_states, attention_mask=None, **kwargs):
593
+ # Convert HF format attention mask if present
594
+ if attention_mask is not None:
595
+ if attention_mask.dim() == 4: # [batch, 1, 1, seq_len]
596
+ attention_mask = attention_mask.squeeze(1).squeeze(1)
597
+ attention_mask = attention_mask.to(dtype=torch.float32)
598
+ attention_mask = (1.0 - attention_mask) * -10000.0
599
+
600
+ # Convert from [batch, seq, dim] to [seq, batch, dim]
601
+ seq_first = hidden_states.transpose(0, 1)
602
+
603
+ # Apply attention
604
+ output = self.smart_attention(
605
+ seq_first, seq_first, seq_first,
606
+ attention_mask=attention_mask,
607
+ **kwargs
608
+ )
609
+
610
+ # Convert back to [batch, seq, dim]
611
+ return output.transpose(0, 1)
612
+ return HFCompatibleAttention(self)
613
+
614
+ @classmethod
615
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
616
+ """
617
+ Create an instance from a pretrained Hugging Face model configuration.
618
+ """
619
+ try:
620
+ from transformers import AutoConfig
621
+
622
+ # Load config from HF model
623
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
624
+
625
+ # Extract relevant attributes
626
+ attention_kwargs = {
627
+ "dim": config.hidden_size,
628
+ "num_heads": config.num_attention_heads,
629
+ "window_size": kwargs.get("window_size", 512),
630
+ "memory_tokens": kwargs.get("memory_tokens", min(32, config.max_position_embeddings // 64)),
631
+ }
632
+
633
+ # Update with any user-provided kwargs
634
+ attention_kwargs.update(kwargs)
635
+
636
+ # Create instance
637
+ return cls(**attention_kwargs)
638
+ except ImportError:
639
+ raise ImportError("transformers library required to load from pretrained model")
640
+ except Exception as e:
641
+ raise ValueError(f"Failed to initialize from pretrained model: {e}")
642
+
643
+ def create_smart_hybrid_attention(
644
+ dim: int = 768,
645
+ num_heads: int = 12,
646
+ max_sequence_length: int = 8192,
647
+ for_huggingface: bool = True,
648
+ **kwargs
649
+ ) -> Union[SmartHybridAttention, nn.Module]:
650
+ """Factory function to create an attention mechanism suitable for the given context.
651
+ Args:
652
+ dim: Hidden dimension size
653
+ num_heads: Number of attention heads
654
+ max_sequence_length: Maximum expected sequence length
655
+ for_huggingface: Whether to return a HuggingFace-compatible version
656
+ Returns:
657
+ An attention module that can be used in a transformer"""
658
+ # Determine appropriate memory size based on sequence length
659
+ memory_tokens = min(max(16, max_sequence_length // 256), 64)
660
+
661
+ # Determine appropriate window size
662
+ window_size = min(512, max(128, max_sequence_length // 32))
663
+
664
+ # Create attention module
665
+ attention = SmartHybridAttention(
666
+ dim=dim,
667
+ num_heads=num_heads,
668
+ window_size=window_size,
669
+ memory_tokens=memory_tokens,
670
+ **kwargs
671
+ )
672
+ # Wrap for HuggingFace if requested
673
+ if for_huggingface:
674
+ return attention.to_hf_attention()
675
+ return attention
utils/tokenizer_utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for tokenizer-related operations.
3
+ """
4
+ import torch
5
+ import logging
6
+ from typing import Dict, List, Any, Union, Optional
7
+ from transformers import AutoTokenizer
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def get_special_tokens_mask(tokenizer, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
12
+ """
13
+ Retrieve special tokens mask.
14
+
15
+ Args:
16
+ tokenizer: Tokenizer to use
17
+ token_ids_0: First token IDs
18
+ token_ids_1: Second token IDs (for pairs)
19
+ already_has_special_tokens: Whether token_ids already contain special tokens
20
+
21
+ Returns:
22
+ List of 1s and 0s, where 1 indicates a special token
23
+ """
24
+ if already_has_special_tokens:
25
+ return tokenizer.get_special_tokens_mask(
26
+ token_ids_0,
27
+ token_ids_1=token_ids_1,
28
+ already_has_special_tokens=True
29
+ )
30
+
31
+ if token_ids_1 is None:
32
+ return tokenizer.get_special_tokens_mask(
33
+ token_ids_0,
34
+ token_ids_1=None,
35
+ already_has_special_tokens=False
36
+ )
37
+
38
+ return tokenizer.get_special_tokens_mask(
39
+ token_ids_0,
40
+ token_ids_1=token_ids_1,
41
+ already_has_special_tokens=False
42
+ )
43
+
44
+ def add_tokens_to_tokenizer(tokenizer, new_tokens):
45
+ """
46
+ Add new tokens to tokenizer vocabulary.
47
+
48
+ Args:
49
+ tokenizer: Tokenizer to modify
50
+ new_tokens: List of new tokens to add
51
+
52
+ Returns:
53
+ Number of tokens added
54
+ """
55
+ return tokenizer.add_tokens(new_tokens)
56
+
57
+ def format_batch_for_model(
58
+ batch: Dict[str, torch.Tensor],
59
+ device: torch.device = None
60
+ ) -> Dict[str, torch.Tensor]:
61
+ """
62
+ Format a batch for model input, moving tensors to specified device.
63
+
64
+ Args:
65
+ batch: Dictionary of tensors
66
+ device: Device to move tensors to
67
+
68
+ Returns:
69
+ Formatted batch dictionary
70
+ """
71
+ if device is None:
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+
74
+ formatted_batch = {}
75
+ for k, v in batch.items():
76
+ if isinstance(v, torch.Tensor):
77
+ formatted_batch[k] = v.to(device)
78
+ else:
79
+ formatted_batch[k] = v
80
+ return formatted_batch
81
+
82
+ def batch_encode_plus(
83
+ tokenizer,
84
+ texts: List[str],
85
+ batch_size: int = 32,
86
+ max_length: int = 512,
87
+ return_tensors: str = "pt",
88
+ **kwargs
89
+ ) -> List[Dict[str, torch.Tensor]]:
90
+ """
91
+ Encode a large batch of texts in smaller chunks.
92
+
93
+ Args:
94
+ tokenizer: Tokenizer to use
95
+ texts: List of texts to encode
96
+ batch_size: Size of each processing batch
97
+ max_length: Maximum sequence length
98
+ return_tensors: Return format ('pt' for PyTorch)
99
+ **kwargs: Additional encoding parameters
100
+
101
+ Returns:
102
+ List of encoded batches
103
+ """
104
+ batches = []
105
+
106
+ for i in range(0, len(texts), batch_size):
107
+ batch_texts = texts[i:i + batch_size]
108
+ encoded = tokenizer(
109
+ batch_texts,
110
+ max_length=max_length,
111
+ padding="max_length",
112
+ truncation=True,
113
+ return_tensors=return_tensors,
114
+ **kwargs
115
+ )
116
+ batches.append(encoded)
117
+
118
+ return batches
119
+
120
+ def get_tokenizer_info(tokenizer) -> Dict[str, Any]:
121
+ """
122
+ Get information about a tokenizer.
123
+
124
+ Args:
125
+ tokenizer: Tokenizer to inspect
126
+
127
+ Returns:
128
+ Dictionary with tokenizer information
129
+ """
130
+ info = {
131
+ "vocab_size": len(tokenizer),
132
+ "model_name": getattr(tokenizer, "name_or_path", None),
133
+ "special_tokens": {}
134
+ }
135
+
136
+ # Get special token attributes if available
137
+ special_tokens = [
138
+ "pad_token", "unk_token", "sep_token",
139
+ "cls_token", "mask_token", "bos_token", "eos_token"
140
+ ]
141
+
142
+ for token_name in special_tokens:
143
+ token_value = getattr(tokenizer, f"{token_name}", None)
144
+ if token_value is not None:
145
+ info["special_tokens"][token_name] = token_value
146
+
147
+ return info
utils/transformer_utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified utilities for handling transformers, tokenizers, and embeddings.
3
+ """
4
+ import os
5
+ import logging
6
+ import torch
7
+ from typing import Dict, Any, Optional, Union, List
8
+ from sentence_transformers import SentenceTransformer
9
+ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoModel
10
+
11
+ # Import the new sentence transformer utilities
12
+ from utils.sentence_transformer_utils import get_sentence_transformer as load_sentence_transformer
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Constants
17
+ DEFAULT_SENTENCE_TRANSFORMER = "sentence-transformers/Wildnerve-tlm01-0.05Bx12"
18
+ DEFAULT_TOKENIZER = "bert-base-uncased"
19
+ FALLBACK_TOKENIZERS = ["bert-base-uncased", "gpt2", "roberta-base"]
20
+
21
+ # Cache for loaded models to avoid reloading
22
+ _model_cache = {}
23
+ _tokenizer_cache = {}
24
+ _sentence_transformer_cache = {}
25
+
26
+ def get_sentence_transformer(model_name):
27
+ try:
28
+ from sentence_transformers import SentenceTransformer
29
+ return SentenceTransformer(model_name)
30
+ except Exception as e:
31
+ logging.error(f"Failed to load sentence transformer {model_name}: {e}")
32
+ logging.warning("Falling back to default model: Wildnerve-tlm01-0.05Bx12")
33
+ from sentence_transformers import SentenceTransformer
34
+ return SentenceTransformer("Wildnerve-tlm01-0.05Bx12")
35
+
36
+ def get_tokenizer(model_name: str = "bert-base-uncased"):
37
+ """Get a tokenizer with proper error handling"""
38
+ try:
39
+ from transformers import AutoTokenizer
40
+ logger.info(f"Loading tokenizer: {model_name}")
41
+ return AutoTokenizer.from_pretrained(model_name)
42
+ except Exception as e:
43
+ logger.error(f"Failed to load tokenizer {model_name}: {e}")
44
+ # Return a minimal dummy tokenizer that won't break everything
45
+ logger.warning("Using dummy tokenizer as fallback")
46
+
47
+ class DummyTokenizer:
48
+ def __init__(self):
49
+ self.vocab_size = 30522 # BERT vocab size
50
+ self.pad_token_id = 0
51
+ self.eos_token_id = 102
52
+ self.bos_token_id = 101
53
+
54
+ def __call__(self, text, **kwargs):
55
+ """Convert text to a dict with dummy tensors"""
56
+ import torch
57
+
58
+ # Handle batch vs single input
59
+ is_batch = isinstance(text, list)
60
+ texts = text if is_batch else [text]
61
+
62
+ # Create random but deterministic IDs based on text length
63
+ input_ids = []
64
+ attention_mask = []
65
+
66
+ for t in texts:
67
+ # Use text length to create deterministic pseudo-random sequence
68
+ import hashlib
69
+ hash_obj = hashlib.md5(t.encode())
70
+ seed = int(hash_obj.hexdigest(), 16) % 10000
71
+
72
+ import random
73
+ random.seed(seed)
74
+
75
+ # Get length or use max_length if provided
76
+ max_length = kwargs.get("max_length", 128)
77
+ length = min(len(t.split()), max_length)
78
+
79
+ # Generate ids and mask
80
+ ids = [self.bos_token_id] + [random.randint(1000, 30000) for _ in range(length-2)] + [self.eos_token_id]
81
+ mask = [1] * len(ids)
82
+
83
+ # Pad if needed
84
+ if "padding" in kwargs:
85
+ pad_length = max_length - len(ids)
86
+ if pad_length > 0:
87
+ ids.extend([self.pad_token_id] * pad_length)
88
+ mask.extend([0] * pad_length)
89
+
90
+ input_ids.append(torch.tensor(ids))
91
+ attention_mask.append(torch.tensor(mask))
92
+
93
+ # Stack tensors
94
+ if "return_tensors" in kwargs and kwargs["return_tensors"] == "pt":
95
+ if is_batch or len(texts) > 1:
96
+ return {
97
+ "input_ids": torch.stack(input_ids),
98
+ "attention_mask": torch.stack(attention_mask)
99
+ }
100
+ else:
101
+ return {
102
+ "input_ids": input_ids[0].unsqueeze(0),
103
+ "attention_mask": attention_mask[0].unsqueeze(0)
104
+ }
105
+ else:
106
+ return {
107
+ "input_ids": input_ids[0] if not is_batch and len(texts) == 1 else input_ids,
108
+ "attention_mask": attention_mask[0] if not is_batch and len(texts) == 1 else attention_mask
109
+ }
110
+
111
+ def decode(self, token_ids, skip_special_tokens=True, **kwargs):
112
+ """Convert token IDs back to text"""
113
+ if isinstance(token_ids, (list, tuple)) and len(token_ids) > 0:
114
+ return f"Decoded text from {len(token_ids)} tokens"
115
+ return "Decoded text"
116
+
117
+ return DummyTokenizer()
118
+
119
+ def get_hybrid_attention_config():
120
+ """Get configuration for smart hybrid attention mechanism"""
121
+ from utils.smartHybridAttention import get_hybrid_attention_config
122
+ return get_hybrid_attention_config()
123
+
124
+ def load_transformer_model(model_name: str, device: Optional[torch.device] = None) -> AutoModel:
125
+ """
126
+ Load a transformer model.
127
+
128
+ Args:
129
+ model_name: Name of the model to load
130
+ device: Optional device to load the model on
131
+
132
+ Returns:
133
+ Loaded transformer model
134
+ """
135
+ try:
136
+ logger.info(f"Loading transformer model: {model_name}")
137
+ model = AutoModel.from_pretrained(model_name)
138
+
139
+ if device:
140
+ model = model.to(device)
141
+
142
+ logger.info(f"Successfully loaded model: {model_name}")
143
+ return model
144
+ except Exception as e:
145
+ logger.error(f"Error loading model {model_name}: {e}")
146
+ raise
147
+
148
+ def clear_cache():
149
+ """Clear all model and tokenizer caches to free memory."""
150
+ global _model_cache, _tokenizer_cache, _sentence_transformer_cache
151
+ _model_cache.clear()
152
+ _tokenizer_cache.clear()
153
+ _sentence_transformer_cache.clear()
154
+ logger.info("Cleared transformer model and tokenizer caches")
155
+
156
+ def get_embedding(text: str, model: Optional[SentenceTransformer] = None) -> torch.Tensor:
157
+ """Get embedding for a text string using a sentence transformer model."""
158
+ if model is None:
159
+ model = get_sentence_transformer(DEFAULT_SENTENCE_TRANSFORMER)
160
+ return model.encode(text, convert_to_tensor=True)