WildnerveAI commited on
Commit
bf64b40
·
verified ·
1 Parent(s): a829f5c

Upload 7 files

Browse files
Files changed (7) hide show
  1. config.json +199 -0
  2. config.py +397 -0
  3. model_Combn.py +387 -0
  4. model_Custm.py +702 -0
  5. model_List.py +138 -0
  6. model_PrTr.py +482 -0
  7. model_manager.py +735 -0
config.json ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "SELECTED_MODEL": ["model_Custm.py", "model_PrTr.py", "model_Combn.py"],
3
+ "MODEL_NAME": "Wildnerve-tlm01",
4
+ "BASE_DATA_DIR": "data",
5
+ "FILE_FORMATS": ["csv", "json", "txt"],
6
+ "MAX_SEQ_LENGTH": 512,
7
+ "SIMILARITY_THRESHOLD": 0.85,
8
+ "DATASET_PATHS": {
9
+ "general": ["data/general.json"],
10
+ "programming_software_dev": ["data/programming_software_dev.json"],
11
+ "other_information": ["data/other_information.json"]
12
+ },
13
+ "LAZY_LOADING_ENABLED": true,
14
+ "MAX_INITIAL_SPECIALIZATIONS": 2,
15
+ "train_file_path": "data/computer_advanced_debugging.json",
16
+ "NUM_EPOCHS": 50,
17
+ "LEARNING_RATE": 0.0001,
18
+ "INPUT_SIZE": 768,
19
+ "OUTPUT_SIZE": 768,
20
+ "SPECIALIZATIONS": [
21
+ "general",
22
+ "programming_software_dev"
23
+ ],
24
+ "ALL_SPECIALIZATIONS": [
25
+ "general",
26
+ "mbpp",
27
+ "programming_software_dev",
28
+ "machine_learning_ai_data_science",
29
+ "industrial_engineering",
30
+ "science_engineering",
31
+ "mathematics",
32
+ "healthcare_and_lifesciences",
33
+ "chemistry",
34
+ "hardware_devops_cloud",
35
+ "cyber_security",
36
+ "business_legal_finance",
37
+ "other_information"
38
+ ],
39
+ "PREPROCESSING": {
40
+ "LOWERCASE": true,
41
+ "REMOVE_SPECIAL_CHARACTERS": true,
42
+ "REPLACE_MULTIPLE_SPACES": true
43
+ },
44
+ "STDP_CONFIG": {
45
+ "WEIGHT_THRESHOLD": 0.5,
46
+ "ACTIVATION_THRESHOLD": 0.2,
47
+ "USE_SNN": true,
48
+ "ALPHA": 0.1,
49
+ "BETA": 0.2,
50
+ "BASE_DIR": "checkpoints",
51
+ "SNN_FILENAME_FORMAT": "snn_model_{specialization}_{epoch}.pt",
52
+ "STDPLearningRate": 0.01,
53
+ "STDPMemDecay": 0.9,
54
+ "SpikeThreshold": 0.5,
55
+ "firing_rate": 10,
56
+ "MAX_SEQ_LENGTH": 2048,
57
+ "STDP_PRETRAIN_EPOCHS": 5,
58
+ "STDP_FINETUNE_EPOCHS": 3,
59
+ "BATCH_SIZE_PRETRAIN": 32,
60
+ "BATCH_SIZE_FINETUNE": 16,
61
+ "NUM_NEURONS": 1024,
62
+ "MAX_RATE": 100
63
+ },
64
+ "TRAINING_CONFIG": {
65
+ "PATIENCE": 3,
66
+ "DELTA": 0.001,
67
+ "VERBOSE": true,
68
+ "NUM_EPOCHS": 10,
69
+ "LEARNING_RATE": 0.0001,
70
+ "TRANSFORMER_LEARNING_RATE": 5e-5,
71
+ "TRANSFORMER_NUM_EPOCHS": 5
72
+ },
73
+ "CHECKPOINT_CONFIG": {
74
+ "PATH": "checkpoints",
75
+ "BASE_DIR": "checkpoints",
76
+ "TRANSFORMER_FILENAME_FORMAT": "transformer_model_{specialization}_{epoch}.pt",
77
+ "SNN_FILENAME_FORMAT": "snn_model_{specialization}_{epoch}.pt"
78
+ },
79
+ "GENERATION_CONFIG": {
80
+ "temperature": 0.7,
81
+ "top_p": 0.9,
82
+ "num_return_sequences": 1
83
+ },
84
+ "TOKENIZER_CONFIG": {
85
+ "MODEL_NAME": "bert-base-uncased",
86
+ "MAX_SEQ_LENGTH": 512,
87
+ "POOLING_MODE": "mean"
88
+ },
89
+ "DATA_LOADER_CONFIG": {
90
+ "BATCH_SIZE": 32,
91
+ "NUM_WORKERS": 0,
92
+ "SHUFFLE": true,
93
+ "INCLUDE_CRAWL": true
94
+ },
95
+ "ATTENTION_CONFIG": {
96
+ "WINDOW_SIZE": 256,
97
+ "STRIDE": 128,
98
+ "MEMORY_SIZE": 64,
99
+ "NUM_HEADS": 8,
100
+ "ATTENTION_DROPOUT": 0.1,
101
+ "ATTENTION_TYPES": {
102
+ "SLIDING": true,
103
+ "HIERARCHICAL": true,
104
+ "GLOBAL": true
105
+ },
106
+ "PROMPT_THRESHOLDS": {
107
+ "LENGTH_THRESHOLD": 500,
108
+ "COMPLEXITY_THRESHOLD": 0.7,
109
+ "PERPLEXITY_THRESHOLD": 50
110
+ },
111
+ "ATTENTION_WEIGHTS": {
112
+ "SHORT_COMPLEX": {
113
+ "SLIDING": 0.4,
114
+ "HIERARCHICAL": 0.6
115
+ },
116
+ "LONG_CONTEXT": {
117
+ "SLIDING": 0.3,
118
+ "HIERARCHICAL": 0.4,
119
+ "GLOBAL": 0.3
120
+ }
121
+ }
122
+ },
123
+ "TRANSFORMER_CONFIG": {
124
+ "TEST_MODE": false,
125
+ "LOGGING_LEVEL": "INFO",
126
+ "LOG_FILE": "logs/training.log",
127
+ "SAVE_CHECKPOINTS": true,
128
+ "BASE_DIR": "checkpoints",
129
+ "TRANSFORMER_FILENAME_FORMAT": "transformer_model_{specialization}_{epoch}.pt",
130
+ "MODEL_NAME": "Wildnerve-tlm01-0.05Bx12",
131
+ "MAX_SEQ_LENGTH": 512,
132
+ "NUM_EPOCHS": 10,
133
+ "LEARNING_RATE": 5e-5,
134
+ "BATCH_SIZE": 32,
135
+ "EMBEDDING_DIM": 768,
136
+ "NUM_HEADS": 12,
137
+ "HIDDEN_DIM": 768,
138
+ "NUM_LAYERS": 12,
139
+ "DROPOUT": 0.1,
140
+ "specialization1": "cpp",
141
+ "specialization2": "java",
142
+ "specialization3": "go",
143
+ "specialization4": "javascript",
144
+ "specialization5": "nim",
145
+ "specialization6": "python",
146
+ "specialization7": "rust",
147
+ "specialization8": "solidity",
148
+ "specialization9": "computer",
149
+ "specialization10": "mathematics",
150
+ "specialization11": "physics",
151
+ "specialization12": "other_information",
152
+ "DATASET_PATH": "data/cpp_ai_language_model.json",
153
+ "OUTPUT_SIZE": 768,
154
+ "POOLING_MODE": "mean",
155
+ "VOCAB_SIZE": 30522,
156
+ "MAX_RATE": 100,
157
+ "MODE": "pretrained",
158
+ "MODE2": "custom",
159
+ "SHUFFLE": true,
160
+ "SIMILARITY_THRESHOLD": 0.85,
161
+ "USE_PRETRAINED_ENCODER": true,
162
+ "ATTENTION_MECHANISM": {
163
+ "TYPE": "hybrid",
164
+ "WINDOW_SIZE": 1024,
165
+ "STRIDE": 512,
166
+ "USE_MEMORY": true
167
+ },
168
+ "SPECIALIZATIONS": {
169
+ "mbpp": "mbpp",
170
+ "programming_software_dev": "programming_software_dev",
171
+ "machine_learning_ai_data_science": "machine_learning_ai_data_science",
172
+ "industrial_engineering": "industrial_engineering",
173
+ "science_engineering": "science_engineering",
174
+ "mathematics": "mathematics",
175
+ "healthcare_and_lifesciences": "healthcare_and_lifesciences",
176
+ "chemistry": "chemistry",
177
+ "hardware_devops_cloud": "hardware_devops_cloud",
178
+ "cyber_security": "cyber_security",
179
+ "business_legal_finance": "business_legal_finance",
180
+ "other_information": "other_information"
181
+ }
182
+ },
183
+ "DUAL_ENCODER_CONFIG": {
184
+ "USE_PRETRAINED_ENCODER": true,
185
+ "USE_CUSTOM_ENCODER": true,
186
+ "DEBUG": false
187
+ },
188
+ "PROMPT_ANALYZER_CONFIG": {
189
+ "MODEL_NAME": "gpt2",
190
+ "DATASET_PATH": null,
191
+ "SPECIALIZATION": null,
192
+ "HIDDEN_DIM": 768,
193
+ "MAX_CACHE_SIZE": 10
194
+ },
195
+ "MAX_ACTIVE_MODELS": 5,
196
+ "MODEL_IDLE_THRESHOLD": 600,
197
+ "MAX_MEMORY_USAGE": 0.8,
198
+ "TOP_K": 3
199
+ }
config.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py - 21/02/2025, cleaned up version, 5:14pm, C:\Users\User\OneDrive\Documents\tlm\config.py
2
+ import os
3
+ import json
4
+ import logging
5
+ import argparse
6
+ from pathlib import Path
7
+ from typing import Optional, Dict, List, Literal, Any
8
+ from pydantic import BaseModel, Field, ValidationError, ConfigDict
9
+ #from types import SimpleNamespace
10
+
11
+ # Configure logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+
20
+ class PathConfig:
21
+ """Handle path configurations"""
22
+ @staticmethod
23
+ def get_project_root() -> Path:
24
+ return Path(__file__).resolve().parent
25
+
26
+ @staticmethod
27
+ def get_data_dir() -> Path:
28
+ """Get writable data directory, falling back to temp if needed"""
29
+ # First try in project directory
30
+ project_dir = PathConfig.get_project_root()
31
+ data_dir = project_dir / "data"
32
+
33
+ # Check if we can write to this location
34
+ try:
35
+ if not data_dir.exists():
36
+ data_dir.mkdir(parents=True, exist_ok=True)
37
+ # Test write access with a small file
38
+ test_file = data_dir / ".write_test"
39
+ test_file.touch()
40
+ test_file.unlink()
41
+ return data_dir
42
+ except (PermissionError, IOError):
43
+ # Fall back to temp directory
44
+ import tempfile
45
+ tmp_dir = Path(tempfile.gettempdir()) / "wildnerve_data"
46
+ tmp_dir.mkdir(parents=True, exist_ok=True)
47
+ logger.info("Using temporary directory for data: %s", tmp_dir)
48
+ return tmp_dir
49
+
50
+ @staticmethod
51
+ def get_checkpoint_dir() -> Path:
52
+ # First try in project directory
53
+ project_dir = PathConfig.get_project_root()
54
+ checkpoint_dir = project_dir / "checkpoints"
55
+
56
+ # Check if we can write to this directory
57
+ if os.access(project_dir, os.W_OK):
58
+ return checkpoint_dir
59
+
60
+ # If not writable, fallback to temp directory
61
+ import tempfile
62
+ tmp_dir = Path(tempfile.gettempdir()) / "wildnerve_checkpoints"
63
+ return tmp_dir
64
+
65
+ # Replace the current directory setup with:
66
+ BASE_DIR = PathConfig.get_project_root()
67
+ DATA_DIR = PathConfig.get_data_dir()
68
+ CHECKPOINT_DIR = PathConfig.get_checkpoint_dir()
69
+
70
+ # Add these model architecture parameters
71
+ INPUT_SIZE = 768 # BERT base hidden size
72
+ OUTPUT_SIZE = 768 # Output embedding size
73
+ HIDDEN_SIZE = 768 # Hidden layer size
74
+
75
+ # Add SPECIALIZATIONS list
76
+ SPECIALIZATIONS = [
77
+ "cpp",
78
+ "java",
79
+ "go",
80
+ "javascript",
81
+ "nim",
82
+ "python",
83
+ "rust",
84
+ "solidity",
85
+ "computer",
86
+ "mathematics",
87
+ "physics",
88
+ "other_information"
89
+ ]
90
+
91
+ # Define DATASET_PATHS so that each specialization is a string or a list of strings
92
+ DATASET_PATHS = {
93
+ "computer": [
94
+ List[str(DATA_DIR / "data" / "computer_advanced_debugging.json")],
95
+ List[str(DATA_DIR / "data" / "computer_agenticAI.json")],
96
+ List[str(DATA_DIR / "data" / "computer_architecture.json")],
97
+ List[str(DATA_DIR / "data" / "computer_cloud_security.json")],
98
+ List[str(DATA_DIR / "data" / "computer_creativity.json")],
99
+ List[str(DATA_DIR / "data" / "computer_crossplatform.json")],
100
+ List[str(DATA_DIR / "data" / "computer_cybersecurity.json")],
101
+ List[str(DATA_DIR / "data" / "computer_error_handling_examples.json")],
102
+ List[str(DATA_DIR / "data" / "computer_gitInstruct.json")]
103
+ ],
104
+
105
+ "cpp": [
106
+ List[str(DATA_DIR / "data" / "cpp_advanced_debugging.json")],
107
+ List[str(DATA_DIR / "data" / "cpp_blockchain.json")],
108
+ List[str(DATA_DIR / "data" / "cpp_mbcppp.json")],
109
+ List[str(DATA_DIR / "data" / "cpp_programming.json")]
110
+ ],
111
+
112
+ "java": [
113
+ List[str(DATA_DIR / "data" / "java_ai_language_model.json")],
114
+ List[str(DATA_DIR / "data" / "java_blockchain.json")],
115
+ List[str(DATA_DIR / "data" / "java_mbjp.json")],
116
+ List[str(DATA_DIR / "data" / "java_programming.json")],
117
+ List[str(DATA_DIR / "data" / "java_transformer_language_model.json")],
118
+ ],
119
+
120
+ "go": [
121
+ List[str(DATA_DIR / "data" / "golang_ai_language_model.json")],
122
+ List[str(DATA_DIR / "data" / "golang_mbgp.json")],
123
+ List[str(DATA_DIR / "data" / "golang_programming.json")]
124
+ ],
125
+
126
+ "javascript": [
127
+ List[str(DATA_DIR / "data" / "javascript_chatbot.json")],
128
+ List[str(DATA_DIR / "data" / "javascript_n_Typescript_frontend.json")],
129
+ List[str(DATA_DIR / "data" / "javascript_n_Typescript_backend.json")],
130
+ List[str(DATA_DIR / "data" / "javascript_programming.json")]
131
+ ],
132
+
133
+ "nim": [
134
+ List[str(DATA_DIR / "data" / "nim_ai_language_model.json")],
135
+ List[str(DATA_DIR / "data" / "nim_blockchain.json")],
136
+ List[str(DATA_DIR / "data" / "nim_chatbot.json")],
137
+ List[str(DATA_DIR / "data" / "nim_mbnp.json")],
138
+ List[str(DATA_DIR / "data" / "nim_programming.json")]
139
+ ],
140
+
141
+ "python": [
142
+ List[str(DATA_DIR / "data" / "python_chatbot_guide.json")],
143
+ List[str(DATA_DIR / "data" / "python_mbpp.json")],
144
+ List[str(DATA_DIR / "data" / "python_programming.json")],
145
+ List[str(DATA_DIR / "data" / "python_transformer_model.json")]
146
+ ],
147
+
148
+ "rust": [
149
+ List[str(DATA_DIR / "data" / "rust_ai_language_model.json")],
150
+ List[str(DATA_DIR / "data" / "rust_blockchain.json")],
151
+ List[str(DATA_DIR / "data" / "rust_mbrp.json")],
152
+ List[str(DATA_DIR / "data" / "rust_programming.json")]
153
+ ],
154
+
155
+ "solidity": [
156
+ List[str(DATA_DIR / "data" / "solidity_programming.json")]
157
+ ],
158
+
159
+ "mathematics": [
160
+ List[str(DATA_DIR / "data" / "mathematics.json")],
161
+ List[str(DATA_DIR / "data" / "mathematics_training.json")]
162
+ ],
163
+
164
+ "physics": [
165
+ List[str(DATA_DIR / "data" / "physics_n_engineering.json")],
166
+ List[str(DATA_DIR / "data" / "physics_n_engineering_applied.json")]
167
+ ],
168
+
169
+ "other_information": [
170
+ List[str(DATA_DIR / "data" / "other_information.json")]
171
+ ]
172
+ }
173
+
174
+ # Nested configuration models
175
+ class TrainingConfig(BaseModel):
176
+ PATIENCE: int = Field(..., description="Early stopping patience")
177
+ DELTA: float = Field(..., description="Minimum change in the monitored value")
178
+ VERBOSE: bool = Field(..., description="Verbosity of training logs")
179
+ NUM_EPOCHS: int = Field(..., description="Number of training epochs")
180
+ LEARNING_RATE: float = Field(..., description="Learning rate for optimizer")
181
+ TRANSFORMER_LEARNING_RATE: float = Field(..., description="Learning rate for transformer")
182
+ TRANSFORMER_NUM_EPOCHS: int = Field(..., description="Transformer training epochs")
183
+
184
+ model_config = ConfigDict(
185
+ validate_assignment=True,
186
+ extra="allow"
187
+ )
188
+
189
+ class CheckpointConfig(BaseModel):
190
+ PATH: str = Field(..., description="Checkpoint saving folder")
191
+ BASE_DIR: str = Field(..., description="Base directory for checkpoints")
192
+ TRANSFORMER_FILENAME_FORMAT: str = Field(..., description="Transformer checkpoint filename format")
193
+ SNN_FILENAME_FORMAT: str = Field(..., description="SNN checkpoint filename format")
194
+
195
+ model_config = ConfigDict(
196
+ validate_assignment=True,
197
+ extra="allow"
198
+ )
199
+
200
+ class TokenizerConfig(BaseModel):
201
+ MODEL_NAME: str = Field(..., description="Name of the tokenizer model")
202
+ MAX_SEQ_LENGTH: int = Field(..., description="Maximum length the tokenizer handles")
203
+ POOLING_MODE: str = Field(..., description="Pooling mode for embeddings")
204
+
205
+ model_config = ConfigDict(
206
+ validate_assignment=True,
207
+ extra="allow"
208
+ )
209
+
210
+ class DataLoaderConfig(BaseModel):
211
+ SHUFFLE: bool = Field(..., description="Whether to Shuffle the dataset")
212
+ BATCH_SIZE: int = Field(..., description="Batch size for dataloader")
213
+ NUM_WORKERS: int = Field(..., description="Number of workers for dataloader")
214
+ INCLUDE_CRAWL: bool = Field(..., description="Include crawl parameter")
215
+
216
+ model_config = ConfigDict(
217
+ validate_assignment=True,
218
+ extra="allow"
219
+ )
220
+
221
+ class GenerationConfig(BaseModel):
222
+ temperature: float = Field(0.7, description="Decoding temperature.")
223
+ top_p: float = Field(0.9, description="Nucleus sampling probability.")
224
+ num_return_sequences: int = Field(1, description="Number of sequences to generate.")
225
+
226
+ model_config = ConfigDict(
227
+ validate_assignment=True,
228
+ extra="allow"
229
+ )
230
+
231
+ class PretrainedLimitsConfig(BaseModel):
232
+ GPT2: int = Field(1024, description="Maximum sequence length for GPT-2")
233
+ BERT: int = Field(512, description="Maximum sequence length for BERT")
234
+
235
+ model_config = ConfigDict(
236
+ validate_assignment=True,
237
+ extra="allow"
238
+ )
239
+
240
+ class CustomWindowsConfig(BaseModel):
241
+ MAX_SEQ_LENGTH: int = Field(2048, description="Maximum sequence length for custom models")
242
+ WINDOW_SIZE: int = Field(1024, description="Window size for sliding window attention")
243
+ STRIDE: int = Field(512, description="Stride for sliding window attention")
244
+
245
+ model_config = ConfigDict(
246
+ validate_assignment=True,
247
+ extra="allow"
248
+ )
249
+
250
+ class AttentionConfig(BaseModel):
251
+ PRETRAINED_LIMITS: PretrainedLimitsConfig = Field(default_factory=PretrainedLimitsConfig)
252
+ CUSTOM_WINDOWS: CustomWindowsConfig = Field(default_factory=CustomWindowsConfig)
253
+
254
+ model_config = ConfigDict(
255
+ validate_assignment=True,
256
+ extra="allow"
257
+ )
258
+
259
+ class TransformerConfig(BaseModel):
260
+ ATTENTION_MECHANISM: Dict[str, Any] = Field(
261
+ default={
262
+ "TYPE": "hybrid",
263
+ "WINDOW_SIZE": 1024,
264
+ "STRIDE": 512,
265
+ "USE_MEMORY": True,
266
+ "ATTENTION_TYPES": {
267
+ "SLIDING": True,
268
+ "HIERARCHICAL": True,
269
+ "GLOBAL": True
270
+ }
271
+ },
272
+ description="Attention mechanism configuration"
273
+ )
274
+
275
+ BASE_DIR: str = Field(..., description="Base directory for transformer checkpoints")
276
+ TRANSFORMER_FILENAME_FORMAT: str = Field(..., description="Filename format for transformer checkpoints")
277
+ MODEL_NAME: str = Field("bert-base-uncased", description="Name of the primary model from Hugging Face") # Changed from Wildnerve-tlm01
278
+ NUM_EPOCHS: int = Field(30, description="Number of epochs for transformer training") # Increased from whatever value was here before
279
+ LEARNING_RATE: float = Field(..., description="Learning rate for transformer")
280
+ BATCH_SIZE: int = Field(..., description="Batch size for transformer training")
281
+ EMBEDDING_DIM: int = Field(..., description="Embedding dimension")
282
+ NUM_HEADS: int = Field(..., description="Number of attention heads")
283
+ HIDDEN_DIM: int = Field(..., description="Hidden dimension")
284
+ NUM_LAYERS: int = Field(..., description="Number of layers")
285
+ DROPOUT: float = Field(..., description="Dropout rate")
286
+ specialization: str = Field(..., description="Specialization type")
287
+ DATASET_PATH: str = Field(..., description="Path to the dataset")
288
+ OUTPUT_SIZE: int = Field(..., description="Size of the output (usually vocab size)")
289
+ MAX_SEQ_LENGTH: int = Field(..., description="Maximum sequence length")
290
+ POOLING_MODE: str = Field(..., description="Pooling mode")
291
+ VOCAB_SIZE: int = Field(..., description="Vocabulary size")
292
+ MAX_RATE: int = Field(..., description="Maximum rate")
293
+ MODE: str = Field(..., description="Model mode")
294
+ MODE2: str = Field(..., description="Secondary mode")
295
+ SHUFFLE: bool = Field(..., description="Shuffle flag for transformer")
296
+ SIMILARITY_THRESHOLD: float = Field(..., description="Similarity threshold for weight sharing")
297
+ USE_PRETRAINED_ENCODER: bool = Field(..., description="Enable pretrained encoder branch")
298
+
299
+ model_config = ConfigDict(
300
+ validate_assignment=True,
301
+ extra="allow"
302
+ )
303
+
304
+ class PreprocessingConfig(BaseModel):
305
+ LOWERCASE: bool = Field(True, description="Convert text to lowercase")
306
+ REMOVE_SPECIAL_CHARACTERS: bool = Field(True, description="Remove special characters from text")
307
+ REPLACE_MULTIPLE_SPACES: bool = Field(True, description="Replace multiple spaces with a single space")
308
+
309
+ model_config = ConfigDict(
310
+ validate_assignment=True,
311
+ extra="allow"
312
+ )
313
+
314
+ class STDPConfig(BaseModel):
315
+ WEIGHT_THRESHOLD: float = Field(..., description="Threshold for STDP weight update")
316
+ ACTIVATION_THRESHOLD: float = Field(..., description="Threshold for STDP activation")
317
+ USE_SNN: bool = Field(..., description="Use spiking neural network")
318
+ ALPHA: float = Field(..., description="STDP alpha parameter")
319
+ BETA: float = Field(..., description="STDP beta parameter")
320
+ BASE_DIR: str = Field(..., description="Directory for STDP checkpoints")
321
+ SNN_FILENAME_FORMAT: str = Field(..., description="Filename format for SNN checkpoints")
322
+ STDPLearningRate: float = Field(..., description="STDP learning rate")
323
+ STDPMemDecay: float = Field(..., description="STDP memory decay factor")
324
+ SpikeThreshold: float = Field(..., description="Spike threshold")
325
+ firing_rate: int = Field(..., description="Firing rate")
326
+ MAX_SEQ_LENGTH: int = Field(..., description="Maximum sequence length")
327
+ STDP_PRETRAIN_EPOCHS: int = Field(..., description="Pre-training epochs for STDP")
328
+ STDP_FINETUNE_EPOCHS: int = Field(..., description="Fine-tuning epochs for STDP")
329
+ BATCH_SIZE_PRETRAIN: int = Field(..., description="Batch size during STDP pre-training")
330
+ BATCH_SIZE_FINETUNE: int = Field(..., description="Batch size during STDP fine-tuning")
331
+ NUM_NEURONS: int = Field(..., description="Number of neurons in the STDP model")
332
+ MAX_RATE: int = Field(..., description="Maximum rate for STDP")
333
+
334
+ model_config = ConfigDict(
335
+ validate_assignment=True,
336
+ extra="allow"
337
+ )
338
+
339
+ class AppConfig(BaseModel):
340
+ DATA_DIR: str = Field(default="/tmp/tlm_data")
341
+ MODEL_DIR: str = Field(default="/tmp/tlm_data/models")
342
+ # Change the type from dict to TransformerConfig so that attributes can be accessed:
343
+ TRANSFORMER_CONFIG: TransformerConfig = Field(default_factory=TransformerConfig)
344
+ SIMILARITY_THRESHOLD: float = Field(default=0.85)
345
+ TOP_K: int = Field(default=3)
346
+ # ... add other expected fields here ...
347
+
348
+ import json
349
+ import logging
350
+ import os
351
+ logger = logging.getLogger(__name__)
352
+
353
+ def load_config():
354
+ config_path = os.path.join(os.path.dirname(__file__), "config.json")
355
+ logger.info(f"Attempting to load config from: {config_path}")
356
+ try:
357
+ with open(config_path, "r") as f:
358
+ config = json.load(f)
359
+ logger.info(f"Config loaded successfully: {config}")
360
+ return config
361
+ except Exception as e:
362
+ logger.error(f"Failed to load config: {e}")
363
+ return {}
364
+
365
+ def load_config():
366
+ config_path = os.path.join(os.path.dirname(__file__), "config.json")
367
+ try:
368
+ with open(config_path, "r") as f:
369
+ config = json.load(f)
370
+ except Exception as e:
371
+ raise RuntimeError(f"Failed to load config file: {e}")
372
+ # Ensure keys exist and are of the expected type:
373
+ config["DATA_DIR"] = config.get("DATA_DIR", "/tmp/tlm_data")
374
+ config["DATASET_PATHS"] = config.get("DATASET_PATHS", {})
375
+ if not isinstance(config["DATASET_PATHS"], dict):
376
+ config["DATASET_PATHS"] = {}
377
+ config["TRANSFORMER_CONFIG"] = config.get("TRANSFORMER_CONFIG", {})
378
+ if not isinstance(config["TRANSFORMER_CONFIG"], dict):
379
+ config["TRANSFORMER_CONFIG"] = {}
380
+ config["SIMILARITY_THRESHOLD"] = float(config.get("SIMILARITY_THRESHOLD", 0.85))
381
+ config["TOP_K"] = int(config.get("TOP_K", 3))
382
+ config["MAX_ACTIVE_MODELS"] = int(config.get("MAX_ACTIVE_MODELS", 2))
383
+ config["MODEL_IDLE_THRESHOLD"] = int(config.get("MODEL_IDLE_THRESHOLD", 600))
384
+ # Also fix MAX_SEQ_LENGTH if provided at root level; fallback to TRANSFORMER_CONFIG
385
+ if "MAX_SEQ_LENGTH" in config:
386
+ config["MAX_SEQ_LENGTH"] = int(config["MAX_SEQ_LENGTH"])
387
+ else:
388
+ config["MAX_SEQ_LENGTH"] = int(config["TRANSFORMER_CONFIG"].get("MAX_SEQ_LENGTH", 512))
389
+ return config
390
+
391
+ # Load config on import
392
+ app_config = load_config()
393
+
394
+ if __name__ == "__main__":
395
+ args = argparse.ArgumentParser(description="Tiny Language Model Configuration").parse_args()
396
+ print("Configuration loaded:")
397
+ print(app_config)
model_Combn.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, math, torch, logging, importlib
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from config import load_config
5
+ from service_registry import registry, MODEL, TOKENIZER
6
+ from transformers import AutoTokenizer, AutoModel
7
+ from typing import Optional, List, Dict, Any, Union, Tuple
8
+ from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
9
+ from base_interfaces.common_types import *
10
+ from base_interfaces.model_interface import AbstractModel
11
+
12
+ app_config = load_config()
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class PositionalEncoding(nn.Module):
16
+ def __init__(self, d_model: int, max_len: int = app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH):
17
+ super().__init__()
18
+ pe = torch.zeros(max_len, d_model)
19
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
20
+ div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float)*( -math.log(10000.0)/d_model))
21
+ pe[:, 0::2] = torch.sin(position*div_term)
22
+ pe[:, 1::2] = torch.cos(position*div_term)
23
+ pe = pe.unsqueeze(1)
24
+ self.register_buffer("pe", pe)
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ seq_len = x.size(0)
27
+ return x + self.pe[:seq_len]
28
+
29
+ class Wildnerve_tlm01(nn.Module, AbstractModel):
30
+ def __init__(self, vocab_size: int, specialization: str, dataset_path: str, model_name: str, embedding_dim: int,
31
+ num_heads: int, hidden_dim: int, num_layers: int, output_size: int, dropout: float,
32
+ max_seq_length: int, pooling_mode: str, use_pretrained_encoder: bool = False, use_custom_encoder: bool = True, debug: bool = False) -> None:
33
+ super(Wildnerve_tlm01, self).__init__()
34
+ self.specialization = specialization
35
+ self.dataset_path = dataset_path
36
+ self.model_name = model_name
37
+ self.pooling_mode = pooling_mode
38
+ self.embedding_dim = embedding_dim
39
+ self.vocab_size = vocab_size
40
+ self.max_seq_length = max_seq_length
41
+ self.num_heads = num_heads
42
+ self.hidden_dim = hidden_dim
43
+ self.num_layers = num_layers
44
+ self.output_size = output_size
45
+ self.dropout = dropout
46
+ self.use_pretrained_encoder = use_pretrained_encoder
47
+ self.use_custom_encoder = use_custom_encoder
48
+ self.debug = debug
49
+ if use_pretrained_encoder:
50
+ try:
51
+ from transformers import AutoTokenizer, AutoModel
52
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
53
+ self.pretrained_encoder = AutoModel.from_pretrained("gpt2")
54
+ logger.info("Loaded GPT-2 for pretrained encoder")
55
+ except Exception as e:
56
+ logger.warning(f"GPT-2 load failed: {e} - falling back to bert-base-uncased")
57
+ from transformers import AutoTokenizer, AutoModel
58
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
59
+ self.pretrained_encoder = AutoModel.from_pretrained("bert-base-uncased")
60
+ # Projection layer to convert pretrained output (assumed 768) to embedding_dim
61
+ self.pretrained_projection = nn.Linear(768, embedding_dim)
62
+ else:
63
+ self.tokenizer = None
64
+ self.pretrained_encoder = None
65
+ if use_custom_encoder:
66
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
67
+ self.pos_encoder = PositionalEncoding(embedding_dim, max_len=max_seq_length)
68
+ self.token_type_embeddings = nn.Embedding(2, embedding_dim)
69
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
70
+ dim_feedforward=hidden_dim, dropout=dropout, batch_first=True)
71
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
72
+ attention_config = get_hybrid_attention_config()
73
+ attention_config["NUM_HEADS"] = num_heads
74
+ attention_config["WINDOW_SIZE"] = max(256, max_seq_length//4)
75
+ self.hybrid_attention = SmartHybridAttention(attention_config)
76
+ self.tgt_embedding = nn.Embedding(vocab_size, embedding_dim)
77
+ self.pos_decoder = PositionalEncoding(embedding_dim, max_len=max_seq_length)
78
+ decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=num_heads,
79
+ dim_feedforward=hidden_dim, dropout=dropout, batch_first=True)
80
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
81
+ self.adapter = nn.Sequential(nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, embedding_dim))
82
+ self.classifier = nn.Linear(embedding_dim, vocab_size)
83
+ self.dropout_layer = nn.Dropout(dropout)
84
+ self.init_weights()
85
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
86
+
87
+ def init_weights(self) -> None:
88
+ initrange = 0.1
89
+ with torch.no_grad():
90
+ self.embedding.weight.uniform_(-initrange, initrange)
91
+ self.tgt_embedding.weight.uniform_(-initrange, initrange)
92
+ self.classifier.weight.uniform_(-initrange, initrange)
93
+ self.classifier.bias.zero_()
94
+ for layer in self.adapter:
95
+ if isinstance(layer, nn.Linear):
96
+ layer.weight.uniform_(-initrange, initrange)
97
+ if layer.bias is not None:
98
+ layer.bias.zero_()
99
+
100
+ def forward(self, src: torch.Tensor, tgt: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
101
+ src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None,
102
+ src_key_padding_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None,
103
+ return_sequence: bool = False, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
104
+ if src is None and input_ids is not None:
105
+ src = input_ids
106
+ if src_key_padding_mask is None and attention_mask is not None:
107
+ src_key_padding_mask = attention_mask
108
+ if src.dim() == 2:
109
+ pass
110
+ elif src.dim() == 3 and src.size(0) > src.size(1):
111
+ src = src.transpose(0, 1)
112
+ src_emb = self.embedding(src)*math.sqrt(self.embedding_dim)
113
+ src_emb = self.pos_encoder(src_emb.transpose(0, 1)).transpose(0, 1)
114
+ if src.size(1) > 256 and hasattr(self, "hybrid_attention"):
115
+ query = src_emb.transpose(0, 1)
116
+ key = query
117
+ value = query
118
+ attended, _ = self.hybrid_attention(query=query, key=key, value=value, key_padding_mask=src_key_padding_mask, attn_mask=src_mask, prompt_length=src.size(1), prompt_complexity=0.5)
119
+ encoded_src = attended.transpose(0, 1)
120
+ else:
121
+ encoded_src = self.transformer_encoder(src_emb, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
122
+ adapted = self.adapter(encoded_src)
123
+ if tgt is not None:
124
+ if tgt.dim() == 3 and tgt.size(0)>tgt.size(1):
125
+ tgt = tgt.transpose(0,1)
126
+ tgt_emb = self.tgt_embedding(tgt)*math.sqrt(self.embedding_dim)
127
+ tgt_emb = self.pos_decoder(tgt_emb.transpose(0,1)).transpose(0,1)
128
+ decoded = self.transformer_decoder(tgt_emb, adapted, tgt_mask=tgt_mask, memory_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask)
129
+ output = self.classifier(decoded)
130
+ if not return_sequence:
131
+ output = output.mean(dim=1)
132
+ else:
133
+ if self.pooling_mode=="mean":
134
+ pooled = adapted.mean(dim=1)
135
+ elif self.pooling_mode=="max":
136
+ pooled = torch.max(adapted, dim=1)[0]
137
+ elif self.pooling_mode=="cls":
138
+ pooled = adapted[:, 0]
139
+ else:
140
+ pooled = adapted.mean(dim=1)
141
+ pooled = self.dropout_layer(pooled)
142
+ output = self.classifier(pooled)
143
+ return output
144
+
145
+ def encode_sentences(self, sentences, batch_size=32, normalize_embeddings=True):
146
+ self.eval()
147
+ from torch.utils.data import DataLoader, Dataset
148
+ if isinstance(sentences, str):
149
+ sentences = [sentences]
150
+ class SentencesDataset(Dataset):
151
+ def __init__(self, sentences, tokenizer, max_length):
152
+ self.sentences = sentences
153
+ self.tokenizer = tokenizer
154
+ self.max_length = max_length
155
+ def __len__(self): return len(self.sentences)
156
+ def __getitem__(self, idx):
157
+ return self.tokenizer(self.sentences[idx], padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
158
+ dataset = SentencesDataset(sentences, self.tokenizer, self.max_seq_length)
159
+ dataloader = DataLoader(dataset, batch_size=batch_size)
160
+ all_emb = []
161
+ device = next(self.parameters()).device
162
+ with torch.no_grad():
163
+ for batch in dataloader:
164
+ inputs = {k: v.squeeze(1).to(device) for k,v in batch.items()}
165
+ outputs = self(inputs["input_ids"], src_key_padding_mask=inputs.get("attention_mask"))
166
+ if normalize_embeddings:
167
+ outputs = torch.nn.functional.normalize(outputs, p=2, dim=1)
168
+ all_emb.append(outputs.cpu().numpy())
169
+ return np.vstack(all_emb)
170
+
171
+ def similarity(self, sentence1: str, sentence2: str) -> float:
172
+ emb = self.encode_sentences([sentence1, sentence2])
173
+ return np.dot(emb[0], emb[1])/(np.linalg.norm(emb[0])*np.linalg.norm(emb[1]))
174
+
175
+ def generate(self, input_ids: torch.Tensor, max_length: int = app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH, device: str = "cpu", temperature: float = 1.0, start_token_id: Optional[int] = None) -> List[List[int]]:
176
+ self.eval()
177
+ batch_size = input_ids.shape[0]
178
+ start_token_id = start_token_id or (input_ids[0,0].item() if input_ids.numel()>0 else 0)
179
+ generated = [[start_token_id] for _ in range(batch_size)]
180
+ src = input_ids.transpose(0,1)
181
+ src_emb = self.embedding(src)*math.sqrt(self.embedding_dim)
182
+ src_emb = self.pos_encoder(src_emb)
183
+ encoded_src = self.transformer_encoder(src_emb)
184
+ encoded_src = self.adapter(encoded_src)
185
+ for _ in range(max_length -1):
186
+ current_tgt = torch.tensor(generated, dtype=torch.long, device=device).transpose(0,1)
187
+ tgt_emb = self.tgt_embedding(current_tgt)*math.sqrt(self.embedding_dim)
188
+ tgt_emb = self.pos_decoder(tgt_emb)
189
+ current_seq_length = current_tgt.size(0)
190
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(current_seq_length).to(device)
191
+ decoded = self.transformer_decoder(tgt_emb, encoded_src, tgt_mask=tgt_mask)
192
+ logits = self.classifier(decoded[-1, :, :])
193
+ if temperature==0:
194
+ next_tokens = torch.argmax(logits, dim=-1)
195
+ else:
196
+ probs = torch.softmax(logits/temperature, dim=-1)
197
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
198
+ next_tokens = next_tokens.cpu().tolist()
199
+ for i, token in enumerate(next_tokens):
200
+ generated[i].append(token)
201
+ return generated
202
+
203
+ def decode_tokens(self, token_ids: List[int]) -> str:
204
+ try:
205
+ return self.tokenizer.decode(token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
206
+ except Exception as e:
207
+ logger.error(f"Decoding error: {e}")
208
+ return str(e)
209
+
210
+ def generate_with_decoding(self, input_ids: torch.Tensor, max_length: int = app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH, device: str = "cpu", temperature: float = 1.0, start_token_id: Optional[int] = None) -> str:
211
+ generated_sequences = self.generate(input_ids, max_length, device, temperature, start_token_id)
212
+ if generated_sequences:
213
+ return self.decode_tokens(generated_sequences[0])
214
+ return ""
215
+
216
+ def generate_streaming(self, prompt, **kwargs):
217
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_length).to(self.device)
218
+ with torch.no_grad():
219
+ outputs = self(inputs.input_ids)
220
+ next_token_logits = outputs[:, -1, :]
221
+ if "temperature" in kwargs and kwargs["temperature"] > 0:
222
+ next_token_logits /= kwargs["temperature"]
223
+ probs = torch.softmax(next_token_logits, dim=-1)
224
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
225
+ generated_ids = next_token
226
+ token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True)
227
+ yield token_text
228
+ max_length = kwargs.get("max_length", 100)
229
+ for _ in range(max_length-1):
230
+ context_ids = torch.cat([inputs.input_ids, generated_ids.unsqueeze(0)], dim=1)
231
+ outputs = self(context_ids)
232
+ next_token_logits = outputs[:, -1, :]
233
+ if "temperature" in kwargs and kwargs["temperature"] > 0:
234
+ next_token_logits /= kwargs["temperature"]
235
+ probs = torch.softmax(next_token_logits, dim=-1)
236
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
237
+ generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=0)
238
+ token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True)
239
+ if next_token.item() == self.tokenizer.eos_token_id:
240
+ break
241
+ yield token_text
242
+
243
+ def forward_with_custom_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
244
+ try:
245
+ device = next(self.parameters()).device
246
+ embeddings = embeddings.to(device)
247
+ batch_first = True
248
+ if not batch_first and embeddings.shape[0] <= embeddings.shape[1]:
249
+ embeddings = embeddings.transpose(0,1)
250
+ if hasattr(self, "pos_encoder"):
251
+ if batch_first:
252
+ embeddings = self.pos_encoder(embeddings)
253
+ else:
254
+ embeddings = self.pos_encoder(embeddings.transpose(0,1)).transpose(0,1)
255
+ encoded = self.transformer_encoder(embeddings)
256
+ if hasattr(self, "adapter"):
257
+ encoded = self.adapter(encoded)
258
+ if self.pooling_mode=="mean":
259
+ pooled = encoded.mean(dim=1)
260
+ elif self.pooling_mode=="max":
261
+ pooled = torch.max(encoded, dim=1)[0]
262
+ elif self.pooling_mode=="cls":
263
+ pooled = encoded[:,0]
264
+ else:
265
+ pooled = encoded.mean(dim=1)
266
+ pooled = self.dropout_layer(pooled)
267
+ output = self.classifier(pooled)
268
+ return output
269
+ except Exception as e:
270
+ logger.error(f"Custom embeddings forward error: {e}")
271
+ batch_size = embeddings.size(0)
272
+ return torch.zeros((batch_size, self.output_size), device=device)
273
+
274
+ def forward_with_error_handling(self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
275
+ try:
276
+ return self.forward(src=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **kwargs)
277
+ except RuntimeError as e:
278
+ if "shape" in str(e):
279
+ logger.warning(f"Shape error: {e}")
280
+ try:
281
+ embedded = self.embedding(input_ids)
282
+ if hasattr(self, "pos_encoder"):
283
+ embedded = self.pos_encoder(embedded)
284
+ encoder_output = self.transformer_encoder(embedded)
285
+ if self.pooling_mode=="mean":
286
+ pooled = encoder_output.mean(dim=1)
287
+ elif self.pooling_mode=="max":
288
+ pooled = torch.max(encoder_output, dim=1)[0]
289
+ elif self.pooling_mode=="cls":
290
+ pooled = encoder_output[:,0]
291
+ else:
292
+ pooled = encoder_output.mean(dim=1)
293
+ pooled = self.dropout_layer(pooled)
294
+ return self.classifier(pooled)
295
+ except Exception as inner_e:
296
+ logger.error(f"Error adapting input: {inner_e}")
297
+ batch_size = input_ids.size(0) if input_ids is not None else 1
298
+ return torch.zeros((batch_size, self.output_size), device=self.device)
299
+ raise
300
+ except Exception as e:
301
+ logger.error(f"Unhandled error: {e}")
302
+ batch_size = input_ids.size(0) if input_ids is not None else 1
303
+ return torch.zeros((batch_size, self.output_size), device=self.device)
304
+
305
+ def train_with_emissions_tracking(self, dataloader, optimizer, criterion, num_epochs=1):
306
+ from codecarbon import EmissionsTracker
307
+ tracker = EmissionsTracker()
308
+ tracker.start()
309
+ self.train()
310
+ for epoch in range(num_epochs):
311
+ for batch in dataloader:
312
+ inputs, labels = batch
313
+ inputs, labels = inputs.to(self.device), labels.to(self.device)
314
+ optimizer.zero_grad()
315
+ outputs = self(inputs)
316
+ loss = criterion(outputs, labels)
317
+ loss.backward()
318
+ optimizer.step()
319
+ logger.info(f"Epoch {epoch+1} completed.")
320
+ emissions = tracker.stop()
321
+ logger.info(f"Training emissions: {emissions:.4f} kg CO2")
322
+
323
+ def infer_with_emissions_tracking(self, input_ids):
324
+ from codecarbon import EmissionsTracker
325
+ tracker = EmissionsTracker()
326
+ tracker.start()
327
+ self.eval()
328
+ with torch.no_grad():
329
+ outputs = self(input_ids)
330
+ emissions = tracker.stop()
331
+ logger.info(f"Inference emissions: {emissions:.4f} kg CO2")
332
+ return outputs
333
+
334
+ def decode_tokens(self, token_ids: List[int]) -> str:
335
+ try:
336
+ return self.tokenizer.decode(token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
337
+ except Exception as e:
338
+ logger.error(f"Decoding error: {e}")
339
+ return "Error decoding tokens"
340
+
341
+ def generate_with_decoding(self, input_ids, max_length=100, **kwargs) -> str:
342
+ generated_ids = self.generate(input_ids, max_length=max_length, **kwargs)
343
+ if generated_ids and len(generated_ids)>0:
344
+ return self.decode_tokens(generated_ids[0])
345
+ return ""
346
+
347
+ def generate_streaming(self, **kwargs):
348
+ device = next(self.parameters()).device
349
+ input_ids = kwargs.get("input_ids")
350
+ prompt = kwargs.get("prompt")
351
+ if prompt and not input_ids and self.tokenizer:
352
+ input_ids = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids.to(device)
353
+ if input_ids is None:
354
+ raise ValueError("Input must be provided")
355
+ max_length = kwargs.get("max_length", 100)
356
+ generated_ids = None
357
+ with torch.no_grad():
358
+ outputs = self(input_ids)
359
+ next_token_logits = outputs[:, -1, :]
360
+ if kwargs.get("temperature", 1.0) > 0:
361
+ next_token_logits /= kwargs["temperature"]
362
+ probs = torch.softmax(next_token_logits, dim=-1)
363
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
364
+ generated_ids = next_token
365
+ token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True)
366
+ yield token_text
367
+ for _ in range(max_length-1):
368
+ context_ids = torch.cat([input_ids, generated_ids.unsqueeze(0)], dim=1)
369
+ outputs = self(context_ids)
370
+ next_token_logits = outputs[:, -1, :]
371
+ if kwargs.get("temperature", 1.0) > 0:
372
+ next_token_logits /= kwargs["temperature"]
373
+ probs = torch.softmax(next_token_logits, dim=-1)
374
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
375
+ generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=0)
376
+ token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True)
377
+ if next_token.item() == self.tokenizer.eos_token_id:
378
+ break
379
+ yield token_text
380
+
381
+ # Register CombinedModel in registry
382
+ registry.register("model_class_combined", Wildnerve_tlm01)
383
+
384
+ def initialize_combined_model():
385
+ # For now, simply call the constructor with a sample config.
386
+ config = {"EMBEDDING_DIM":768, "OUTPUT_SIZE":768, "MODEL_NAME":"bert-base-uncased", "MAX_SEQ_LENGTH":512}
387
+ return Wildnerve_tlm01(**config)
model_Custm.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_Custm.py
2
+ import os
3
+ import sys
4
+ import math
5
+ import torch
6
+ import logging
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ from typing import Optional, List, Dict, Union
10
+
11
+ # Import the carbon tracking early - before transformers
12
+ from codecarbon import EmissionsTracker # Import EmissionsTracker
13
+
14
+ # Apply patches before importing transformers
15
+ import transformer_patches
16
+
17
+ # Now we can safely import transformers
18
+ import transformers
19
+
20
+ # Continue with standard imports
21
+ from service_registry import registry, MODEL, TOKENIZER
22
+ from utils.transformer_utils import get_tokenizer
23
+ from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
24
+
25
+ # Import base interfaces
26
+ from base_interfaces.common_types import *
27
+ from base_interfaces.model_interface import AbstractModel
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Check if transformers integrations has CodeCarbonCallback
32
+ if hasattr(transformers, 'integrations') and hasattr(transformers.integrations, 'CodeCarbonCallback'):
33
+ logger.info("transformers.integrations.CodeCarbonCallback is available")
34
+
35
+ # Check if we're using our proxy or the real implementation
36
+ if hasattr(transformers.integrations, 'CodeCarbonCallback'):
37
+ callback_module = transformers.integrations.CodeCarbonCallback.__module__
38
+ if callback_module == 'carbon_tracking':
39
+ logger.info("Using our clean architecture implementation for CodeCarbonCallback")
40
+ else:
41
+ logger.info(f"Using original implementation for CodeCarbonCallback from {callback_module}")
42
+
43
+ # Continue with existing code
44
+ try:
45
+ if 'TLM_DATA_DIR' in os.environ:
46
+ data_dir = os.environ.get('TLM_DATA_DIR', '/tmp/tlm_data')
47
+ model_dir = os.path.join(data_dir, "models")
48
+ logging.info(f"Using data directory from environment: {data_dir}")
49
+ from types import SimpleNamespace
50
+ app_config = SimpleNamespace()
51
+ app_config.DATA_DIR = data_dir
52
+ app_config.MODEL_DIR = model_dir
53
+ app_config.TRANSFORMER_CONFIG = SimpleNamespace()
54
+ app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH = 512
55
+ try:
56
+ from config import load_config, app_config as config_app_config
57
+ app_config = load_config() if not hasattr(config_app_config, 'DATA_DIR') else config_app_config
58
+ except Exception as config_error:
59
+ logging.warning(f"Using minimal config due to error: {config_error}")
60
+ else:
61
+ from config import load_config, app_config as config_app_config
62
+ app_config = load_config() if not hasattr(config_app_config, 'DATA_DIR') else config_app_config
63
+ except Exception as e:
64
+ logging.warning(f"Error importing config: {e}")
65
+ from types import SimpleNamespace
66
+ app_config = SimpleNamespace()
67
+ app_config.DATA_DIR = '/tmp/tlm_data'
68
+ app_config.MODEL_DIR = '/tmp/tlm_data/models'
69
+ app_config.TRANSFORMER_CONFIG = SimpleNamespace()
70
+ app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH = 512
71
+
72
+ # Ensure the necessary directories exist, but don't fail if they can't be created
73
+ try:
74
+ os.makedirs(getattr(app_config, "DATA_DIR", "/tmp/tlm_data"), exist_ok=True)
75
+ os.makedirs(getattr(app_config, "MODEL_DIR", "/tmp/tlm_data/models"), exist_ok=True)
76
+ except Exception as e:
77
+ logging.warning(f"Could not create directories: {e}")
78
+
79
+ # Configure logging and suppress TensorFlow warnings
80
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
81
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
82
+
83
+ # ----------------------------
84
+ # Positional Encoding Module
85
+ # ----------------------------
86
+ class PositionalEncoding(nn.Module):
87
+ def __init__(self, d_model: int, max_len: int = app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH):
88
+ super().__init__()
89
+ pe = torch.zeros(max_len, d_model)
90
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
91
+ div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model))
92
+ pe[:, 0::2] = torch.sin(position * div_term)
93
+ pe[:, 1::2] = torch.cos(position * div_term)
94
+ pe = pe.unsqueeze(1) # shape: (max_len, 1, d_model)
95
+ self.register_buffer("pe", pe)
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ # x shape: (seq_len, batch_size, d_model)
98
+ seq_len = x.size(0)
99
+ x = x + self.pe[:seq_len]
100
+ return x
101
+
102
+ # ----------------------------
103
+ # Wildnerve-tlm01 using Only Custom Encoder/Decoder
104
+ # ----------------------------
105
+ class Wildnerve_tlm01(nn.Module, AbstractModel):
106
+ """A Transformer-based Tiny Language Model that uses:
107
+ - A custom built encoder & decoder (embedding, positional encoding, and TransformerEncoder)
108
+ - An adapter and classifier for post-processing
109
+ - The AutoTokenizer for consistent tokenization and decoding
110
+ - SmartHybridAttention for better context handling"""
111
+ def __init__(
112
+ self,
113
+ vocab_size=30522, # Default BERT vocab size
114
+ specialization="general",
115
+ dataset_path=None,
116
+ model_name="Wildnerve-tlm01-0.05Bx12", # Primary model name
117
+ embedding_dim=768,
118
+ num_heads=12,
119
+ hidden_dim=768,
120
+ num_layers=6,
121
+ output_size=768,
122
+ dropout=0.1,
123
+ max_seq_length=512,
124
+ pooling_mode="mean",
125
+ tokenizer=None, # Accept tokenizer as parameter
126
+ **kwargs # Accept additional kwargs for compatibility
127
+ ) -> None:
128
+ super().__init__()
129
+ # Set device once at the start
130
+ object.__setattr__(self, "device", torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
131
+ self.specialization = specialization
132
+ self.dataset_path = dataset_path
133
+ self.model_name = model_name
134
+ self.pooling_mode = pooling_mode
135
+ self.embedding_dim = embedding_dim
136
+ self.vocab_size = vocab_size
137
+ self.max_seq_length = max_seq_length
138
+ self.num_heads = num_heads
139
+ self.hidden_dim = hidden_dim
140
+ self.num_layers = num_layers
141
+ self.output_size = output_size
142
+ self.dropout = dropout
143
+
144
+ # Optionally track model usage
145
+ self.model_last_used = {}
146
+
147
+ # Unified tokenizer initialization:
148
+ if tokenizer is not None:
149
+ self.tokenizer = tokenizer
150
+ else:
151
+ if registry.has(TOKENIZER):
152
+ self.tokenizer = registry.get(TOKENIZER)
153
+ else:
154
+ try:
155
+ from transformers import AutoTokenizer
156
+ self.tokenizer = AutoTokenizer.from_pretrained("Wildnerve-tlm01-0.05Bx12")
157
+ logger.info("Loaded primary tokenizer: Wildnerve-tlm01-0.05Bx12")
158
+ except Exception as e:
159
+ logger.warning(f"Primary tokenizer load failed: {e}")
160
+ try:
161
+ from transformers import BertTokenizer
162
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
163
+ logger.info("Loaded fallback tokenizer: bert-base-uncased")
164
+ except Exception as e2:
165
+ logger.error(f"Fallback tokenizer load failed: {e2}")
166
+ self.tokenizer = None
167
+ registry.register(TOKENIZER, self.tokenizer, overwrite=True)
168
+
169
+ # Register this model instance in the registry by specialization
170
+ model_registry_key = f"model_{specialization}"
171
+ registry.register(model_registry_key, self)
172
+
173
+ # Also register as default model if it's the primary specialization
174
+ if specialization == "general":
175
+ registry.register(MODEL, self)
176
+
177
+ # ----------------------------
178
+ # Encoder Components (Custom)
179
+ # ----------------------------
180
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
181
+ self.pos_encoder = PositionalEncoding(embedding_dim, max_len=max_seq_length)
182
+
183
+ # ----------------------------
184
+ # Decoder Components (Custom)
185
+ # ----------------------------
186
+ self.tgt_embedding = nn.Embedding(vocab_size, embedding_dim)
187
+ self.pos_decoder = PositionalEncoding(embedding_dim, max_len=max_seq_length)
188
+
189
+ # ----------------------------
190
+ # Transformer Encoder and Decoder (Custom)
191
+ # Always create with batch_first=True for better performance
192
+ # ----------------------------
193
+ encoder_layer = nn.TransformerEncoderLayer(
194
+ d_model=embedding_dim,
195
+ nhead=num_heads,
196
+ dim_feedforward=hidden_dim,
197
+ dropout=dropout,
198
+ batch_first=True # Fixed to use batch_first=True
199
+ )
200
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
201
+
202
+ decoder_layer = nn.TransformerDecoderLayer(
203
+ d_model=embedding_dim,
204
+ nhead=num_heads,
205
+ dim_feedforward=hidden_dim,
206
+ dropout=dropout,
207
+ batch_first=True # Fixed to use batch_first=True
208
+ )
209
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
210
+
211
+ # Initialize the smart hybrid attention
212
+ attention_config = get_hybrid_attention_config()
213
+ attention_config['NUM_HEADS'] = num_heads
214
+ attention_config['WINDOW_SIZE'] = max(256, max_seq_length // 4)
215
+ self.hybrid_attention = SmartHybridAttention(attention_config)
216
+
217
+ # ----------------------------
218
+ # Adapter & Output Layers
219
+ # ----------------------------
220
+ self.adapter = nn.Sequential(
221
+ nn.Linear(embedding_dim, hidden_dim),
222
+ nn.ReLU(),
223
+ nn.Linear(hidden_dim, embedding_dim)
224
+ )
225
+ self.classifier = nn.Linear(embedding_dim, self.vocab_size)
226
+ self.dropout_layer = nn.Dropout(dropout)
227
+
228
+ self.init_weights()
229
+
230
+ def init_weights(self) -> None:
231
+ initrange = 0.1
232
+ with torch.no_grad():
233
+ self.embedding.weight.uniform_(-initrange, initrange)
234
+ self.tgt_embedding.weight.uniform_(-initrange, initrange)
235
+ self.classifier.weight.uniform_(-initrange, initrange)
236
+ self.classifier.bias.zero_()
237
+ for layer in self.adapter:
238
+ if isinstance(layer, nn.Linear):
239
+ layer.weight.uniform_(-initrange, initrange)
240
+ if layer.bias is not None:
241
+ layer.bias.zero_()
242
+ def forward(
243
+ self,
244
+ src: torch.Tensor = None,
245
+ tgt: Optional[torch.Tensor] = None,
246
+ token_type_ids: Optional[torch.Tensor] = None, # Not used in this implementation
247
+ src_mask: Optional[torch.Tensor] = None,
248
+ tgt_mask: Optional[torch.Tensor] = None,
249
+ src_key_padding_mask: Optional[torch.Tensor] = None,
250
+ tgt_key_padding_mask: Optional[torch.Tensor] = None,
251
+ return_sequence: bool = False,
252
+ # Add Hugging Face compatibility parameters
253
+ input_ids: Optional[torch.Tensor] = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ ) -> torch.Tensor:
256
+ # Use Hugging Face parameters if provided
257
+ if src is None and input_ids is not None:
258
+ src = input_ids
259
+ if src_key_padding_mask is None and attention_mask is not None:
260
+ src_key_padding_mask = attention_mask
261
+
262
+ # Handle input shape - our layers expect batch_first=True format
263
+ if src.dim() == 2:
264
+ # src is already [batch_size, seq_len]
265
+ pass
266
+ elif src.dim() == 3 and src.size(0) > src.size(1):
267
+ # src is [seq_len, batch_size, dim] - need to transpose
268
+ src = src.transpose(0, 1)
269
+
270
+ # ----------------------------
271
+ # Encoder: Custom processing of source
272
+ # ----------------------------
273
+ src_emb = self.embedding(src) * math.sqrt(self.embedding_dim)
274
+ src_emb = self.pos_encoder(src_emb.transpose(0, 1)).transpose(0, 1) # Apply positional encoding
275
+
276
+ # Use hybrid attention if sequence length is above the threshold
277
+ if src.size(1) > 256 and hasattr(self, 'hybrid_attention'):
278
+ # Prepare inputs for hybrid attention
279
+ query = src_emb.transpose(0, 1) # Ensure shape is [seq_len, batch, dim]
280
+ key = query
281
+ value = query
282
+
283
+ # Apply smart hybrid attention
284
+ attended_output, _ = self.hybrid_attention(
285
+ query=query,
286
+ key=key,
287
+ value=value,
288
+ key_padding_mask=src_key_padding_mask,
289
+ attn_mask=src_mask,
290
+ prompt_length=src.size(1),
291
+ prompt_complexity=0.5 # Default value, can be computed based on input
292
+ )
293
+
294
+ # Convert back to expected format
295
+ encoded_src = attended_output.transpose(0, 1)
296
+ else:
297
+ # Use standard transformer encoder for shorter sequences
298
+ encoded_src = self.transformer_encoder(src_emb, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
299
+
300
+ # Process through adapter layer
301
+ adapted = self.adapter(encoded_src)
302
+
303
+ # ----------------------------
304
+ # Decoder / Output
305
+ # ----------------------------
306
+ if tgt is not None:
307
+ # Handle tgt shape for batch_first format
308
+ if tgt.dim() == 2:
309
+ # tgt is already [batch_size, seq_len]
310
+ pass
311
+ elif tgt.dim() == 3 and tgt.size(0) > tgt.size(1):
312
+ # tgt is [seq_len, batch_size, dim] - need to transpose
313
+ tgt = tgt.transpose(0, 1)
314
+
315
+ tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.embedding_dim)
316
+ tgt_emb = self.pos_decoder(tgt_emb.transpose(0, 1)).transpose(0, 1) # Apply positional encoding
317
+
318
+ decoded = self.transformer_decoder(
319
+ tgt_emb,
320
+ adapted,
321
+ tgt_mask=tgt_mask,
322
+ memory_key_padding_mask=src_key_padding_mask,
323
+ tgt_key_padding_mask=tgt_key_padding_mask
324
+ )
325
+
326
+ output = self.classifier(decoded) # [batch_size, seq_len, output_size]
327
+
328
+ if not return_sequence:
329
+ output = output.mean(dim=1) # Average over sequence dimension
330
+ else:
331
+ # For encoder-only tasks (e.g., classification)
332
+ if self.pooling_mode == "mean":
333
+ pooled = encoded_src.mean(dim=1)
334
+ elif self.pooling_mode == "max":
335
+ pooled = torch.max(encoded_src, dim=1)[0]
336
+ elif self.pooling_mode == "cls":
337
+ pooled = encoded_src[:, 0] # Use first token (CLS) - batch_first format
338
+ else:
339
+ pooled = encoded_src.mean(dim=1)
340
+ pooled = self.dropout_layer(pooled)
341
+ output = self.classifier(pooled)
342
+
343
+ return output
344
+
345
+ # Add sentence transformer methods
346
+ def encode_sentences(self, sentences, batch_size=32, normalize_embeddings=True):
347
+ """Encode sentences into vectors (sentence transformer functionality)"""
348
+ self.eval()
349
+ from torch.utils.data import DataLoader, Dataset
350
+
351
+ # Handle single sentence
352
+ if isinstance(sentences, str):
353
+ sentences = [sentences]
354
+
355
+ class SentencesDataset(Dataset):
356
+ def __init__(self, sentences, tokenizer, max_length):
357
+ self.sentences = sentences
358
+ self.tokenizer = tokenizer
359
+ self.max_length = max_length
360
+
361
+ def __len__(self):
362
+ return len(self.sentences)
363
+
364
+ def __getitem__(self, idx):
365
+ return self.tokenizer(self.sentences[idx],
366
+ padding='max_length',
367
+ truncation=True,
368
+ max_length=self.max_length,
369
+ return_tensors='pt')
370
+
371
+ # Create dataset and dataloader
372
+ dataset = SentencesDataset(sentences, self.tokenizer, self.max_seq_length)
373
+ dataloader = DataLoader(dataset, batch_size=batch_size)
374
+
375
+ all_embeddings = []
376
+ device = next(self.parameters()).device
377
+
378
+ with torch.no_grad():
379
+ for batch in dataloader:
380
+ inputs = {k: v.squeeze(1).to(device) for k, v in batch.items()}
381
+ outputs = self(inputs['input_ids'], src_key_padding_mask=inputs.get('attention_mask'))
382
+
383
+ if normalize_embeddings:
384
+ outputs = torch.nn.functional.normalize(outputs, p=2, dim=1)
385
+
386
+ all_embeddings.append(outputs.cpu().numpy())
387
+
388
+ return np.vstack(all_embeddings)
389
+
390
+ def similarity(self, sentence1: str, sentence2: str) -> float:
391
+ """Compute cosine similarity between two sentences"""
392
+ embeddings = self.encode_sentences([sentence1, sentence2])
393
+ return np.dot(embeddings[0], embeddings[1]) / (np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]))
394
+
395
+ def generate(
396
+ self,
397
+ input_ids: torch.Tensor,
398
+ max_length: int = app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH,
399
+ device: str = 'cpu',
400
+ temperature: float = 1.0,
401
+ start_token_id: Optional[int] = None
402
+ ) -> List[List[int]]:
403
+ """Generates a sequence of token IDs using the encoder-decoder architecture."""
404
+ self.eval()
405
+ batch_size = input_ids.shape[0]
406
+ if start_token_id is None:
407
+ start_token_id = input_ids[0, 0].item()
408
+ generated = [[start_token_id] for _ in range(batch_size)]
409
+
410
+ # Encode source input using the custom encoder.
411
+ src = input_ids.transpose(0, 1)
412
+ src_emb = self.embedding(src) * math.sqrt(self.embedding_dim)
413
+ src_emb = self.pos_encoder(src_emb)
414
+ encoded_src = self.transformer_encoder(src_emb)
415
+ encoded_src = self.adapter(encoded_src)
416
+
417
+ for _ in range(max_length - 1):
418
+ current_tgt = torch.tensor(generated, dtype=torch.long, device=device)
419
+ current_tgt = current_tgt.transpose(0, 1)
420
+ tgt_emb = self.tgt_embedding(current_tgt) * math.sqrt(self.embedding_dim)
421
+ tgt_emb = self.pos_decoder(tgt_emb)
422
+ current_seq_length = current_tgt.size(0)
423
+
424
+ # Create causal mask for the decoder.
425
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(current_seq_length).to(device)
426
+ decoded = self.transformer_decoder(tgt_emb, encoded_src, tgt_mask=tgt_mask)
427
+ logits = self.classifier(decoded[-1, :, :])
428
+
429
+ if temperature == 0:
430
+ next_tokens = torch.argmax(logits, dim=-1)
431
+ else:
432
+ probs = torch.softmax(logits / temperature, dim=-1)
433
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
434
+
435
+ next_tokens = next_tokens.cpu().tolist()
436
+ for i, token in enumerate(next_tokens):
437
+ generated[i].append(token)
438
+
439
+ return generated
440
+
441
+ def decode_tokens(self, token_ids: List[int]) -> str:
442
+ """Decodes a list of token IDs into a human-readable string."""
443
+ try:
444
+ return self.tokenizer.decode(token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
445
+ except Exception as e:
446
+ logger.error(f"Error decoding tokens: {e}")
447
+ return str(e)
448
+ def generate_with_decoding(
449
+ self,
450
+ input_ids: torch.Tensor,
451
+ max_length: int = app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH,
452
+ device: str = 'cpu',
453
+ temperature: float = 1.0,
454
+ start_token_id: Optional[int] = None
455
+ ) -> str:
456
+ """Generates a sequence and returns the decoded text."""
457
+ generated_sequences = self.generate(input_ids, max_length, device, temperature, start_token_id)
458
+ if generated_sequences:
459
+ return self.decode_tokens(generated_sequences[0])
460
+ return ""
461
+
462
+ def generate_streaming(self, prompt, **kwargs):
463
+ """Generate a response token-by-token from the model"""
464
+ # Prepare input
465
+ inputs = self.tokenizer(
466
+ prompt,
467
+ return_tensors="pt",
468
+ padding="max_length",
469
+ truncation=True,
470
+ max_length=self.max_seq_length
471
+ ).to(self.device)
472
+
473
+ # Generate initial token
474
+ # This is a simplified implementation - a real one would use beam search or sampling
475
+ with torch.no_grad():
476
+ # Get initial logits from the model
477
+ outputs = self(inputs.input_ids)
478
+ next_token_logits = outputs[:, -1, :]
479
+
480
+ # Choose next token (using temperature if specified)
481
+ if "temperature" in kwargs and kwargs["temperature"] > 0:
482
+ # Apply temperature
483
+ next_token_logits = next_token_logits / kwargs["temperature"]
484
+
485
+ # Sample from the distribution
486
+ probs = torch.softmax(next_token_logits, dim=-1)
487
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
488
+
489
+ # Create the sequence with the new token
490
+ generated_ids = next_token
491
+
492
+ # Decode and yield the first token
493
+ token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True)
494
+ yield token_text
495
+
496
+ # Generate remaining tokens up to max_length
497
+ max_length = kwargs.get("max_length", 100)
498
+
499
+ for _ in range(max_length - 1):
500
+ # Create input with context plus generated tokens
501
+ context_ids = torch.cat([inputs.input_ids, generated_ids.unsqueeze(0)], dim=1)
502
+
503
+ # Get next token prediction
504
+ outputs = self(context_ids)
505
+ next_token_logits = outputs[:, -1, :]
506
+
507
+ # Apply temperature if specified
508
+ if "temperature" in kwargs and kwargs["temperature"] > 0:
509
+ next_token_logits = next_token_logits / kwargs["temperature"]
510
+
511
+ # Sample from the distribution
512
+ probs = torch.softmax(next_token_logits, dim=-1)
513
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
514
+
515
+ # Append to generated sequence
516
+ generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=0)
517
+
518
+ # Decode and yield the next token
519
+ token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True)
520
+
521
+ # Check for end of sequence token
522
+ if next_token.item() == self.tokenizer.eos_token_id:
523
+ break
524
+ yield token_text
525
+
526
+ def forward_with_custom_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
527
+ """Forward pass that accepts pre-calculated embeddings to bypass shape errors."""
528
+ try:
529
+ # Get device
530
+ device = next(self.parameters()).device
531
+ embeddings = embeddings.to(device)
532
+
533
+ # Process through transformer encoder - bypassing the embedding layer
534
+ # Check if embeddings need to be transposed for batch_first format
535
+ batch_first = getattr(self.transformer_encoder, 'batch_first', False)
536
+
537
+ if batch_first and embeddings.shape[0] <= embeddings.shape[1]:
538
+ # First dimension is smaller than second, likely needs transpose
539
+ # from [seq_len, batch, dim] to [batch, seq_len, dim]
540
+ embeddings = embeddings.transpose(0, 1)
541
+
542
+ # Apply position encoding if needed
543
+ if hasattr(self, 'pos_encoder'):
544
+ # Check if position encoder expects seq_first or batch_first
545
+ if not batch_first:
546
+ # Ensure shape is [seq_len, batch, dim]
547
+ if embeddings.shape[0] > embeddings.shape[1]:
548
+ # Already in correct format
549
+ embeddings = self.pos_encoder(embeddings)
550
+ else:
551
+ # Need to transpose first
552
+ embeddings = embeddings.transpose(0, 1)
553
+ embeddings = self.pos_encoder(embeddings)
554
+ embeddings = embeddings.transpose(0, 1)
555
+ else:
556
+ # With batch_first, no need to transpose
557
+ embeddings = self.pos_encoder(embeddings)
558
+
559
+ # Process through encoder
560
+ encoded = self.transformer_encoder(embeddings)
561
+
562
+ # Process through adapter
563
+ if hasattr(self, 'adapter'):
564
+ encoded = self.adapter(encoded)
565
+
566
+ # Apply pooling for output
567
+ if self.pooling_mode == "mean":
568
+ pooled = encoded.mean(dim=1)
569
+ elif self.pooling_mode == "max":
570
+ pooled = torch.max(encoded, dim=1)[0]
571
+ elif self.pooling_mode == "cls":
572
+ # Use first token (CLS token) for classification
573
+ pooled = encoded[:, 0]
574
+ else:
575
+ pooled = encoded.mean(dim=1)
576
+
577
+ # Final dropout and classification
578
+ pooled = self.dropout_layer(pooled)
579
+ output = self.classifier(pooled)
580
+
581
+ return output
582
+ except Exception as e:
583
+ logger.error(f"Error in custom embeddings forward pass: {e}")
584
+ # Return a tensor of the right shape to prevent further errors
585
+ return torch.zeros(1, self.output_size, device=device)
586
+
587
+ def forward_with_error_handling(
588
+ self,
589
+ input_ids: Optional[torch.Tensor] = None,
590
+ attention_mask: Optional[torch.Tensor] = None,
591
+ token_type_ids: Optional[torch.Tensor] = None,
592
+ **kwargs
593
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
594
+ """Forward pass with enhanced error handling for shape mismatches"""
595
+ try:
596
+ # Try standard forward pass first
597
+ return self.forward(
598
+ src=input_ids,
599
+ attention_mask=attention_mask,
600
+ token_type_ids=token_type_ids,
601
+ **kwargs
602
+ )
603
+ except RuntimeError as e:
604
+ # Check if this is a shape error
605
+ if "shape" in str(e):
606
+ logger.warning(f"Shape mismatch detected: {e}")
607
+ if input_ids.dim() == 3 and input_ids.size(0) > input_ids.size(1):
608
+ input_ids = input_ids.transpose(0, 1) # Adjust shape as needed
609
+ # Retry the forward pass using adapted input
610
+ try:
611
+ embedded = self.embedding(input_ids)
612
+ if hasattr(self, 'pos_encoder'):
613
+ embedded = self.pos_encoder(embedded)
614
+ encoder_out = self.transformer_encoder(embedded)
615
+ pooled = encoder_out.mean(dim=1)
616
+ pooled = self.dropout_layer(pooled)
617
+ return self.classifier(pooled)
618
+ except Exception as inner_e:
619
+ logger.error(f"Adaptation failed: {inner_e}")
620
+ batch_size = input_ids.size(0) if input_ids is not None else 1
621
+ return torch.zeros((batch_size, self.output_size), device=self.device)
622
+ # Re-raise the exception if not handled
623
+ raise
624
+ except Exception as e:
625
+ logger.error(f"Unhandled error in forward_with_error_handling: {e}")
626
+ batch_size = input_ids.size(0) if input_ids is not None else 1
627
+ return torch.zeros((batch_size, self.output_size), device=self.device)
628
+
629
+ def train_with_emissions_tracking(self, dataloader, optimizer, criterion, num_epochs=1):
630
+ """
631
+ Train the model while tracking carbon emissions using CodeCarbon.
632
+ """
633
+ tracker = EmissionsTracker()
634
+ tracker.start() # Start tracking emissions
635
+
636
+ self.train() # Set model to training mode
637
+ for epoch in range(num_epochs):
638
+ for batch in dataloader:
639
+ inputs, labels = batch
640
+ inputs, labels = inputs.to(self.device), labels.to(self.device)
641
+
642
+ optimizer.zero_grad()
643
+ outputs = self(inputs)
644
+ loss = criterion(outputs, labels)
645
+ loss.backward()
646
+ optimizer.step()
647
+
648
+ logging.info(f"Epoch {epoch + 1}/{num_epochs} completed.")
649
+
650
+ emissions = tracker.stop() # Stop tracking emissions
651
+ logging.info(f"Training completed. Carbon emissions: {emissions:.4f} kg CO2")
652
+
653
+ def infer_with_emissions_tracking(self, input_ids):
654
+ """
655
+ Perform inference while tracking carbon emissions using CodeCarbon.
656
+ """
657
+ tracker = EmissionsTracker()
658
+ tracker.start() # Start tracking emissions
659
+
660
+ self.eval() # Set model to evaluation mode
661
+ with torch.no_grad():
662
+ outputs = self(input_ids)
663
+
664
+ emissions = tracker.stop() # Stop tracking emissions
665
+ logging.info(f"Inference completed. Carbon emissions: {emissions:.4f} kg CO2")
666
+ return outputs
667
+
668
+ # Register the model class in registry for discovery
669
+ registry.register("model_class_custom", Wildnerve_tlm01)
670
+
671
+ # Check if tokenizer is initialized properly.
672
+ def initialize_tokenizer():
673
+ """
674
+ Fallback function to initialize the tokenizer.
675
+ Tries up to 5 times and logs debug messages on each attempt.
676
+ """
677
+ from transformers import BertTokenizer, AutoTokenizer
678
+ max_attempts = 5
679
+ for attempt in range(1, max_attempts + 1):
680
+ try:
681
+ # Attempt to get tokenizer from the registry
682
+ from service_registry import registry, TOKENIZER
683
+ if registry.has(TOKENIZER):
684
+ tokenizer = registry.get(TOKENIZER)
685
+ if tokenizer is not None:
686
+ logger.debug(f"Attempt {attempt}: Successfully retrieved tokenizer from registry.")
687
+ return tokenizer
688
+ # Fallback: load tokenizer directly
689
+ tokenizer = AutoTokenizer.from_pretrained("Wildnerve-tlm01-0.05Bx12")
690
+ logger.debug(f"Attempt {attempt}: Successfully loaded tokenizer from pretrained model.")
691
+ # Register it for future use
692
+ registry.register(TOKENIZER, tokenizer)
693
+ return tokenizer
694
+ except Exception as e:
695
+ logger.debug(f"Attempt {attempt}: Failed to initialize tokenizer due to: {e}")
696
+ logger.error("Tokenizer initialization failed after 5 attempts. Using default BertTokenizer.")
697
+ try:
698
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
699
+ except Exception as e:
700
+ logger.error(f"Default tokenizer initialization failed: {e}")
701
+ tokenizer = None
702
+ return tokenizer
model_List.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_List.py - Model selection and analysis component
2
+ import logging
3
+ import time
4
+ import math
5
+ import torch
6
+ import importlib.util
7
+ import os
8
+ import re
9
+ import logging
10
+ from typing import List, Tuple, Dict
11
+ import torch
12
+ import numpy as np
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+ import nltk
15
+ try:
16
+ nltk.data.find('tokenizers/punkt')
17
+ except LookupError:
18
+ nltk.download("punkt")
19
+ from transformers import AutoTokenizer, AutoModel
20
+ from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
21
+ from service_registry import registry, TOKENIZER, MODEL
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ class PromptAnalyzer:
26
+ """
27
+ A complete prompt analyzer that:
28
+ - Loads a lightweight Transformer encoder (DistilBERT)
29
+ - Applies SmartHybridAttention to refine token embeddings
30
+ - Compares the resulting prompt embedding against predefined topic embeddings
31
+ - Determines a primary topic and subtopics
32
+ - Provides candidate model identifiers or a single best match.
33
+ """
34
+ def __init__(self):
35
+ # Predefined topics with keyword sets for topic understanding
36
+ self.predefined_topics: Dict[str, List[str]] = {
37
+ "general": ["general", "overview", "basic", "introduction"],
38
+ "programming": ["code", "programming", "debug", "software", "algorithm", "bug"],
39
+ "science": ["research", "experiment", "science", "physics", "biology", "chemistry"],
40
+ "history": ["history", "ancient", "modern", "civilization", "war"],
41
+ "mathematics": ["math", "algebra", "calculus", "geometry", "statistics"]
42
+ }
43
+ # Initialize a lightweight transformer encoder for embeddings
44
+ self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
45
+ self.encoder = AutoModel.from_pretrained("distilbert-base-uncased")
46
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ self.encoder.to(self.device)
48
+ # Initialize SmartHybridAttention for refined representations
49
+ attention_config = get_hybrid_attention_config()
50
+ self.attention = SmartHybridAttention(attention_config)
51
+ self.attention.to(self.device)
52
+ logger.info("PromptAnalyzer initialized with DistilBERT and SmartHybridAttention.")
53
+
54
+ def _encode_text(self, text: str) -> np.ndarray:
55
+ """
56
+ Encode text into an embedding vector.
57
+ First, obtain token embeddings using DistilBERT.
58
+ Then refine these embeddings with SmartHybridAttention.
59
+ Finally, average-pool to produce a single vector.
60
+ """
61
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
62
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
63
+ with torch.no_grad():
64
+ outputs = self.encoder(**inputs) # shape: [batch, seq_len, dim]
65
+ token_embeds = outputs.last_hidden_state # [1, seq_len, dim]
66
+ # Transpose for attention: [seq_len, batch, dim]
67
+ token_embeds = token_embeds.transpose(0, 1)
68
+ attended, _ = self.attention(query=token_embeds, key=token_embeds, value=token_embeds)
69
+ # Transpose back and pool over tokens: [batch, seq_len, dim] -> [batch, dim]
70
+ attended = attended.transpose(0, 1)
71
+ pooled = attended.mean(dim=1)
72
+ return pooled.squeeze().cpu().numpy()
73
+
74
+ def analyze_prompt(self, prompt: str) -> Tuple[str, List[str]]:
75
+ """
76
+ Analyze the given prompt:
77
+ - Compute its refined embedding.
78
+ - For each predefined topic, encode its keyword string.
79
+ - Compute cosine similarity between prompt and topic embeddings.
80
+ - Return the primary topic (highest similarity) and any subtopics
81
+ with similarity above 80% of the top score.
82
+ """
83
+ prompt_embedding = self._encode_text(prompt)
84
+ topic_scores = {}
85
+ for topic, keywords in self.predefined_topics.items():
86
+ topic_text = " ".join(keywords)
87
+ topic_embedding = self._encode_text(topic_text)
88
+ similarity = cosine_similarity(
89
+ prompt_embedding.reshape(1, -1),
90
+ topic_embedding.reshape(1, -1)
91
+ )[0][0]
92
+ topic_scores[topic] = similarity
93
+ sorted_topics = sorted(topic_scores.items(), key=lambda x: x[1], reverse=True)
94
+ primary_topic = sorted_topics[0][0] if sorted_topics else "general"
95
+ threshold = sorted_topics[0][1] * 0.8 if sorted_topics else 0.0
96
+ subtopics = [topic for topic, score in sorted_topics if score >= threshold and topic != primary_topic]
97
+ logger.debug(f"Prompt analyzed (first 30 chars): '{prompt[:30]}...' -> Primary: {primary_topic}, Subtopics: {subtopics}")
98
+ return primary_topic, subtopics
99
+
100
+ def get_selected_models(self) -> List[str]:
101
+ """
102
+ Return candidate model identifiers.
103
+ For example, if the prompt is technical (programming) the custom model might be top.
104
+ This method can later be expanded to select multiple or weighted candidates.
105
+ """
106
+ # Here we return our primary custom model and a fallback general model.
107
+ return ["Wildnerve-tlm01-0.05Bx12", "bert-base-uncased"]
108
+
109
+ def choose_model(self, prompt: str) -> str:
110
+ """
111
+ Based on the analyzed prompt, select the most appropriate model identifier.
112
+ For instance, if 'programming' is detected, return the custom model.
113
+ Otherwise, return a general/pretrained model or a combination indicator.
114
+ """
115
+ primary_topic, _ = self.analyze_prompt(prompt)
116
+ if primary_topic == "programming":
117
+ return "Wildnerve-tlm01-0.05Bx12"
118
+ elif primary_topic in ["science", "mathematics", "history"]:
119
+ return "model_Combn.py"
120
+ else:
121
+ return "bert-base-uncased"
122
+
123
+ # Optionally, additional helper methods could be added here for richer topic decomposition.
124
+
125
+ # Register the PromptAnalyzer in the service registry to resolve dependencies.
126
+ registry.register("prompt_analyzer", PromptAnalyzer())
127
+
128
+ # If additional functions or initialization code is needed, include here:
129
+ def main():
130
+ # For testing purposes; in production, model_manager will retrieve the analyzer.
131
+ analyzer = registry.get("prompt_analyzer")
132
+ sample_prompt = "I'm having trouble debugging my Python code for a sorting algorithm."
133
+ primary_topic, subtopics = analyzer.analyze_prompt(sample_prompt)
134
+ selected = analyzer.choose_model(sample_prompt)
135
+ logger.info(f"Sample prompt analysis:\nPrimary Topic: {primary_topic}\nSubtopics: {subtopics}\nSelected Model: {selected}")
136
+
137
+ if __name__ == "__main__":
138
+ main()
model_PrTr.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_prtr.py
2
+ import os
3
+ import sys
4
+ import math
5
+ import torch
6
+ import logging
7
+ import importlib
8
+ import torch.nn as nn
9
+ from config import load_config
10
+ from transformers import AutoTokenizer, AutoModel
11
+ from typing import Optional, List, Dict, Any, Union
12
+ from sentence_transformers import SentenceTransformer
13
+ # Import service registry
14
+ from service_registry import registry, MODEL, TOKENIZER
15
+ # First import base interfaces
16
+ from base_interfaces.common_types import *
17
+ from base_interfaces.model_interface import AbstractModel
18
+
19
+ # Import environment setup first to ensure config is available
20
+ from model_env_setup import app_config
21
+
22
+ app_config = load_config()
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # ----------------------------
26
+ # Positional Encoding Module (for decoder)
27
+ # ----------------------------
28
+ class PositionalEncoding(nn.Module):
29
+ def __init__(self, d_model: int, max_len: int = app_config.MAX_SEQ_LENGTH):
30
+ super().__init__()
31
+ pe = torch.zeros(max_len, d_model)
32
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
33
+ div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model))
34
+ pe[:, 0::2] = torch.sin(position * div_term)
35
+ pe[:, 1::2] = torch.cos(position * div_term)
36
+ pe = pe.unsqueeze(1) # shape: (max_len, 1, d_model)
37
+ self.register_buffer('pe', pe)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ # x shape: (seq_len, batch_size, d_model)
41
+ seq_len = x.size(0)
42
+ x = x + self.pe[:seq_len]
43
+ return x
44
+
45
+ # ----------------------------
46
+ # Wildnerve-tlm01 using Only Pretrained Encoder
47
+ # ----------------------------
48
+ class Wildnerve_tlm01(nn.Module, AbstractModel):
49
+ """A Transformer-based language model that uses:
50
+ - A pretrained encoder (via AutoModel)
51
+ - A custom decoder stack
52
+ The model uses the AutoTokenizer for consistent tokenization."""
53
+ def __init__(
54
+ self,
55
+ vocab_size: int,
56
+ specialization: str,
57
+ dataset_path: str,
58
+ model_name: str,
59
+ embedding_dim: int,
60
+ num_heads: int,
61
+ hidden_dim: int,
62
+ num_layers: int,
63
+ output_size: int,
64
+ dropout: float,
65
+ max_seq_length: int,
66
+ pooling_mode: str,
67
+ tokenizer=None # Accept tokenizer as parameter
68
+ ) -> None:
69
+ super().__init__()
70
+ self.specialization = specialization
71
+ self.dataset_path = dataset_path
72
+ self.model_name = model_name
73
+ self.pooling_mode = pooling_mode
74
+ self.vocab_size = vocab_size
75
+ self.max_seq_length = max_seq_length
76
+ self.embedding_dim = embedding_dim
77
+ self.num_heads = num_heads
78
+ self.hidden_dim = hidden_dim
79
+ self.num_layers = num_layers
80
+ self.output_size = output_size
81
+ self.dropout = dropout
82
+
83
+ # Add dimension projection layer for pretrained model output
84
+ self.pretrained_projection = nn.Linear(768, embedding_dim) # 768 → 256
85
+
86
+ # Initialize projection layer
87
+ nn.init.xavier_uniform_(self.pretrained_projection.weight)
88
+ nn.init.zeros_(self.pretrained_projection.bias)
89
+
90
+ # Use tokenizer from params, registry, or create new
91
+ if tokenizer is not None:
92
+ self.tokenizer = tokenizer
93
+ else:
94
+ # Try to get tokenizer from registry first
95
+ if registry.has(TOKENIZER):
96
+ self.tokenizer = registry.get(TOKENIZER)
97
+ logger.info("Using tokenizer from registry")
98
+ else:
99
+ # Load a new tokenizer
100
+ if tokenizer is None:
101
+ try:
102
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
103
+ logger.info("Loaded primary pretrained tokenizer: bert-base-uncased")
104
+ except Exception as e:
105
+ logger.warning(f"Bert tokenizer load failed: {e}")
106
+ try:
107
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
108
+ logger.info("Loaded fallback tokenizer: GPT2")
109
+ except Exception as e2:
110
+ logger.error(f"GPT2 tokenizer load failed: {e2}")
111
+ self.tokenizer = None
112
+
113
+ # Register this model instance in the registry by specialization
114
+ model_registry_key = f"model_{specialization}"
115
+ registry.register(model_registry_key, self)
116
+
117
+ # ----------------------------
118
+ # Decoder (Target) Components
119
+ # ----------------------------
120
+ self.tgt_embedding = nn.Embedding(vocab_size, embedding_dim)
121
+ self.pos_decoder = PositionalEncoding(embedding_dim, max_len=max_seq_length)
122
+ decoder_layer = nn.TransformerDecoderLayer(
123
+ d_model=embedding_dim,
124
+ nhead=num_heads,
125
+ dim_feedforward=hidden_dim,
126
+ dropout=dropout,
127
+ batch_first=False
128
+ )
129
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
130
+
131
+ # ----------------------------
132
+ # Adapter & Output Components
133
+ # ----------------------------
134
+ self.adapter = nn.Sequential(
135
+ nn.Linear(embedding_dim, hidden_dim),
136
+ nn.ReLU(),
137
+ nn.Linear(hidden_dim, embedding_dim)
138
+ )
139
+ self.classifier = nn.Linear(embedding_dim, output_size)
140
+ self.dropout_layer = nn.Dropout(dropout)
141
+
142
+ self.init_weights()
143
+
144
+ def init_weights(self) -> None:
145
+ """Initialize weights for decoder, adapter and classifier."""
146
+ initrange = 0.1
147
+ with torch.no_grad():
148
+ self.tgt_embedding.weight.uniform_(-initrange, initrange)
149
+ self.classifier.weight.uniform_(-initrange, initrange)
150
+ self.classifier.bias.zero_()
151
+ for layer in self.adapter:
152
+ if isinstance(layer, nn.Linear):
153
+ layer.weight.uniform_(-initrange, initrange)
154
+ if layer.bias is not None:
155
+ layer.bias.zero_()
156
+
157
+ def forward(self, src: torch.Tensor, tgt: Optional[torch.Tensor] = None,
158
+ src_key_padding_mask: Optional[torch.Tensor] = None,
159
+ tgt_key_padding_mask: Optional[torch.Tensor] = None,
160
+ return_sequence: bool = False,
161
+ **kwargs) -> torch.Tensor:
162
+ try:
163
+ # Pretrained encoder expects input shape: (batch_size, seq_length)
164
+ encoded_output = self.pretrained_encoder(src)[0] # (batch_size, seq_length, embedding_dim)
165
+
166
+ # Project from 768 to 256
167
+ encoded_output = self.pretrained_projection(encoded_output)
168
+
169
+ # Transpose to (seq_length, batch_size, embedding_dim)
170
+ encoded_output = encoded_output.transpose(0, 1)
171
+
172
+ # Process through adapter layer
173
+ adapted = self.adapter(encoded_output)
174
+
175
+ # If a target sequence is provided, run the decoder
176
+ if tgt is not None:
177
+ tgt = tgt.transpose(0, 1) # (seq_length, batch_size)
178
+ tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.embedding_dim)
179
+ tgt_emb = self.pos_decoder(tgt_emb)
180
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(0)).to(src.device)
181
+ decoded = self.transformer_decoder(
182
+ tgt_emb,
183
+ adapted,
184
+ tgt_mask=tgt_mask,
185
+ memory_key_padding_mask=src_key_padding_mask,
186
+ tgt_key_padding_mask=tgt_key_padding_mask
187
+ )
188
+ output = self.classifier(decoded.transpose(0, 1))
189
+ if not return_sequence:
190
+ output = output.mean(dim=1)
191
+ else:
192
+ if self.pooling_mode == "mean":
193
+ output = adapted.mean(dim=0)
194
+ elif self.pooling_mode == "max":
195
+ output = torch.max(adapted, dim=0)[0]
196
+ else:
197
+ output = adapted.mean(dim=0)
198
+
199
+ output = self.dropout_layer(output)
200
+ output = self.classifier(output)
201
+
202
+ return output
203
+
204
+ except Exception as e:
205
+ logger.error(f"Error during forward pass: {e}")
206
+ raise
207
+
208
+ @staticmethod
209
+ def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
210
+ """Generate square subsequent mask for transformer."""
211
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
212
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
213
+ return mask
214
+
215
+ def generate(
216
+ self,
217
+ input_ids: torch.Tensor,
218
+ max_length: int = app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH,
219
+ device: str = 'cpu',
220
+ temperature: float = 1.0,
221
+ start_token_id: Optional[int] = None
222
+ ) -> List[List[int]]:
223
+ """Generates token ID sequences using the pretrained encoder and custom decoder."""
224
+ self.eval()
225
+ batch_size = input_ids.shape[0]
226
+ if start_token_id is None:
227
+ start_token_id = input_ids[0, 0].item()
228
+ generated = [[start_token_id] for _ in range(batch_size)]
229
+
230
+ # Use pretrained encoder to encode source input.
231
+ encoded_output = self.pretrained_encoder(input_ids)[0] # (batch_size, seq_length, embedding_dim)
232
+ encoded_output = encoded_output.transpose(0, 1) # (seq_length, batch_size, embedding_dim)
233
+ adapted = self.adapter(encoded_output)
234
+
235
+ for _ in range(max_length - 1):
236
+ current_tgt = torch.tensor(generated, dtype=torch.long, device=device)
237
+ current_tgt = current_tgt.transpose(0, 1)
238
+ tgt_emb = self.tgt_embedding(current_tgt) * math.sqrt(self.embedding_dim)
239
+ tgt_emb = self.pos_decoder(tgt_emb)
240
+ current_seq_length = current_tgt.size(0)
241
+
242
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(current_seq_length).to(device)
243
+ decoded = self.transformer_decoder(tgt_emb, adapted, tgt_mask=tgt_mask)
244
+ logits = self.classifier(decoded[-1, :, :])
245
+
246
+ if temperature == 0:
247
+ next_tokens = torch.argmax(logits, dim=-1)
248
+ else:
249
+ probs = torch.softmax(logits / temperature, dim=-1)
250
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
251
+
252
+ next_tokens = next_tokens.cpu().tolist()
253
+ for i, token in enumerate(next_tokens):
254
+ generated[i].append(token)
255
+
256
+ return generated
257
+
258
+ def decode_tokens(self, token_ids: List[int]) -> str:
259
+ """Decodes a list of token IDs into text."""
260
+ try:
261
+ return self.tokenizer.decode(token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
262
+ except Exception as e:
263
+ logger.error(f"Error decoding tokens: {e}")
264
+ return str(e)
265
+
266
+ def generate_with_decoding(
267
+ self,
268
+ input_ids: torch.Tensor,
269
+ max_length: int = 20,
270
+ device: str = 'cpu',
271
+ temperature: float = 1.0,
272
+ start_token_id: Optional[int] = None
273
+ ) -> str:
274
+ """Generates a sequence and decodes it into text."""
275
+ generated_sequences = self.generate(input_ids, max_length, device, temperature, start_token_id)
276
+ if generated_sequences:
277
+ return self.decode_tokens(generated_sequences[0])
278
+ return ""
279
+
280
+ def generate_streaming(self, prompt=None, input_ids=None, attention_mask=None, **kwargs):
281
+ """Generate a response token-by-token from the model"""
282
+ # Consistent device handling
283
+ device = next(self.parameters()).device
284
+
285
+ # Handle either text prompt or tokenized input
286
+ if prompt is not None and input_ids is None:
287
+ inputs = self.tokenizer(
288
+ prompt,
289
+ return_tensors="pt",
290
+ padding="max_length",
291
+ truncation=True,
292
+ max_length=self.max_seq_length
293
+ )
294
+ input_ids = inputs["input_ids"].to(device)
295
+ attention_mask = inputs.get("attention_mask", None)
296
+ if attention_mask is not None:
297
+ attention_mask = attention_mask.to(device)
298
+
299
+ # Ensure input_ids is valid
300
+ if input_ids is None:
301
+ raise ValueError("Either prompt or input_ids must be provided")
302
+
303
+ # Use pretrained encoder to encode source input
304
+ encoded_output = self.pretrained_encoder(input_ids)[0] # (batch_size, seq_length, embedding_dim)
305
+ encoded_output = self.pretrained_projection(encoded_output)
306
+ encoded_output = encoded_output.transpose(0, 1) # (seq_length, batch_size, embedding_dim)
307
+ adapted = self.adapter(encoded_output)
308
+
309
+ # Get generation config params
310
+ max_length = kwargs.get('max_length', 100)
311
+ temperature = kwargs.get('temperature', 0.7)
312
+
313
+ # Generate first token
314
+ with torch.no_grad():
315
+ # Initialize with start token (could be from input or specified)
316
+ start_token_id = kwargs.get('start_token_id', input_ids[0, 0].item())
317
+ current_tgt = torch.tensor([[start_token_id]], dtype=torch.long, device=device)
318
+ current_tgt = current_tgt.transpose(0, 1) # (1, batch_size=1)
319
+
320
+ # Process first token
321
+ tgt_emb = self.tgt_embedding(current_tgt) * math.sqrt(self.embedding_dim)
322
+ tgt_emb = self.pos_decoder(tgt_emb)
323
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(1).to(device)
324
+ decoded = self.transformer_decoder(tgt_emb, adapted, tgt_mask=tgt_mask)
325
+ logits = self.classifier(decoded[-1, :, :])
326
+
327
+ # Sample from distribution
328
+ if temperature == 0:
329
+ next_token = torch.argmax(logits, dim=-1)
330
+ else:
331
+ probs = torch.softmax(logits / temperature, dim=-1)
332
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
333
+
334
+ # Convert to token text and yield
335
+ token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True)
336
+ yield token_text
337
+
338
+ # Generate rest of sequence
339
+ generated_ids = [next_token.item()]
340
+
341
+ for _ in range(max_length - 1):
342
+ # Update target sequence
343
+ current_tgt = torch.tensor([generated_ids], dtype=torch.long, device=device)
344
+ current_tgt = current_tgt.transpose(0, 1) # (seq_len, batch=1)
345
+
346
+ # Process next token
347
+ tgt_emb = self.tgt_embedding(current_tgt) * math.sqrt(self.embedding_dim)
348
+ tgt_emb = self.pos_decoder(tgt_emb)
349
+ current_seq_length = current_tgt.size(0)
350
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(current_seq_length).to(device)
351
+ decoded = self.transformer_decoder(tgt_emb, adapted, tgt_mask=tgt_mask)
352
+ logits = self.classifier(decoded[-1, :, :])
353
+
354
+ # Sample next token
355
+ if temperature == 0:
356
+ next_token = torch.argmax(logits, dim=-1)
357
+ else:
358
+ probs = torch.softmax(logits / temperature, dim=-1)
359
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
360
+
361
+ # Add to generated sequence
362
+ next_token_id = next_token.item()
363
+ generated_ids.append(next_token_id)
364
+
365
+ # Decode and yield the token
366
+ token_text = self.tokenizer.decode([next_token_id], skip_special_tokens=True)
367
+
368
+ # Check for EOS token
369
+ if next_token_id == self.tokenizer.eos_token_id:
370
+ break
371
+
372
+ yield token_text
373
+
374
+ #-------Pretrained Transformer Model-------------
375
+ class PretrainedTransformer(nn.Module, AbstractModel):
376
+ """A simple wrapper around a pretrained Hugging Face transformer model."""
377
+ def __init__(
378
+ self,
379
+ vocab_size=30522,
380
+ specialization="general",
381
+ dataset_path=None,
382
+ model_name="bert-base-uncased", # Primary model name for pretrained transformer
383
+ embedding_dim=768,
384
+ num_heads=12,
385
+ hidden_dim=768,
386
+ num_layers=6,
387
+ output_size=768,
388
+ dropout=0.1,
389
+ max_seq_length=512,
390
+ pooling_mode="mean",
391
+ tokenizer=None,
392
+ **kwargs
393
+ ) -> None:
394
+ super().__init__()
395
+
396
+ # Optionally track model usage
397
+ self.model_last_used = {}
398
+
399
+ # Unified tokenizer initialization:
400
+ # Primary: Load tokenizer for "bert-base-uncased"
401
+ # Fallback: if it fails, try GPT2 tokenizer
402
+ if tokenizer is not None:
403
+ self.tokenizer = tokenizer
404
+ else:
405
+ from transformers import AutoTokenizer, BertTokenizer
406
+ if registry.has(TOKENIZER):
407
+ self.tokenizer = registry.get(TOKENIZER)
408
+ else:
409
+ try:
410
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
411
+ logger.info("Loaded primary tokenizer: bert-base-uncased")
412
+ except Exception as e:
413
+ logger.warning(f"Primary tokenizer load failed: {e}")
414
+ try:
415
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
416
+ logger.info("Loaded fallback tokenizer: GPT2")
417
+ except Exception as e2:
418
+ logger.error(f"Fallback tokenizer load failed: {e2}")
419
+ self.tokenizer = None
420
+ registry.register(TOKENIZER, self.tokenizer)
421
+
422
+ # Set model names for fallback chain explicitly
423
+ self.model_name = model_name # Should be "bert-base-uncased"
424
+ self.fallback_model = "gpt2" # Fallback tokenization/model if needed
425
+
426
+ self.model = AutoModel.from_pretrained(model_name)
427
+ try:
428
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
429
+ except Exception as e:
430
+ logger.error(f"Failed to load tokenizer for {model_name}: {e}")
431
+ self.tokenizer = None
432
+
433
+ def forward(self, input_ids, attention_mask=None):
434
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
435
+ return outputs.last_hidden_state
436
+
437
+ def encode(self, text: str):
438
+ if not self.tokenizer:
439
+ raise ValueError("Tokenizer not available")
440
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
441
+ with torch.no_grad():
442
+ outputs = self.forward(inputs.input_ids, inputs.get("attention_mask"))
443
+ # Pool by averaging the token embeddings
444
+ return outputs.mean(dim=1)
445
+
446
+ def generate(self, input_ids, max_length=100, **kwargs):
447
+ # Use generate method from model if available, else fallback.
448
+ if hasattr(self.model, "generate"):
449
+ return self.model.generate(input_ids=input_ids, max_length=max_length, **kwargs)
450
+ else:
451
+ # Simple fallback: return input_ids as is
452
+ return input_ids
453
+
454
+ # Register model classes in registry
455
+ registry.register("model_class_pretrained", Wildnerve_tlm01)
456
+ registry.register("pretrained_transformer_class", PretrainedTransformer)
457
+
458
+ # Check if pretrained transformers are properly initialized.
459
+ def initialize_pretrained_model():
460
+ """Attempt to initialize a pretrained tokenizer with a fallback mechanism.
461
+ Tries to load 'bert-base-uncased' first; if that fails, attempts to load 'gpt2'.
462
+ If the fallback is used, then reattempts loading 'bert-base-uncased' on subsequent tries.
463
+ Repeats up to 5 attempts in total.
464
+ Returns:
465
+ The initialized tokenizer instance if successful, otherwise None."""
466
+ max_attempts = 5
467
+ for attempt in range(1, max_attempts + 1):
468
+ try:
469
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
470
+ logger.info(f"Attempt {attempt}: Successfully loaded bert-base-uncased.")
471
+ return tokenizer
472
+ except Exception as e:
473
+ logger.warning(f"Attempt {attempt}: Loading bert-base-uncased failed: {e}")
474
+ try:
475
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
476
+ logger.info(f"Attempt {attempt}: Successfully loaded gpt2 as fallback.")
477
+ return tokenizer
478
+ except Exception as e2:
479
+ logger.warning(f"Attempt {attempt}: Loading gpt2 failed as fallback: {e2}")
480
+ logger.info("Retrying tokenizer initialization...")
481
+ logger.error("Failed to initialize pretrained model tokenizer after 5 attempts.")
482
+ return None
model_manager.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc, os, sys, time, torch, logging, inspect, numpy as np, pandas as pd, importlib.util
2
+ from pathlib import Path
3
+ from threading import Lock
4
+ from collections import OrderedDict
5
+ from nltk.stem import WordNetLemmatizer
6
+ from typing import List, Dict, Any, Tuple, Optional, TYPE_CHECKING
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from config import app_config
9
+ from dataset import TensorDataset
10
+ from utils.transformer_utils import get_sentence_transformer
11
+ from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
12
+ if TYPE_CHECKING:
13
+ from service_registry import registry
14
+ from service_registry import registry, MODEL, TOKENIZER, MODEL_MANAGER, COMMUNICATOR
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ try:
19
+ import psutil
20
+ PSUTIL_AVAILABLE = True
21
+ except ImportError:
22
+ logger.warning("psutil not available")
23
+ PSUTIL_AVAILABLE = False
24
+ class DummyProcess:
25
+ def __init__(self, pid=None): self.pid = pid or 1
26
+ def memory_info(self):
27
+ class MemInfo:
28
+ def __init__(self): self.rss = 1e6; self.vms = 1e6
29
+ return MemInfo()
30
+ def memory_percent(self): return 1.0
31
+ class DummyPsutil:
32
+ @staticmethod
33
+ def Process(pid=None): return DummyProcess(pid)
34
+ psutil = DummyPsutil()
35
+
36
+ def safe_get_config(config_obj, key, default=None):
37
+ if isinstance(config_obj, dict):
38
+ return config_obj.get(key, default)
39
+ elif hasattr(config_obj, key):
40
+ return getattr(config_obj, key, default)
41
+ return default
42
+
43
+ def safe_get_config_value(config_obj, key, default=None):
44
+ try:
45
+ if isinstance(config_obj, dict):
46
+ return config_obj.get(key, default)
47
+ elif hasattr(config_obj, key):
48
+ return getattr(config_obj, key, default)
49
+ elif isinstance(config_obj, (int, float, str, bool)):
50
+ return config_obj
51
+ return default
52
+ except:
53
+ return default
54
+
55
+ class DatasetManager:
56
+ def __init__(self):
57
+ self.datasets: Dict[str, Any] = {}
58
+ self.lock = Lock()
59
+
60
+ def load_dataset(self, path: str, specialization: str) -> Any:
61
+ with self.lock:
62
+ if specialization in self.datasets:
63
+ logger.info(f"Using cached dataset for {specialization}")
64
+ return self.datasets[specialization]
65
+ dataset = self._load_and_process_dataset(path, specialization)
66
+ self.datasets[specialization] = dataset
67
+ return dataset
68
+
69
+ def _load_and_process_dataset(self, path: str, specialization: str) -> TensorDataset:
70
+ if not os.path.exists(path):
71
+ raise FileNotFoundError(f"Dataset {path} not found.")
72
+ logger.info(f"Loading dataset: {specialization}")
73
+ data = pd.read_csv(path)
74
+ if "label" not in data.columns:
75
+ raise ValueError("Dataset must have a 'label' column.")
76
+ features = data.drop("label", axis=1).values
77
+ labels = data["label"].values
78
+ features_tensor = torch.tensor(features, dtype=torch.float32)
79
+ labels_tensor = torch.tensor(labels, dtype=torch.long)
80
+ return TensorDataset(features_tensor, labels_tensor)
81
+
82
+ def get_status(self) -> Dict[str, Any]:
83
+ return {"loaded_datasets": list(self.datasets.keys()), "cache_size": len(self.datasets)}
84
+
85
+ def clear_cache(self):
86
+ with self.lock:
87
+ self.datasets.clear()
88
+
89
+ class ModelManager:
90
+ def __init__(self, tokenizer=None, max_active_models=5, model_idle_threshold=600):
91
+ self.models = {}
92
+ self.lock = Lock()
93
+ self.model_pool = OrderedDict()
94
+ self.max_active_models = max_active_models if isinstance(max_active_models, int) and max_active_models > 0 else 2
95
+ self.model_idle_threshold = model_idle_threshold if isinstance(model_idle_threshold, int) and model_idle_threshold > 0 else 600
96
+ self.tokenizer = tokenizer
97
+ dataset_paths = safe_get_config(app_config, "DATASET_PATHS", {})
98
+ self.specializations = list(dataset_paths.keys()) if isinstance(dataset_paths, dict) else ["default"]
99
+ self._performance_metrics = {}
100
+ attention_config = get_hybrid_attention_config()
101
+ self.smart_attention = SmartHybridAttention(
102
+ dim=attention_config["DIM"],
103
+ num_heads=attention_config["NUM_HEADS"],
104
+ window_size=attention_config["WINDOW_SIZE"],
105
+ use_sliding=attention_config["USE_SLIDING"],
106
+ use_global=attention_config["USE_GLOBAL"],
107
+ use_hierarchical=attention_config["USE_HIERARCHICAL"],
108
+ global_token_ratio=attention_config["GLOBAL_TOKEN_RATIO"],
109
+ memory_tokens=attention_config["MEMORY_TOKENS"]
110
+ )
111
+ self.dataset_manager = DatasetManager()
112
+ transformer_config = safe_get_config(app_config, "TRANSFORMER_CONFIG", {})
113
+ # Force use of our custom model with no fallback
114
+ model_name = safe_get_config(transformer_config, "MODEL_NAME", "Wildnerve-tlm01-0.05Bx12")
115
+ self.embedding_model = get_sentence_transformer(model_name)
116
+ self.similarity_threshold = safe_get_config(app_config, "SIMILARITY_THRESHOLD", 0.85)
117
+ self.top_k = safe_get_config(app_config, "TOP_K", 3)
118
+ self.prompt_analyzer = None
119
+ self.selected_models = self._get_selected_models()
120
+ logger.info(f"ModelManager initialized with {len(self.specializations)} specializations")
121
+ self._load_models()
122
+
123
+ def _get_selected_models(self) -> List[str]:
124
+ model_files = safe_get_config(app_config, "SELECTED_MODEL", ["model_Custm.py"])
125
+ return model_files if model_files else ["model_Custm.py"]
126
+
127
+ def _import_model_class(self, model_key: str):
128
+ try:
129
+ abs_path = f"{os.path.dirname(__file__)}/{model_key}.py"
130
+ if os.path.exists(abs_path):
131
+ spec = importlib.util.spec_from_file_location(model_key, abs_path)
132
+ module = importlib.util.module_from_spec(spec)
133
+ spec.loader.exec_module(module)
134
+ elif os.path.exists(f"{model_key}.py"):
135
+ spec = importlib.util.spec_from_file_location(model_key, f"{model_key}.py")
136
+ module = importlib.util.module_from_spec(spec)
137
+ spec.loader.exec_module(module)
138
+ else:
139
+ module = importlib.import_module(model_key)
140
+ if module and hasattr(module, "Wildnerve_tlm01"):
141
+ return getattr(module, "Wildnerve_tlm01")
142
+ else:
143
+ logger.warning(f"Module {model_key} missing Wildnerve_tlm01 class")
144
+ return None
145
+ except Exception as e:
146
+ logger.error(f"Failed to import {model_key}: {e}")
147
+ return None
148
+
149
+ def _load_models(self):
150
+ """Initialize models with lazy loading and limited initial specializations"""
151
+ # Define all specializations but only load a minimal subset at startup
152
+ all_specializations = [
153
+ "mbpp",
154
+ "programming_software_dev",
155
+ "machine_learning_ai_data_science",
156
+ "industrial_engineering",
157
+ "science_engineering",
158
+ "mathematics",
159
+ "healthcare_and_lifesciences",
160
+ "chemistry",
161
+ "hardware_devops_cloud",
162
+ "cyber_security",
163
+ "business_legal_finance",
164
+ "other_information"
165
+ ]
166
+
167
+ # Only load 2 specializations at startup to prevent resource exhaustion
168
+ initial_specializations = ["general", "programming_software_dev"]
169
+ self.all_specializations = all_specializations # Store all for later lazy loading
170
+ self.models = {}
171
+
172
+ # Set up data directory
173
+ data_dir = os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data")
174
+ os.makedirs(data_dir, exist_ok=True)
175
+
176
+ # Only initialize the minimal subset at startup
177
+ for spec in initial_specializations:
178
+ try:
179
+ self._initialize_model_for_specialization(spec, data_dir)
180
+ logger.info(f"Initialized model for {spec}")
181
+ except Exception as e:
182
+ logger.error(f"Error initializing model for {spec}: {e}")
183
+
184
+ logger.info(f"Loaded {len(self.models)} initial models, {len(all_specializations)} total available")
185
+ return True
186
+
187
+ def _initialize_model_for_specialization(self, spec, data_dir):
188
+ """Initialize a single model with proper error handling and timeouts"""
189
+ # Get dataset path with fallbacks
190
+ dataset_path = None
191
+ try:
192
+ if isinstance(app_config, dict) and "DATASET_PATHS" in app_config:
193
+ dataset_path = app_config["DATASET_PATHS"].get(spec)
194
+ elif hasattr(app_config, "DATASET_PATHS"):
195
+ dataset_path = getattr(app_config.DATASET_PATHS, spec, None)
196
+ except Exception as e:
197
+ logger.warning(f"Error getting dataset path: {e}")
198
+
199
+ # Use default path if not provided
200
+ if not dataset_path:
201
+ dataset_path = os.path.join(data_dir, f"{spec}.csv")
202
+
203
+ # Create minimal dataset if needed
204
+ if not os.path.exists(dataset_path):
205
+ try:
206
+ with open(dataset_path, "w") as f:
207
+ f.write("text,label\n")
208
+ f.write(f"sample {spec} text,0\n")
209
+ logger.info(f"Created minimal dataset for {spec}")
210
+ except Exception as e:
211
+ logger.error(f"Error creating dataset for {spec}: {e}")
212
+
213
+ # Create model with timeout protection
214
+ start_time = time.time()
215
+ timeout = 30 # 30 second timeout for model creation
216
+
217
+ try:
218
+ # Import with timeout check to avoid hanging
219
+ from model_Custm import Wildnerve_tlm01
220
+
221
+ # Initialize model with appropriate parameters
222
+ model = Wildnerve_tlm01(
223
+ vocab_size=30522,
224
+ specialization=spec,
225
+ dataset_path=dataset_path,
226
+ model_name="Wildnerve-tlm01-0.05Bx12",
227
+ embedding_dim=768,
228
+ num_heads=12,
229
+ hidden_dim=768,
230
+ num_layers=6,
231
+ output_size=768,
232
+ dropout=0.1,
233
+ max_seq_length=512,
234
+ pooling_mode="mean",
235
+ tokenizer=self.tokenizer
236
+ )
237
+
238
+ # Add model to the pool
239
+ self.models[spec] = model
240
+ self.model_pool[spec] = None
241
+ self._performance_metrics[spec] = {
242
+ "inference_time": 0.0,
243
+ "memory_usage": 0.0,
244
+ "last_accessed": time.time(),
245
+ "num_inferences": 0
246
+ }
247
+
248
+ # Check for timeout
249
+ if time.time() - start_time > timeout:
250
+ logger.warning(f"Model creation for {spec} took longer than {timeout}s!")
251
+
252
+ except Exception as e:
253
+ logger.error(f"Error creating model for {spec}: {e}")
254
+ raise
255
+
256
+ def get_or_create_model(self, specialization: str) -> Any:
257
+ """Get an existing model or create it on demand if not already loaded"""
258
+ with self.lock:
259
+ # Check if model already exists
260
+ model = self.get_model(specialization)
261
+ if model:
262
+ logger.info(f"Using existing model for {specialization}")
263
+ return model
264
+
265
+ # Check if it's a valid specialization
266
+ if specialization not in self.all_specializations and specialization != "general":
267
+ logger.warning(f"Unknown specialization: {specialization}, using general")
268
+ specialization = "general"
269
+
270
+ # Create model if needed
271
+ logger.info(f"Lazily loading model for {specialization}")
272
+
273
+ # Remove least recently used model if needed
274
+ if len(self.models) >= self.max_active_models:
275
+ lru_specialization = next(iter(self.model_pool))
276
+ self.remove_model_instance(lru_specialization)
277
+
278
+ # Initialize the requested model
279
+ data_dir = os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data")
280
+ try:
281
+ self._initialize_model_for_specialization(specialization, data_dir)
282
+ return self.models.get(specialization)
283
+ except Exception as e:
284
+ logger.error(f"Error initializing model: {e}")
285
+
286
+ # Fallback to general model
287
+ if specialization != "general" and "general" in self.models:
288
+ return self.models["general"]
289
+
290
+ # Last resort - create a minimal model
291
+ return self._create_minimal_model()
292
+
293
+ def _create_minimal_model(self):
294
+ """Create a minimal fallback model for emergencies"""
295
+ try:
296
+ from model_Custm import Wildnerve_tlm01
297
+ model = Wildnerve_tlm01(
298
+ vocab_size=30522,
299
+ specialization="minimal",
300
+ dataset_path=None,
301
+ model_name="bert-base-uncased", # Use simpler base model
302
+ embedding_dim=768,
303
+ num_heads=12,
304
+ hidden_dim=768,
305
+ num_layers=2, # Reduced layers
306
+ output_size=768,
307
+ dropout=0.1,
308
+ max_seq_length=128, # Reduced sequence length
309
+ pooling_mode="mean",
310
+ tokenizer=self.tokenizer
311
+ )
312
+ model._is_minimal = True # Mark as minimal model
313
+ return model
314
+ except Exception as e:
315
+ logger.error(f"Failed to create minimal model: {e}")
316
+ return None
317
+
318
+ def get_model(self, specialization: str) -> Any:
319
+ with self.lock:
320
+ model = self.models.get(specialization)
321
+ if model:
322
+ self.model_pool.move_to_end(specialization)
323
+ if specialization in self._performance_metrics:
324
+ self._performance_metrics[specialization]["last_accessed"] = time.time()
325
+ return model
326
+
327
+ def route_input(self, input_text: str) -> dict:
328
+ input_embedding = self.embedding_model.encode(input_text)
329
+ similarities = {}
330
+ for spec in self.specializations:
331
+ model = self.get_model(spec)
332
+ if model and hasattr(model, "embedding"):
333
+ sim = cosine_similarity(input_embedding.reshape(1, -1), model.embedding.reshape(1, -1))[0][0]
334
+ similarities[spec] = sim
335
+ if similarities:
336
+ best_match = max(similarities.items(), key=lambda x: x[1])
337
+ return {"matched_specialization": best_match[0], "confidence": best_match[1], "all_scores": similarities}
338
+ return {"matched_specialization": self.specializations[0], "confidence": 0.0, "all_scores": similarities}
339
+
340
+ def get_model_for_prompt(self, prompt: str) -> Tuple[Any, str]:
341
+ try:
342
+ routing_result = self.route_input(prompt)
343
+ specialization = routing_result.get("matched_specialization", self.specializations[0])
344
+ model = self.get_or_create_model(specialization)
345
+ start_time = time.time()
346
+ def update_metrics():
347
+ if specialization in self._performance_metrics:
348
+ m = self._performance_metrics[specialization]
349
+ elapsed = time.time() - start_time
350
+ n = m.get("num_inferences", 0) + 1
351
+ m["inference_time"] = ((m.get("inference_time", 0) * (n-1)) + elapsed) / n
352
+ m["num_inferences"] = n
353
+ m["last_accessed"] = time.time()
354
+ if hasattr(model, "get_memory_usage"):
355
+ m["memory_usage"] = model.get_memory_usage()
356
+ update_metrics()
357
+ return model, specialization
358
+ except Exception as e:
359
+ logger.error(f"Error selecting model: {e}")
360
+ if self.models:
361
+ default_key = list(self.models.keys())[0]
362
+ return self.models[default_key], default_key
363
+ else:
364
+ logger.error("No models available for routing")
365
+ return None, "none"
366
+
367
+ def generate(self, prompt: str, **kwargs):
368
+ self.validate_input(prompt)
369
+ model, specialization = self.get_model_for_prompt(prompt)
370
+ start_time = time.time()
371
+ try:
372
+ result = model.generate(prompt=prompt, **kwargs)
373
+ elapsed = time.time() - start_time
374
+ if specialization in self._performance_metrics:
375
+ m = self._performance_metrics[specialization]
376
+ n = m.get("num_inferences", 0) + 1
377
+ m["inference_time"] = ((m.get("inference_time", 0) * (n-1)) + elapsed) / n
378
+ m["num_inferences"] = n
379
+ m["last_accessed"] = time.time()
380
+ return result
381
+ except Exception as e:
382
+ logger.error(f"Error generating with {specialization}: {e}")
383
+ default_spec = self.specializations[0]
384
+ default_model = self.get_or_create_model(default_spec)
385
+ return default_model.generate(prompt=prompt, **kwargs)
386
+
387
+ def generate_streaming(self, prompt: str, **kwargs):
388
+ self.validate_input(prompt)
389
+ model, specialization = self.get_model_for_prompt(prompt)
390
+ start_time = time.time()
391
+ try:
392
+ if hasattr(model, "generate_streaming"):
393
+ for token in model.generate_streaming(prompt=prompt, **kwargs):
394
+ yield token
395
+ else:
396
+ logger.info("Simulating streaming generation")
397
+ result = model.generate(prompt=prompt, **kwargs)
398
+ for word in result.split():
399
+ yield word + " "
400
+ elapsed = time.time() - start_time
401
+ if specialization in self._performance_metrics:
402
+ m = self._performance_metrics[specialization]
403
+ n = m.get("num_inferences", 0) + 1
404
+ m["inference_time"] = ((m.get("inference_time", 0) * (n-1)) + elapsed) / n
405
+ m["num_inferences"] = n
406
+ m["last_accessed"] = time.time()
407
+ except Exception as e:
408
+ logger.error(f"Error in streaming generation: {e}")
409
+ default_spec = self.specializations[0]
410
+ default_model = self.get_or_create_model(default_spec)
411
+ if hasattr(default_model, "generate_streaming"):
412
+ for token in default_model.generate_streaming(prompt=prompt, **kwargs):
413
+ yield token
414
+ else:
415
+ fallback_result = default_model.generate(prompt=prompt, **kwargs)
416
+ for word in fallback_result.split():
417
+ yield word + " "
418
+
419
+ def remove_model_instance(self, specialization: str) -> bool:
420
+ with self.lock:
421
+ if specialization in self.models:
422
+ del self.models[specialization]
423
+ self.model_pool.pop(specialization, None)
424
+ gc.collect()
425
+ if torch.cuda.is_available():
426
+ torch.cuda.empty_cache()
427
+ logger.info(f"Removed model for {specialization}")
428
+ return True
429
+ return False
430
+
431
+ def validate_input(self, input_text: str) -> bool:
432
+ if not input_text or len(input_text.strip()) == 0:
433
+ raise ValueError("Empty input text")
434
+ max_length = safe_get_config(app_config, "MAX_INPUT_LENGTH", safe_get_config(app_config, "MAX_SEQ_LENGTH", 128))
435
+ if len(input_text) > max_length:
436
+ raise ValueError(f"Input exceeds maximum length of {max_length}")
437
+ return True
438
+
439
+ def get_health_status(self) -> Dict[str, Any]:
440
+ with self.lock:
441
+ process = psutil.Process(os.getpid())
442
+ mem_info = process.memory_info()
443
+ return {
444
+ "active_models": len(self.models),
445
+ "memory_usage": {
446
+ "rss_mb": mem_info.rss / (1024 * 1024),
447
+ "vms_mb": mem_info.vms / (1024 * 1024),
448
+ "percent": process.memory_percent()
449
+ },
450
+ "model_performance": self._get_model_metrics(),
451
+ "dataset_status": self.dataset_manager.get_status(),
452
+ "cache_efficiency": len(self.model_pool) / max(1, self.max_active_models)
453
+ }
454
+
455
+ def _get_model_metrics(self) -> Dict[str, Dict[str, Any]]:
456
+ metrics = {}
457
+ for spec, model in self.models.items():
458
+ base = self._performance_metrics.get(spec, {})
459
+ mem_usage = 0
460
+ if hasattr(model, "get_memory_usage"):
461
+ mem_usage = model.get_memory_usage()
462
+ elif hasattr(model, "parameters"):
463
+ mem_usage = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
464
+ metrics[spec] = {
465
+ "inference_time": base.get("inference_time", 0),
466
+ "memory_usage_mb": mem_usage,
467
+ "last_accessed": base.get("last_accessed", 0),
468
+ "num_inferences": base.get("num_inferences", 0),
469
+ "model_type": model.__class__.__name__
470
+ }
471
+ return metrics
472
+
473
+ def get_available_models(self) -> Dict[str, Any]:
474
+ with self.lock:
475
+ return dict(self.models)
476
+
477
+ def shutdown(self):
478
+ try:
479
+ logger.info("Initiating shutdown")
480
+ for spec in list(self.models.keys()):
481
+ self.remove_model_instance(spec)
482
+ self.dataset_manager.clear_cache()
483
+ logger.info("Shutdown complete")
484
+ except Exception as e:
485
+ logger.error(f"Error during shutdown: {e}")
486
+
487
+ def manage_model_cache(self):
488
+ try:
489
+ current = time.time()
490
+ with self.lock:
491
+ while len(self.models) > self.max_active_models:
492
+ oldest = next(iter(self.model_pool))
493
+ self.remove_model_instance(oldest)
494
+ logger.info(f"Removed LRU model: {oldest}")
495
+ for spec, last in list(self.model_pool.items()):
496
+ m = self._performance_metrics.get(spec, {})
497
+ if m.get("last_accessed", 0) and (current - m["last_accessed"] > self.model_idle_threshold):
498
+ self.remove_model_instance(spec)
499
+ logger.info(f"Removed idle model: {spec}")
500
+ sorted_models = sorted(self.model_pool.items(), key=lambda x: self._performance_metrics.get(x[0], {}).get("last_accessed", 0), reverse=True)
501
+ self.model_pool = OrderedDict(sorted_models)
502
+ except Exception as e:
503
+ logger.error(f"Error in cache management: {e}")
504
+
505
+ def set_tokenizer(self, tokenizer):
506
+ self.tokenizer = tokenizer
507
+ with self.lock:
508
+ for name, model in self.models.items():
509
+ if hasattr(model, "set_tokenizer"):
510
+ try:
511
+ model.tokenizer = tokenizer
512
+ logger.debug(f"Updated tokenizer for {name}")
513
+ except Exception as ex:
514
+ logger.warning(f"Failed to set tokenizer for {name}: {ex}")
515
+ logger.info("Tokenizer updated for models")
516
+ return self
517
+
518
+ def initialize_models(self):
519
+ try:
520
+ logger.info("Initializing models from weights")
521
+ prompt_analyzer = registry.get("prompt_analyzer")
522
+ if not prompt_analyzer:
523
+ try:
524
+ from pathlib import Path
525
+ model_list_path = Path(__file__).parent / "model_List.py"
526
+ if model_list_path.exists():
527
+ spec = importlib.util.find_spec("model_List")
528
+ if spec:
529
+ model_list = importlib.util.module_from_spec(spec)
530
+ spec.loader.exec_module(model_list)
531
+ if hasattr(model_list, "PromptAnalyzer"):
532
+ prompt_analyzer = model_list.PromptAnalyzer()
533
+ registry.register("prompt_analyzer", prompt_analyzer)
534
+ logger.info("Imported PromptAnalyzer")
535
+ except Exception as e:
536
+ logger.error(f"Error importing PromptAnalyzer: {e}")
537
+ self.prompt_analyzer = prompt_analyzer
538
+ selected_models_list = prompt_analyzer.get_selected_models() if prompt_analyzer and hasattr(prompt_analyzer, "get_selected_models") else ["model_Custm.py"]
539
+ logger.info(f"Selected model types: {selected_models_list}")
540
+ specializations = ["general", "programming", "science", "history", "mathematics"]
541
+ for spec in specializations:
542
+ try:
543
+ model_name = selected_models_list[0].replace(".py", "")
544
+ from pathlib import Path
545
+ model_path = Path(__file__).parent / f"{model_name}.py"
546
+ if model_path.exists():
547
+ spec_obj = importlib.util.find_spec(model_name)
548
+ if spec_obj:
549
+ model_module = importlib.util.module_from_spec(spec_obj)
550
+ spec_obj.loader.exec_module(model_module)
551
+ if hasattr(model_module, "Wildnerve_tlm01"):
552
+ model_class = getattr(model_module, "Wildnerve_tlm01")
553
+ embedding_dim = 768
554
+ num_heads = 12 if embedding_dim % 12 == 0 else 1
555
+ model_instance = model_class(
556
+ vocab_size=30522,
557
+ specialization=spec,
558
+ dataset_path=None,
559
+ model_name="bert-base-uncased",
560
+ embedding_dim=embedding_dim,
561
+ num_heads=num_heads,
562
+ hidden_dim=768,
563
+ num_layers=2,
564
+ output_size=768,
565
+ dropout=0.1,
566
+ max_seq_length=128,
567
+ pooling_mode="mean"
568
+ )
569
+ self.models[spec] = model_instance
570
+ logger.info(f"Created model for {spec}")
571
+ except Exception as e:
572
+ logger.error(f"Error creating model for {spec}: {e}")
573
+ if not self.models:
574
+ logger.error("No models created")
575
+ return False
576
+ try:
577
+ import os
578
+ attention_config_path = os.path.join(app_config.DATA_DIR, "attention_configuration.json")
579
+ from utils.attention_connector import get_attention_connector
580
+ attention_connector = get_attention_connector()
581
+ if hasattr(attention_connector, "config_path"):
582
+ attention_connector.config_path = attention_config_path
583
+ attention_connector._init_profile_selector()
584
+ logger.info(f"Initialized attention connector with config: {attention_config_path}")
585
+ except Exception as e:
586
+ logger.warning(f"Failed to initialize attention connector: {e}")
587
+ logger.info(f"Successfully initialized {len(self.models)} models")
588
+ return True
589
+ except Exception as e:
590
+ logger.error(f"Error initializing models: {e}", exc_info=True)
591
+ return False
592
+
593
+ def get_alternative_model_for_prompt(self, prompt: str, current_model=None) -> any:
594
+ try:
595
+ if self.prompt_analyzer and hasattr(self.prompt_analyzer, "choose_model"):
596
+ model_type = self.prompt_analyzer.choose_model(prompt)
597
+ if model_type:
598
+ try:
599
+ alt_model = model_type(
600
+ vocab_size=30522,
601
+ specialization="general",
602
+ dataset_path=None,
603
+ model_name="bert-base-uncased",
604
+ embedding_dim=768,
605
+ num_heads=12,
606
+ hidden_dim=768,
607
+ num_layers=6,
608
+ output_size=768,
609
+ dropout=0.1,
610
+ max_seq_length=512,
611
+ pooling_mode="mean",
612
+ tokenizer=self.tokenizer
613
+ )
614
+ if alt_model != current_model:
615
+ logger.info("Found alternative model via prompt_analyzer")
616
+ return alt_model
617
+ except Exception as e:
618
+ logger.error(f"Error initializing alternative model: {e}")
619
+ for name, model in self.get_available_models().items():
620
+ if model != current_model:
621
+ logger.info(f"Using alternative model: {name}")
622
+ return model
623
+ try:
624
+ from model_Custm import Wildnerve_tlm01
625
+ fallback_model = Wildnerve_tlm01(
626
+ vocab_size=30522,
627
+ specialization="general",
628
+ model_name="bert-base-uncased",
629
+ embedding_dim=768,
630
+ num_heads=12,
631
+ hidden_dim=768,
632
+ num_layers=6,
633
+ output_size=30522,
634
+ dropout=0.1,
635
+ max_seq_length=512,
636
+ pooling_mode="mean",
637
+ tokenizer=self.tokenizer
638
+ )
639
+ logger.info("Created fallback model")
640
+ return fallback_model
641
+ except Exception as e:
642
+ logger.error(f"Error creating fallback model: {e}")
643
+ return None
644
+ except Exception as e:
645
+ logger.error(f"Error getting alternative model: {e}")
646
+ return None
647
+
648
+ def prepare_model_input(self, text: str, model) -> dict:
649
+ device = next(model.parameters()).device
650
+ try:
651
+ tokenizer = getattr(model, "tokenizer", None)
652
+ if tokenizer:
653
+ inputs = tokenizer(
654
+ text,
655
+ return_tensors="pt",
656
+ padding=True,
657
+ truncation=True,
658
+ max_length=safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512)
659
+ )
660
+ input_ids = inputs["input_ids"].to(device)
661
+ return {"input_ids": input_ids, "max_length": safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512), "device": device, "temperature": getattr(self, "generation_config", {}).get("temperature", 0.7)}
662
+ else:
663
+ logger.warning("No tokenizer in model; using basic input")
664
+ return {"input_text": text, "max_length": safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512)}
665
+ except Exception as e:
666
+ logger.error(f"Error preparing model input: {e}")
667
+ return {"input_text": text}
668
+
669
+ def process_with_context(self, input_text: str, context: Optional[dict] = None) -> dict:
670
+ conversation_context = self.get_conversation_context(window_size=3)
671
+ contextualized_prompt = input_text
672
+ if conversation_context:
673
+ max_seq_length = safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512)
674
+ max_seq_length = int(max_seq_length) if isinstance(max_seq_length, (int, float)) else 512
675
+ contextualized_prompt = f"Previous conversation:\n{conversation_context}\n\nCurrent question: {input_text}"
676
+ result = self.process_input(contextualized_prompt, context)
677
+ if isinstance(result, dict):
678
+ result["original_query"] = input_text
679
+ return result
680
+
681
+ def get_conversation_context(self, window_size: int = 3) -> str:
682
+ if not hasattr(self, "conversation_history"):
683
+ self.conversation_history = []
684
+ recent = self.conversation_history[-window_size*2:]
685
+ lines = []
686
+ for entry in recent:
687
+ prefix = "User: " if entry["role"]=="user" else "Assistant: "
688
+ lines.append(f"{prefix}{entry['content']}")
689
+ return "\n".join(lines)
690
+
691
+ # Factory methods for model manager creation
692
+ def create_model_manager(tokenizer=None) -> ModelManager:
693
+ try:
694
+ max_active_models = safe_get_config_value(app_config, "MAX_ACTIVE_MODELS", 2)
695
+ model_idle_threshold = safe_get_config_value(app_config, "MODEL_IDLE_THRESHOLD", 600)
696
+ manager = ModelManager(tokenizer=tokenizer, max_active_models=max_active_models, model_idle_threshold=model_idle_threshold)
697
+ if tokenizer:
698
+ manager.set_tokenizer(tokenizer)
699
+ elif registry.has(TOKENIZER):
700
+ manager.set_tokenizer(registry.get(TOKENIZER))
701
+ registry.register(MODEL_MANAGER, manager)
702
+ return manager
703
+ except Exception as e:
704
+ logger.error(f"Error creating ModelManager: {e}")
705
+ minimal_manager = ModelManager(tokenizer=tokenizer, max_active_models=1)
706
+ registry.register(MODEL_MANAGER, minimal_manager)
707
+ return minimal_manager
708
+
709
+ def create_model_manager_with_tokenizer(tokenizer):
710
+ try:
711
+ max_active_models = safe_get_config_value(app_config, "MAX_ACTIVE_MODELS", 2)
712
+ model_idle_threshold = safe_get_config_value(app_config, "MODEL_IDLE_THRESHOLD", 600)
713
+ manager = ModelManager(max_active_models=max_active_models, model_idle_threshold=model_idle_threshold)
714
+ manager.tokenizer = tokenizer
715
+ manager.initialize_models()
716
+ registry.register(MODEL_MANAGER, manager)
717
+ return manager
718
+ except Exception as e:
719
+ logger.error(f"Error creating ModelManager with tokenizer: {e}")
720
+ minimal_manager = ModelManager(max_active_models=1)
721
+ minimal_manager.tokenizer = tokenizer
722
+ registry.register(MODEL_MANAGER, minimal_manager)
723
+ return minimal_manager
724
+
725
+ if __name__ == "__main__":
726
+ tokenizer = registry.get(TOKENIZER)
727
+ if not tokenizer:
728
+ from utils.transformer_utils import get_tokenizer
729
+ tokenizer = get_tokenizer("bert-base-uncased")
730
+ registry.register(TOKENIZER, tokenizer)
731
+ model_manager = create_model_manager(tokenizer)
732
+ logger.info(f"Model Manager initialized with {len(model_manager.models)} models")
733
+ else:
734
+ model_manager = None
735
+ logger.info("ModelManager module imported; initialization deferred")