WildnerveAI commited on
Commit
fa96858
·
verified ·
1 Parent(s): 6145b18

Upload 8 files

Browse files
Files changed (8) hide show
  1. adapter_layer.py +13 -38
  2. communicator.py +1052 -0
  3. config.json +1 -1
  4. config.py +5 -0
  5. handler.py +12 -90
  6. model_List.py +4 -19
  7. model_manager.py +27 -1
  8. service_registry.py +3 -0
adapter_layer.py CHANGED
@@ -9,6 +9,7 @@ import traceback
9
  import codecarbon
10
  import importlib.util
11
  from typing import Dict, Any, Optional, List
 
12
 
13
  # Directly import the packages that are now installed
14
  try:
@@ -143,49 +144,23 @@ class WildnerveModelAdapter:
143
  }
144
 
145
  def generate(self, prompt: str, **kwargs) -> str:
146
- """Generate a response to the prompt"""
147
  if not self.initialized or self.model is None:
148
  logger.error("Model not initialized for generation")
149
  return "Error: Model not properly initialized"
150
-
151
  try:
152
- logger.info(f"Generating with {type(self.model).__name__} for prompt: {prompt[:50]}...")
153
-
154
- # Check for streaming capability
155
- if hasattr(self.model, "generate_streaming"):
 
 
156
  try:
157
- logger.info("Using streaming generation method")
158
- tokens = []
159
- for token in self.model.generate_streaming(prompt, **kwargs):
160
- tokens.append(token)
161
- return "".join(tokens)
162
  except Exception as e:
163
- logger.warning(f"Streaming generation failed: {e}")
164
-
165
- # Try standard generate methods
166
- gen_methods = ["generate_with_decoding", "generate"]
167
- for method_name in gen_methods:
168
- if hasattr(self.model, method_name):
169
- try:
170
- logger.info(f"Using {method_name} generation method")
171
- # Tokenize the input if needed
172
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
173
-
174
- # Get the result
175
- method = getattr(self.model, method_name)
176
- result = method(input_ids, **kwargs)
177
-
178
- if isinstance(result, str) and result:
179
- return result
180
- except Exception as e:
181
- logger.warning(f"{method_name} failed: {e}")
182
- logger.warning(traceback.format_exc())
183
-
184
- # If we get here, try a simple direct generate method
185
- logger.info("Using direct generate method")
186
- return self.model.generate(prompt, **kwargs)
187
-
188
  except Exception as e:
189
- logger.error(f"Error in generate: {e}")
190
- logger.error(traceback.format_exc())
191
  return f"Error generating response: {str(e)}"
 
9
  import codecarbon
10
  import importlib.util
11
  from typing import Dict, Any, Optional, List
12
+ from service_registry import registry, PRETRAINED_MODEL
13
 
14
  # Directly import the packages that are now installed
15
  try:
 
144
  }
145
 
146
  def generate(self, prompt: str, **kwargs) -> str:
147
+ """Generate a combined response: custom model then pretrained model."""
148
  if not self.initialized or self.model is None:
149
  logger.error("Model not initialized for generation")
150
  return "Error: Model not properly initialized"
 
151
  try:
152
+ # (1) custom-specialized inference
153
+ tech_output = self.model.generate(prompt, **kwargs)
154
+
155
+ # (2) append general pretrained-model output if registered
156
+ pre = registry.get(PRETRAINED_MODEL)
157
+ if pre:
158
  try:
159
+ gen_output = pre.generate(prompt, **kwargs)
160
+ return f"{tech_output.strip()}\n\n{gen_output.strip()}"
 
 
 
161
  except Exception as e:
162
+ logger.warning(f"Pretrained model generate failed: {e}")
163
+ return tech_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  except Exception as e:
165
+ logger.error(f"Error in generate: {e}", exc_info=True)
 
166
  return f"Error generating response: {str(e)}"
communicator.py ADDED
@@ -0,0 +1,1052 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import time
4
+ import torch
5
+ import logging
6
+ from threading import Lock
7
+ from config import app_config, load_config
8
+ from model_manager import safe_get_config_value # Added import to fix error on line 458
9
+ from typing import Dict, List, Optional, Union, Any, Tuple
10
+ # Import ModelManager as a type hint only to avoid circular imports
11
+ from typing import TYPE_CHECKING
12
+ if TYPE_CHECKING:
13
+ from model_manager import ModelManager
14
+
15
+ # Import service registry for dependencies
16
+ from service_registry import registry, MODEL, TOKENIZER, MODEL_MANAGER, COMMUNICATOR
17
+
18
+ # Then import other dependencies
19
+ from utils.sentence_transformer_utils import get_sentence_transformer
20
+ from utils.output_formatter import OutputFormatter
21
+ from sklearn.metrics.pairwise import cosine_similarity
22
+ # Import base interfaces
23
+ from base_interfaces.common_types import *
24
+ from base_interfaces.communicator_interface import AbstractCommunicator
25
+ # Import hybrid attention utils - update this import
26
+ from utils.smartHybridAttention import get_hybrid_attention_config
27
+
28
+ # Conditional imports for SNN/STDP functionality
29
+ try:
30
+ from snntorch._neurons.lapicque import LIF
31
+ from snntorch import spikegen
32
+ from snntorch._neurons import Synaptic
33
+ from communicator_STDP import Communicator_STDP
34
+ SNNTORCH_AVAILABLE = True
35
+ except ImportError:
36
+ SNNTORCH_AVAILABLE = False
37
+ logger.warning("SNN/STDP functionality not available - some features will be disabled")
38
+
39
+ # Configure logging for the module
40
+ logger = logging.getLogger(__name__)
41
+ logging.basicConfig(level=logging.INFO)
42
+
43
+ # Gracefully handle psutil import - only do this once
44
+ try:
45
+ import psutil
46
+ PSUTIL_AVAILABLE = True
47
+ except ImportError:
48
+ logger.warning("psutil not available - cannot monitor system resources")
49
+ PSUTIL_AVAILABLE = False
50
+ # Create a minimal psutil-like interface for compatibility
51
+ class DummyProcess:
52
+ def __init__(self, pid=None):
53
+ self.pid = pid or 1
54
+
55
+ def memory_info(self):
56
+ class MemInfo:
57
+ def __init__(self):
58
+ self.rss = 1000000 # 1 MB
59
+ self.vms = 1000000 # 1 MB
60
+ return MemInfo()
61
+ def memory_percent(self):
62
+ return 1.0 # 1%
63
+
64
+ class DummyPsutil:
65
+ @staticmethod
66
+ def Process(pid=None):
67
+ return DummyProcess(pid)
68
+ psutil = DummyPsutil()
69
+
70
+ # The Communicator class implementation
71
+ class Communicator(AbstractCommunicator):
72
+ def __init__(self, models: Dict[str, torch.nn.Module] = None, model_manager=None):
73
+ """Initialize the Communicator with a model manager and necessary components."""
74
+ self.lock = Lock()
75
+ self.config = load_config()
76
+ self.similarity_threshold = app_config.SIMILARITY_THRESHOLD
77
+ self.top_k = app_config.TOP_K
78
+ self.conversation_history = []
79
+ self.shared_layers = [
80
+ 'encoder.layer.0', # Often early layers capture general language features
81
+ 'encoder.layer.1',
82
+ 'embeddings' # Embeddings are often beneficial to share
83
+ ]
84
+
85
+ # Initialize model manager - fixed to avoid circular imports
86
+ self._init_model_manager(model_manager)
87
+
88
+ # Initialize components
89
+ self.output_formatter = OutputFormatter()
90
+ self.embedding_model = get_sentence_transformer("Wildnerve-tlm01-0.05Bx12")
91
+
92
+ # Get models and compute specialization embeddings
93
+ self._init_models_and_embeddings()
94
+
95
+ # Initialize SNN/STDP components if enabled
96
+ self._init_snn_components()
97
+
98
+ # Initialize with attention configuration
99
+ self.attention_config = get_hybrid_attention_config()
100
+
101
+ # Update attention config from app_config
102
+ if hasattr(app_config, 'TRANSFORMER_CONFIG') and hasattr(app_config.TRANSFORMER_CONFIG, 'ATTENTION_MECHANISM'):
103
+ attn_mech = app_config.TRANSFORMER_CONFIG.ATTENTION_MECHANISM
104
+ if isinstance(attn_mech, dict):
105
+ for key, value in attn_mech.items():
106
+ if key in self.attention_config:
107
+ self.attention_config[key] = value
108
+
109
+ # Initialize tokenizer - set this directly to avoid attribute errors later
110
+ self.tokenizer = self._init_tokenizer()
111
+
112
+ logger.info("Communicator initialized successfully")
113
+
114
+ def _init_tokenizer(self):
115
+ """Initialize the tokenizer with proper error handling"""
116
+ try:
117
+ if registry.has(TOKENIZER):
118
+ return registry.get(TOKENIZER)
119
+ from transformers import AutoTokenizer
120
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
121
+ logger.info("Tokenizer initialized in communicator")
122
+ registry.register(TOKENIZER, tokenizer) # Register only if not present
123
+ return tokenizer
124
+ except Exception as e:
125
+ logger.error(f"Tokenizer initialization failed: {e}")
126
+ return None
127
+
128
+ def _init_model_manager(self, model_manager):
129
+ """Helper method to initialize model manager"""
130
+ if model_manager is None:
131
+ # Delayed import to avoid circular reference
132
+ from model_manager import ModelManager
133
+
134
+ try:
135
+ max_active_models = getattr(app_config, 'MAX_ACTIVE_MODELS', 5)
136
+ self.model_manager = ModelManager(max_active_models=max_active_models)
137
+ logger.info(f"Created ModelManager with max_active_models={max_active_models}")
138
+ except Exception as e:
139
+ logger.error(f"Error creating ModelManager: {e}")
140
+ self.model_manager = None
141
+ else:
142
+ self.model_manager = model_manager
143
+
144
+ def _init_models_and_embeddings(self):
145
+ """Initialize models and compute embeddings for specializations"""
146
+ # Always force primary sentence transformer usage.
147
+ self.embedding_model = get_sentence_transformer("Wildnerve-tlm01-0.05Bx12")
148
+ self.models = self.model_manager.get_available_models() if self.model_manager else {}
149
+ if not self.models:
150
+ logger.warning("No models available in model manager")
151
+
152
+ # Create embeddings for each specialization
153
+ self.specialization_embeddings = {}
154
+
155
+ if self.model_manager:
156
+ # Access specializations through models dictionary keys
157
+ specializations = []
158
+ if hasattr(self.model_manager, 'models'):
159
+ specializations = list(self.model_manager.models.keys())
160
+ elif hasattr(self.model_manager, 'get_available_models'):
161
+ specializations = list(self.model_manager.get_available_models().keys())
162
+
163
+ for spec in specializations:
164
+ self.specialization_embeddings[spec] = self.embedding_model.encode(spec, convert_to_numpy=True)
165
+
166
+ # Compute weight sharing groups based on cosine similarity
167
+ self.weight_sharing_groups = self.create_weight_sharing_groups(self.similarity_threshold)
168
+ logger.info("Computed weight sharing groups: %s", self.weight_sharing_groups)
169
+
170
+ def _init_snn_components(self):
171
+ """Initialize SNN/STDP components if enabled"""
172
+ # Check if SNN should be used
173
+ use_snn = self._get_config_value('STDP_CONFIG', 'USE_SNN', False)
174
+
175
+ if use_snn:
176
+ # Determine device (CPU/GPU)
177
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
178
+
179
+ # Get configuration values safely
180
+ alpha = self._get_config_value('STDP_CONFIG', 'ALPHA', 0.1)
181
+ beta = self._get_config_value('STDP_CONFIG', 'BETA', 0.2)
182
+ spike_threshold = self._get_config_value('STDP_CONFIG', 'SpikeThreshold', 0.5)
183
+
184
+ # Initialize components
185
+ self.synapse_weights = Synaptic(alpha=alpha, beta=beta)
186
+ self.spike_threshold = spike_threshold
187
+ self.spike_generator = spikegen.rate
188
+ self.beta = beta
189
+ self.snn_layer = LIF(beta=self.beta)
190
+ self.snn_comm = Communicator_STDP(self.models, device=self.device)
191
+ self.mem = torch.zeros(1, 1)
192
+ self.spk = torch.zeros(1, 1)
193
+ logger.info("SNN/STDP components initialized successfully")
194
+ else:
195
+ self.device = None
196
+ self.snn_comm = None
197
+ logger.info("SNN/STDP components not enabled")
198
+
199
+ def _get_config_value(self, config_name, attribute, default=None):
200
+ """Safely retrieve configuration values handling both dict and object access"""
201
+ if not hasattr(app_config, config_name):
202
+ return default
203
+
204
+ config_obj = getattr(app_config, config_name)
205
+
206
+ if isinstance(config_obj, dict):
207
+ return config_obj.get(attribute, default)
208
+ else:
209
+ return getattr(config_obj, attribute, default)
210
+
211
+ def create_weight_sharing_groups(self, similarity_threshold: float) -> Dict[str, set]:
212
+ """Computes cosine similarities among specialization embeddings and groups
213
+ specializations that exceed the similarity threshold to enable weight sharing. Returns:
214
+ Dictionary of groups: {specialization: [other_specializations_exceeding_threshold]}"""
215
+ groups = {}
216
+ for spec1, emb1 in self.specialization_embeddings.items():
217
+ for spec2, emb2 in self.specialization_embeddings.items():
218
+ if spec1 != spec2:
219
+ # Compute similarity
220
+ similarity = cosine_similarity(
221
+ emb1.reshape(1, -1),
222
+ emb2.reshape(1, -1)
223
+ )[0][0]
224
+
225
+ if similarity > similarity_threshold:
226
+ if spec1 not in groups:
227
+ groups[spec1] = set()
228
+ groups[spec1].add(spec2)
229
+ return groups
230
+
231
+ def share_weights(self):
232
+ """Share weights between models based on their computed similarity groups."""
233
+ with self.lock:
234
+ for primary_spec, related_specs in self.weight_sharing_groups.items():
235
+ primary_model = self.model_manager.get_model(primary_spec)
236
+ if not primary_model:
237
+ continue
238
+
239
+ # Share weights from primary model to all related models in the group
240
+ for related_spec in related_specs:
241
+ related_model = self.model_manager.get_model(related_spec)
242
+ if not related_model:
243
+ continue
244
+
245
+ # Share weights for similar layers
246
+ for p_layer, r_layer in zip(primary_model.parameters(), related_model.parameters()):
247
+ r_layer.data.copy_(p_layer.data)
248
+ logger.info("Completed weight sharing across model groups")
249
+
250
+ def process_with_snn(self, input_tensor: torch.Tensor) -> torch.Tensor:
251
+ """Process input through SNN components if enabled."""
252
+ # Check if SNN is enabled and components are available
253
+ if not hasattr(self, 'snn_comm') or self.snn_comm is None:
254
+ return input_tensor
255
+
256
+ # Reset states before processing new input
257
+ self.reset_snn_state()
258
+
259
+ # Ensure input is properly shaped
260
+ if input_tensor.dim() == 1:
261
+ input_tensor = input_tensor.unsqueeze(0)
262
+
263
+ try:
264
+ # Use communicator_STDP for processing if available
265
+ if hasattr(self, 'snn_comm') and self.snn_comm is not None:
266
+ # Pass to dedicated STDP communicator - enabling parallel processing
267
+ return self.snn_comm.process_input(input_tensor)
268
+ else:
269
+ # Generate spikes from input
270
+ spikes = self.spike_generator(input_tensor, num_steps=1)
271
+
272
+ # Process through synaptic layer
273
+ syn_out_result = self.synapse_weights(spikes)
274
+ syn_out = syn_out_result[0] if isinstance(syn_out_result, tuple) else syn_out_result
275
+
276
+ # Handle LIF neuron processing
277
+ batch_size = syn_out.shape[0]
278
+ if self.mem.shape[0] != batch_size:
279
+ self.mem = torch.zeros(batch_size, syn_out.shape[1], device=syn_out.device)
280
+
281
+ # Process through SNN layer
282
+ mem_next = self.beta * self.mem + syn_out
283
+ spk_next = (mem_next > self.spike_threshold).float()
284
+ self.mem = mem_next * (1 - spk_next) # Reset membrane if spiked
285
+ self.spk = spk_next
286
+
287
+ return self.spk
288
+
289
+ except Exception as e:
290
+ logger.error(f"Error in SNN processing: {e}", exc_info=True)
291
+ return input_tensor
292
+ def reset_snn_state(self):
293
+ """Reset the SNN neuron states"""
294
+ if hasattr(self, 'mem'):
295
+ self.mem = torch.zeros_like(self.mem)
296
+ if hasattr(self, 'spk'):
297
+ self.spk = torch.zeros_like(self.spk)
298
+
299
+ def route_input(self, input_text: str, query: Optional[str] = None) -> List[tuple]:
300
+ """Route input to most relevant specializations, returning top-k matches. Returns:
301
+ List of (specialization, similarity_score) tuples"""
302
+ with self.lock:
303
+ text_to_analyze = query if query else input_text
304
+
305
+ if not self.specialization_embeddings:
306
+ logger.warning("No specialization embeddings available for routing")
307
+ return [("default", 1.0)]
308
+
309
+ try:
310
+ # Calculate text embedding
311
+ text_embedding = self.embedding_model.encode(text_to_analyze, convert_to_numpy=True)
312
+ # Apply SNN processing if enabled
313
+ use_snn = self._get_config_value('STDP_CONFIG', 'USE_SNN', False)
314
+
315
+ if use_snn:
316
+ text_embedding = torch.from_numpy(text_embedding).float()
317
+ text_embedding = self.process_with_snn(text_embedding)
318
+ text_embedding = text_embedding.detach().numpy()
319
+
320
+ # Calculate similarities
321
+ text_embedding = text_embedding.reshape(1, -1)
322
+ similarities = {}
323
+
324
+ for spec, spec_embedding in self.specialization_embeddings.items():
325
+ spec_embedding = spec_embedding.reshape(1, -1)
326
+ similarity = cosine_similarity(text_embedding, spec_embedding)[0][0]
327
+ similarities[spec] = float(similarity)
328
+
329
+ # Get top-k most similar specializations
330
+ sorted_specs = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
331
+ top_k_specs = sorted_specs[:self.top_k]
332
+
333
+ logger.debug("Routing similarities: %s", similarities)
334
+ logger.info("Selected top %d specializations: %s", self.top_k, top_k_specs)
335
+
336
+ # Check if prompt is long enough to use sliding window
337
+ prompt_length = len(input_text.split())
338
+ use_sliding_window = prompt_length > self.attention_config['WINDOW_SIZE'] // 2
339
+
340
+ if use_sliding_window:
341
+ logger.info(f"Using sliding window attention for long input (length: {prompt_length})")
342
+ return top_k_specs if top_k_specs else [("default", 1.0)]
343
+ except Exception as e:
344
+ logger.error(f"Error in route_input: {str(e)}")
345
+ return [("default", 1.0)]
346
+
347
+ def process_input(self, input_text: str, context: Optional[Dict] = None) -> Dict[str, Any]:
348
+ """Process user input through the appropriate model(s) and generate response. Returns:
349
+ Dictionary containing response and metadata"""
350
+ start_time = time.time()
351
+ logger.info(f"Processing input: {input_text[:50]}...")
352
+ try:
353
+ # Add input to conversation history
354
+ self.conversation_history.append({"role": "user", "content": input_text})
355
+
356
+ # Route input to determine specialization
357
+ specializations = self.route_input(input_text)
358
+ primary_spec, confidence = specializations[0] if specializations else ("default", 0.0)
359
+
360
+ # Get the model for primary specialization
361
+ model = None
362
+ if hasattr(self.model_manager, 'get_model'):
363
+ model = self.model_manager.get_model(primary_spec)
364
+ elif primary_spec in self.models:
365
+ model = self.models[primary_spec]
366
+
367
+ if not model:
368
+ logger.warning(f"No model found for {primary_spec}, using default")
369
+ # Try to get any available model
370
+ if hasattr(self.model_manager, 'get_available_models'):
371
+ models = self.model_manager.get_available_models()
372
+ if models:
373
+ model = next(iter(models.values()), None)
374
+ elif self.models:
375
+ model = next(iter(self.models.values()), None)
376
+ if not model:
377
+ return {
378
+ "response": "No models available to process your request.",
379
+ "specialization": "none",
380
+ "processing_time": time.time() - start_time
381
+ }
382
+
383
+ # Check if STDP/SNN should be used
384
+ use_snn = self._get_config_value('STDP_CONFIG', 'USE_SNN', False)
385
+
386
+ # Process input with standard pipeline
387
+ model_inputs = self.prepare_model_input(input_text, model)
388
+
389
+ # Generate response
390
+ response = self.process_request(input_text, model)
391
+
392
+ # If SNN is enabled, also process with STDP - potentially in parallel
393
+ stdp_response = None
394
+ if use_snn and hasattr(self, 'snn_comm') and self.snn_comm:
395
+ try:
396
+ # Process simultaneously with STDP
397
+ stdp_response = self.snn_comm.process_request(input_text, model)
398
+ logger.info("STDP processing completed successfully")
399
+ except Exception as e:
400
+ logger.error(f"STDP processing failed: {e}")
401
+
402
+ # Format response - prefer standard response but use STDP if standard fails
403
+ formatted_response = None
404
+ if response:
405
+ formatted_response = self.output_formatter.format_response(response, primary_spec)
406
+ elif stdp_response:
407
+ formatted_response = self.output_formatter.format_response(stdp_response, primary_spec)
408
+ response = stdp_response
409
+ else:
410
+ formatted_response = "I'm having trouble generating a response."
411
+
412
+ # Add to conversation history
413
+ self.conversation_history.append({"role": "assistant", "content": formatted_response})
414
+
415
+ # Share weights if needed and more than one specialization
416
+ if len(specializations) > 1:
417
+ self.share_weights()
418
+ # Calculate processing time
419
+ processing_time = time.time() - start_time
420
+
421
+ result = {
422
+ "response": formatted_response,
423
+ "specialization": primary_spec,
424
+ "similarity_score": confidence,
425
+ "processing_time": processing_time,
426
+ "alternative_specializations": [s[0] for s in specializations[1:]] if len(specializations) > 1 else []
427
+ }
428
+ # Add STDP information if available
429
+ if stdp_response:
430
+ result["stdp_processed"] = True
431
+ result["parallel_response"] = stdp_response
432
+ return result
433
+
434
+ except Exception as e:
435
+ logger.error(f"Error processing input: {str(e)}", exc_info=True)
436
+ return {
437
+ "response": f"An error occurred while processing your request: {str(e)}",
438
+ "error": str(e),
439
+ "processing_time": time.time() - start_time
440
+ }
441
+
442
+ def prepare_model_input(self, text: str, model) -> Dict:
443
+ """Prepare input text for model processing. Returns: Dictionary of model inputs"""
444
+ device = next(model.parameters()).device
445
+ try:
446
+ # Get tokenizer from model
447
+ tokenizer = getattr(model, 'tokenizer', None)
448
+
449
+ if tokenizer:
450
+ # Tokenize the input
451
+ inputs = tokenizer(
452
+ text,
453
+ return_tensors="pt",
454
+ padding=True,
455
+ truncation=True,
456
+ max_length=safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512)
457
+ )
458
+ # Move inputs to the same device as model
459
+ input_ids = inputs["input_ids"].to(device)
460
+
461
+ return {
462
+ "input_ids": input_ids,
463
+ "max_length": app_config.MAX_SEQ_LENGTH,
464
+ "device": device,
465
+ "temperature": getattr(self, 'generation_config', {}).get('temperature', 0.7)
466
+ }
467
+ else:
468
+ # Fallback if tokenizer not available
469
+ logger.warning("Model has no tokenizer attribute, using basic input")
470
+ return {
471
+ "input_text": text,
472
+ "max_length": app_config.MAX_SEQ_LENGTH
473
+ }
474
+ except Exception as e:
475
+ logger.error(f"Error preparing model input: {str(e)}")
476
+ # Return minimal inputs
477
+ return {"input_text": text}
478
+
479
+ def clear_conversation_history(self):
480
+ """Clear the conversation history"""
481
+ self.conversation_history = []
482
+
483
+ def get_conversation_history(self) -> List[Dict]:
484
+ """Get the current conversation history"""
485
+ return self.conversation_history.copy()
486
+
487
+ def process_request(self, prompt: str, model: Any) -> str:
488
+ """Process a user request through the selected model"""
489
+ try:
490
+ logger.info(f"Processing request with model")
491
+
492
+ # Get the tokenizer - reuse existing tokenizer or initialize if needed
493
+ if not self.tokenizer:
494
+ self.tokenizer = self._init_tokenizer()
495
+
496
+ # Tokenize input
497
+ inputs = self.tokenizer(
498
+ prompt,
499
+ return_tensors="pt",
500
+ truncation=True,
501
+ max_length=128
502
+ )
503
+ # Generate response with the model
504
+ with torch.no_grad():
505
+ try:
506
+ # Try using generate method with compatible parameters
507
+ if hasattr(model, 'generate_with_decoding'):
508
+ # Use the most direct generation method if available
509
+ return model.generate_with_decoding(
510
+ inputs["input_ids"],
511
+ max_length=256,
512
+ temperature=0.7
513
+ )
514
+ elif hasattr(model, 'generate'):
515
+ # Check what parameters the generate method accepts
516
+ import inspect
517
+ sig_params = inspect.signature(model.generate).parameters
518
+ generate_kwargs = {'input_ids': inputs["input_ids"]}
519
+
520
+ # Only add parameters the function accepts
521
+ if 'max_length' in sig_params:
522
+ generate_kwargs['max_length'] = 256
523
+
524
+ if 'temperature' in sig_params:
525
+ generate_kwargs['temperature'] = 0.7
526
+
527
+ # Call generate with compatible parameters
528
+ outputs = model.generate(**generate_kwargs)
529
+
530
+ # Decode the output
531
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
532
+
533
+ # Clean up and return
534
+ return response.strip()
535
+ except Exception as e:
536
+ logger.warning(f"Error in model.generate: {e}")
537
+
538
+ # Check for shape errors which is the common issue we're encountering
539
+ if "shape" in str(e):
540
+ # Extract the specific shape mentioned in the error
541
+ shape_match = re.search(r'shape \'\[(.*?)\]\'', str(e))
542
+ if shape_match:
543
+ # Special handling for shape error - use alternative models
544
+ logger.info("Detected shape error, trying alternative model inference methods")
545
+
546
+ # Try to get a response using a different model specialization
547
+ alternative_response = self._get_response_from_alternative_model(prompt)
548
+ if alternative_response:
549
+ return alternative_response
550
+
551
+ # If that fails, try using a more dynamic topic detection approach
552
+ topic, subtopics = self._analyze_prompt_for_topics(prompt)
553
+ logger.info(f"Detected topic: {topic}, subtopics: {subtopics}")
554
+
555
+ return self._get_topic_response(topic, prompt, subtopics)
556
+
557
+ # If not a shape error, try direct model inference
558
+ try:
559
+ # Use only input_ids to minimize potential shape issues
560
+ outputs = model(inputs["input_ids"])
561
+
562
+ # Check if we can extract anything meaningful from the outputs
563
+ if isinstance(outputs, dict) and "logits" in outputs:
564
+ logits = outputs["logits"]
565
+ # Extract top tokens for a coherent response
566
+ response = self._generate_response_from_logits(logits, prompt)
567
+ if response:
568
+ return response
569
+ elif isinstance(outputs, torch.Tensor) and outputs.dim() >= 2:
570
+ # For tensor outputs, extract useful information
571
+ response = self._generate_response_from_tensor(outputs, prompt)
572
+ if response:
573
+ return response
574
+ except Exception as fw_error:
575
+ logger.error(f"Forward pass error: {fw_error}")
576
+
577
+ # Last resort: check if other models can handle this prompt better
578
+ return self._get_fallback_response(prompt)
579
+ except Exception as e:
580
+ logger.error(f"Error in process_request: {e}")
581
+ return "I encountered an error processing your request. Could you try asking your question differently?"
582
+
583
+ def _get_response_from_alternative_model(self, prompt: str) -> Optional[str]:
584
+ """Try to get a response using a different model from the model manager"""
585
+ try:
586
+ if not self.model_manager:
587
+ return None
588
+ # Get the top 3 alternative models
589
+ specializations = self.route_input(prompt)
590
+ # Skip the first one (which is the one that just failed)
591
+ for spec, _ in specializations[1:]:
592
+ alt_model = self.model_manager.get_model(spec)
593
+ if alt_model:
594
+ logger.info(f"Trying alternative model for specialization: {spec}")
595
+ try:
596
+ # Prepare inputs for this model
597
+ if hasattr(alt_model, 'tokenizer'):
598
+ tokenizer = alt_model.tokenizer
599
+ else:
600
+ tokenizer = self.tokenizer
601
+ inputs = tokenizer(
602
+ prompt,
603
+ return_tensors="pt",
604
+ truncation=True,
605
+ max_length=128
606
+ )
607
+ # Try generation with this model
608
+ if hasattr(alt_model, 'generate_with_decoding'):
609
+ response = alt_model.generate_with_decoding(
610
+ inputs["input_ids"],
611
+ max_length=256,
612
+ temperature=0.7
613
+ )
614
+ if response and isinstance(response, str) and len(response) > 10:
615
+ return response
616
+ except Exception as alt_error:
617
+ logger.warning(f"Alternative model {spec} also failed: {alt_error}")
618
+ continue
619
+ return None
620
+ except Exception as e:
621
+ logger.error(f"Error getting response from alternative model: {e}")
622
+ return None
623
+ def _analyze_prompt_for_topics(self, score, prompt: str) -> Tuple[str, List[str]]:
624
+ """Analyze prompt to dynamically determine the topic and subtopics"""
625
+ # First try to use the embedding model if available
626
+ primary_topic = "general"
627
+ subtopics = []
628
+ try:
629
+ # Option 1: Use embedding similarity to predefined topics
630
+ if hasattr(self, 'embedding_model'):
631
+ # Define a broad range of topics
632
+ candidate_topics = [
633
+ "programming", "math", "science", "history", "art",
634
+ "literature", "music", "politics", "economics", "philosophy",
635
+ "technology", "health", "sports", "entertainment", "education",
636
+ "business", "psychology", "sociology", "linguistics", "physics",
637
+ "chemistry", "biology", "medicine", "engineering", "computer science",
638
+ "artificial intelligence", "data science", "web development", "finance",
639
+ "law", "ethics", "religion", "geography", "astronomy", "environment"
640
+ ]
641
+ # Get embedding for the prompt
642
+ prompt_embedding = self.embedding_model.encode(prompt, convert_to_numpy=True)
643
+
644
+ # Get embeddings for topics
645
+ topic_embeddings = {
646
+ topic: self.embedding_model.encode(f"This text is about {topic}.", convert_to_numpy=True)
647
+ for topic in candidate_topics
648
+ }
649
+ # Calculate similarities
650
+ similarities = {
651
+ topic: float(cosine_similarity(
652
+ prompt_embedding.reshape(1, -1),
653
+ emb.reshape(1, -1)
654
+ )[0][0])
655
+ for topic, emb in topic_embeddings.items()
656
+ }
657
+ # Sort by similarity score
658
+ sorted_topics = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
659
+ # Get primary topic and subtopics
660
+ if sorted_topics:
661
+ primary_topic = sorted_topics[0][0]
662
+ # Get subtopics with similarity score at least 80% of the top score
663
+ threshold = sorted_topics[0][1] * 0.8
664
+ subtopics = [topic for topic, score in sorted_topics[1:6] if score > threshold]
665
+
666
+ # Option 2: Use frequency analysis as fallback
667
+ if primary_topic == "general" or not subtopics:
668
+ # Define topic keywords
669
+ topic_keywords = {
670
+ "programming": ["code", "programming", "python", "java", "javascript", "function", "algorithm", "developer", "software"],
671
+ "math": ["math", "mathematics", "algebra", "calculus", "equation", "geometry", "statistics", "theorem"],
672
+ "science": ["science", "physics", "chemistry", "biology", "scientific", "experiment", "theory"],
673
+ "history": ["history", "historical", "ancient", "century", "civilization", "war", "empire"],
674
+ "technology": ["technology", "tech", "computer", "digital", "internet", "device", "hardware", "software"],
675
+ "ai": ["ai", "artificial intelligence", "machine learning", "neural network", "deep learning", "nlp", "algorithm"],
676
+ "health": ["health", "medical", "medicine", "disease", "treatment", "doctor", "patient", "healthcare"],
677
+ "business": ["business", "company", "market", "industry", "finance", "economic", "management", "strategy"],
678
+ "general": [] # Fallback
679
+ }
680
+ # Clean and tokenize prompt
681
+ words = re.findall(r'\b[a-zA-Z]{3,}\b', prompt.lower())
682
+
683
+ # Count matches for each topic
684
+ topic_scores = {topic: 0 for topic in topic_keywords.keys()}
685
+ for word in words:
686
+ for topic, keywords in topic_keywords.items():
687
+ if word in keywords or any(keyword in word for keyword in keywords):
688
+ topic_scores[topic] += 1
689
+
690
+ # Get top topics by score
691
+ sorted_topics = sorted(topic_scores.items(), key=lambda x: x[1], reverse=True)
692
+ if sorted_topics[0][1] > 0:
693
+ primary_topic = sorted_topics[0][0]
694
+ # Get subtopics with score > 0
695
+ subtopics = [topic for topic in sorted_topics[1:4] if score > 0]
696
+
697
+ # If we still don't have subtopics, add some based on primary topic
698
+ if not subtopics:
699
+ # Define related subtopics for common topics
700
+ related_topics = {
701
+ "programming": ["software development", "algorithms", "data structures"],
702
+ "math": ["algebra", "geometry", "statistics"],
703
+ "science": ["physics", "chemistry", "biology"],
704
+ "history": ["ancient history", "modern history", "world wars"],
705
+ "technology": ["computers", "internet", "gadgets"],
706
+ "ai": ["machine learning", "neural networks", "natural language processing"],
707
+ "health": ["medicine", "wellness", "nutrition"],
708
+ "business": ["economics", "finance", "management"]
709
+ }
710
+ subtopics = related_topics.get(primary_topic, ["information", "knowledge", "details"])
711
+
712
+ return primary_topic, subtopics
713
+ except Exception as e:
714
+ logger.error(f"Error analyzing prompt for topics: {e}")
715
+ return "general", ["information"]
716
+ def _generate_response_from_logits(self, logits: torch.Tensor, prompt: str) -> Optional[str]:
717
+ """Generate a coherent response from model output logits"""
718
+ try:
719
+ # Extract the top tokens from the logits
720
+ if logits.dim() >= 2:
721
+ # Get the last position's logits
722
+ last_logits = logits[:, -1, :] if logits.dim() > 2 else logits
723
+
724
+ # Get top tokens
725
+ top_k = min(5, last_logits.size(-1))
726
+ top_values, top_indices = torch.topk(last_logits, top_k, dim=-1)
727
+
728
+ # Decode top tokens
729
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
730
+ top_tokens = [self.tokenizer.decode([idx.item()]) for idx in top_indices[0]]
731
+
732
+ # Create a coherent response using the tokens and context from the prompt
733
+ topic_tokens = [token for token in top_tokens if len(token) > 1 and not token.startswith('[')]
734
+ if topic_tokens:
735
+ # Extract topic from prompt
736
+ topic = self._extract_topic_from_prompt(prompt)
737
+ context = ", ".join(topic_tokens[:3])
738
+ return f"Based on my understanding of {topic}, the key concepts include {context}. Would you like more specific information about any of these aspects?"
739
+ return None
740
+ except Exception as e:
741
+ logger.error(f"Error generating response from logits: {e}")
742
+ return None
743
+ def _generate_response_from_tensor(self, tensor: torch.Tensor, prompt: str) -> Optional[str]:
744
+ """Generate a response from a tensor output"""
745
+ try:
746
+ # For sequence outputs, try to find the most relevant position
747
+ if tensor.dim() >= 2:
748
+ # If it's a sequence, use the mean or the last position
749
+ if tensor.dim() == 3: # [batch, seq, hidden]
750
+ features = tensor[0, -1, :] # Last position of first batch
751
+ else: # [batch, hidden]
752
+ features = tensor[0, :] # First batch
753
+
754
+ # Use these features to generate a meaningful response
755
+ # (Simplified approach - in reality we'd want to use these features more effectively)
756
+ topic = self._extract_topic_from_prompt(prompt)
757
+
758
+ # If the tensor is small enough, we can include some values
759
+ if features.numel() < 10:
760
+ values = [f"{val:.2f}" for val in features[:5].tolist()]
761
+ value_str = ", ".join(values)
762
+ return f"I analyzed your question about {topic}. My analysis indicates values of {value_str}, which suggests this topic involves multiple factors."
763
+ else:
764
+ # Generic response using the tensor shape
765
+ shape_str = "x".join(str(dim) for dim in tensor.size())
766
+ return f"I analyzed your question about {topic}. This is a complex topic with many dimensions (tensor shape: {shape_str}). Could you specify which aspect you'd like me to focus on?"
767
+ return None
768
+ except Exception as e:
769
+ logger.error(f"Error generating response from tensor: {e}")
770
+ return None
771
+ def _extract_topic_from_prompt(self, prompt: str) -> str:
772
+ """Extract a topic phrase from the prompt"""
773
+ # Simple extraction of the main subject using first few words
774
+ words = prompt.strip().split()
775
+
776
+ if not words:
777
+ return "this topic"
778
+
779
+ # Check for common question patterns
780
+ if words[0].lower() in ['what', 'how', 'why', 'when', 'where', 'who', 'which']:
781
+ # For questions, look for the subject after the question word
782
+ # E.g., "What is quantum physics?" -> "quantum physics"
783
+ if len(words) > 1:
784
+ if words[1].lower() in ['is', 'are', 'was', 'were', 'will', 'did', 'does', 'do']:
785
+ if len(words) > 2:
786
+ return ' '.join(words[2:min(5, len(words))])
787
+ return words[1]
788
+ return ' '.join(words[1:min(4, len(words))])
789
+ # For non-questions, use the first few words
790
+ return ' '.join(words[:min(3, len(words))])
791
+
792
+ def _extract_subject(self, text: str) -> str:
793
+ """Extract the primary subject from a text prompt
794
+ This method uses basic NLP techniques to identify the main
795
+ subject or topic of a text, which can be used for routing to specialized models."""
796
+ try:
797
+ # For more advanced implementations, we'd use proper NLP here
798
+ # For now, a simple keyword extraction approach:
799
+
800
+ # Convert to lowercase for easier matching
801
+ text = text.lower()
802
+
803
+ # Define some subject categories and their keywords
804
+ subject_keywords = {
805
+ "programming": ["code", "program", "programming", "function", "algorithm", "software", "developer"],
806
+ "mathematics": ["math", "equation", "calculation", "formula", "number", "geometry"],
807
+ "science": ["science", "physics", "chemistry", "biology", "scientific"],
808
+ "history": ["history", "historical", "past", "ancient", "century"]
809
+ }
810
+ # Find which subject has the most matching keywords
811
+ subject_scores = {}
812
+ for subject, keywords in subject_keywords.items():
813
+ score = sum(1 for keyword in keywords if keyword in text)
814
+ if score > 0:
815
+ subject_scores[subject] = score
816
+
817
+ # Return the subject with the highest score, or empty string if none found
818
+ if subject_scores:
819
+ return max(subject_scores.items(), key=lambda x: x[1])[0]
820
+ return ""
821
+ except Exception as e:
822
+ logger.error(f"Error extracting subject: {e}")
823
+ return ""
824
+
825
+ # Add conversation context methods to enhance chatbot capabilities
826
+ def add_to_conversation_history(self, role: str, content: str, metadata: Optional[Dict] = None):
827
+ """Add an entry to conversation history with optional metadata"""
828
+ entry = {
829
+ "role": role,
830
+ "content": content,
831
+ "timestamp": time.time()
832
+ }
833
+ if metadata:
834
+ entry["metadata"] = metadata
835
+ self.conversation_history.append(entry)
836
+ # Maintain a reasonable history size
837
+ max_history = getattr(app_config, "MAX_CONVERSATION_HISTORY", 10)
838
+ if len(self.conversation_history) > max_history:
839
+ self.conversation_history = self.conversation_history[-max_history:]
840
+
841
+ def get_conversation_context(self, window_size: int = 3) -> str:
842
+ """Get recent conversation context formatted as a single string"""
843
+ if not self.conversation_history:
844
+ return ""
845
+
846
+ # Get the most recent exchanges
847
+ recent_history = self.conversation_history[-window_size*2:]
848
+
849
+ # Format as a string
850
+ context_parts = []
851
+ for entry in recent_history:
852
+ role_prefix = "User: " if entry["role"] == "user" else "Assistant: "
853
+ context_parts.append(f"{role_prefix}{entry['content']}")
854
+
855
+ return "\n".join(context_parts)
856
+
857
+ def process_with_context(self, input_text: str, context: Optional[Dict] = None) -> Dict[str, Any]:
858
+ """Process input with conversation context for better continuity"""
859
+ # Get recent conversation context
860
+ conversation_context = self.get_conversation_context(window_size=3)
861
+
862
+ # Combine context with current prompt if context exists
863
+ contextualized_prompt = input_text
864
+ if conversation_context:
865
+ # Create a prompt that includes conversation history
866
+ # but doesn't exceed token limits
867
+ # Get MAX_SEQ_LENGTH safely
868
+ max_seq_length = getattr(app_config, 'MAX_SEQ_LENGTH', 512)
869
+ if isinstance(max_seq_length, dict):
870
+ max_seq_length = 512
871
+ logger.warning(f"MAX_SEQ_LENGTH is a dictionary, using default: {max_seq_length}")
872
+ elif not isinstance(max_seq_length, (int, float)):
873
+ max_seq_length = 512
874
+ logger.warning(f"MAX_SEQ_LENGTH is not a number, using default: {max_seq_length}")
875
+ else:
876
+ max_seq_length = int(max_seq_length)
877
+
878
+ max_context_length = max_seq_length // 2 # Now safe to use integer division
879
+
880
+ contextualized_prompt = f"Previous conversation:\n{conversation_context}\n\nCurrent question: {input_text}"
881
+
882
+ # Process using enhanced prompt
883
+ result = self.process_input(contextualized_prompt, context)
884
+
885
+ # Store original query in result
886
+ if isinstance(result, dict):
887
+ result["original_query"] = input_text
888
+ return result
889
+
890
+ def _get_fallback_response(self, prompt: str) -> str:
891
+ """Get a fallback response when primary model processing fails"""
892
+ try:
893
+ # Extract topic from prompt
894
+ topic, subtopics = self._analyze_prompt_for_topics(prompt)
895
+ # Try to use any available model for generating a response
896
+ if hasattr(self, 'model_manager') and self.model_manager:
897
+ # Try multiple strategies to get a working model
898
+ # Strategy 1: Try the built-in alternative model getter
899
+ if hasattr(self.model_manager, 'get_alternative_model_for_prompt'):
900
+ alt_model = self.model_manager.get_alternative_model_for_prompt(prompt)
901
+ if alt_model:
902
+ logger.info(f"Using alternative model for fallback response")
903
+ try:
904
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
905
+ if hasattr(alt_model, 'generate_with_decoding'):
906
+ response = alt_model.generate_with_decoding(
907
+ inputs["input_ids"],
908
+ max_length=256,
909
+ temperature=0.7
910
+ )
911
+ if response and isinstance(response, str) and len(response) > 10:
912
+ return response
913
+ except Exception as alt_error:
914
+ logger.warning(f"Alternative model also failed: {alt_error}")
915
+
916
+ # Strategy 2: Try any other available model from the manager
917
+ try:
918
+ available_models = self.model_manager.get_available_models()
919
+ for spec_name, model in available_models.items():
920
+ if spec_name != topic: # Skip the model that likely failed already
921
+ logger.info(f"Trying model from specialization: {spec_name}")
922
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
923
+ if hasattr(model, 'generate_with_decoding'):
924
+ response = model.generate_with_decoding(
925
+ inputs["input_ids"],
926
+ max_length=256,
927
+ temperature=0.9 # Higher temperature for diversity
928
+ )
929
+ if response and isinstance(response, str) and len(response) > 10:
930
+ return response
931
+ except Exception as e:
932
+ logger.warning(f"Failed to use alternative models: {e}")
933
+
934
+ # If no model worked, build a dynamic response based on topic analysis
935
+ return self._build_dynamic_response(topic, prompt, subtopics)
936
+ except Exception as e:
937
+ logger.error(f"Error getting fallback response: {e}")
938
+ # Absolute last resort generic response
939
+ return "I'm having trouble understanding that request. Could you rephrase it or try asking something else?"
940
+
941
+ def _build_dynamic_response(self, topic: str, prompt: str, subtopics: List[str] = None) -> str:
942
+ """Build a dynamic response based on topic analysis without hardcoded templates"""
943
+ try:
944
+ # Extract subject if possible
945
+ subject = self._extract_subject(prompt)
946
+ # Ensure we have subtopics list
947
+ subtopics = subtopics or []
948
+ # Build a response that acknowledges the topic but doesn't contain hardcoded knowledge
949
+ topic_str = subject if subject else topic
950
+
951
+ # Construct a dynamic response prompt for a model
952
+ meta_prompt = f"""
953
+ Topic: {topic_str}
954
+ Related areas: {', '.join(subtopics[:3]) if subtopics else 'various fields'}
955
+ Request: {prompt}
956
+
957
+ Create a brief response that acknowledges the topic but asks for clarification.
958
+ Do not provide specific information about the topic, just acknowledge understanding and ask for more details."""
959
+ # Try to use a lightweight model for this meta-generation if possible
960
+ try:
961
+ if hasattr(self, 'model_manager') and self.model_manager:
962
+ # Try to find any working model
963
+ models = self.model_manager.get_available_models()
964
+ if models:
965
+ model = next(iter(models.values()))
966
+ inputs = self.tokenizer(meta_prompt, return_tensors="pt", truncation=True, max_length=256)
967
+ meta_response = model.generate_with_decoding(
968
+ inputs["input_ids"],
969
+ max_length=256,
970
+ temperature=0.7
971
+ )
972
+ if meta_response and len(meta_response) > 20:
973
+ return meta_response
974
+ except Exception as e:
975
+ logger.warning(f"Meta-generation failed: {e}")
976
+
977
+ # Fallback to a very simple dynamic response if all else fails
978
+ subtopic_str = ", ".join(subtopics[:3]) if subtopics else "related areas"
979
+
980
+ return f"""I understand you're asking about {topic_str}. This relates to {subtopic_str}.
981
+ To provide a helpful response, I'd need more specific details about what aspect you're interested in learning about.
982
+ Could you please clarify what specific information you're looking for?"""
983
+ except Exception as e:
984
+ logger.error(f"Error building dynamic response: {e}")
985
+ return "I need more information to help you with that topic. Could you provide more details about what you'd like to know?"
986
+ def _get_topic_response(self, topic: str, prompt: str, subtopics: List[str] = None) -> str:
987
+ """Get a response for a specific topic using model-driven approach"""
988
+ return self._build_dynamic_response(topic, prompt, subtopics)
989
+
990
+ def process_input(self, prompt, **kwargs):
991
+ # First try using a real model if available
992
+ if self.model and not (hasattr(self.model, '_is_minimal') and self.model._is_minimal) and self.tokenizer:
993
+ try:
994
+ logger.info("Attempting model inference with actual model")
995
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
996
+
997
+ # Add timeout protection
998
+ max_inference_time = 30 # seconds
999
+ start_time = time.time()
1000
+
1001
+ if hasattr(self.model, "generate_with_decoding"):
1002
+ response = self.model.generate_with_decoding(inputs.input_ids)
1003
+ elif hasattr(self.model, "generate"):
1004
+ output_ids = self.model.generate(inputs.input_ids)
1005
+ response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
1006
+ else:
1007
+ # Forward pass
1008
+ outputs = self.model(inputs.input_ids)
1009
+ response = self.tokenizer.decode(torch.argmax(outputs, dim=-1)[0], skip_special_tokens=True)
1010
+
1011
+ elapsed_time = time.time() - start_time
1012
+
1013
+ if elapsed_time > max_inference_time:
1014
+ logger.warning(f"Model inference took too long: {elapsed_time:.2f} seconds")
1015
+
1016
+ if response and len(response) > 10: # Require reasonably long response
1017
+ logger.info("Generated model response successfully")
1018
+ return {"response": response, "minimal_mode": False}
1019
+ except Exception as e:
1020
+ logger.warning(f"Model inference failed: {e}")
1021
+ elif self.model and hasattr(self.model, '_is_minimal') and self.model._is_minimal:
1022
+ logger.warning("Using minimal model - full model unavailable")
1023
+
1024
+ # Check if prompt contains keywords we can respond to meaningfully
1025
+ logger.debug(f"Minimal communicator processing: {prompt[:30]}...")
1026
+ response = self._get_knowledge_response(prompt)
1027
+ if response:
1028
+ return {"response": response, "minimal_mode": True} # Flag as minimal mode
1029
+
1030
+ return {"response": f"I'm operating in minimal mode. Your query was about {prompt.split()[0] if prompt.split() else 'this topic'}...",
1031
+ "minimal_mode": True} # Flag as minimal mode
1032
+
1033
+ # Add factory function for producing & registering the main Communicator
1034
+ def create_communicator(model_manager=None):
1035
+ from communicator import Communicator
1036
+ comm = Communicator(model_manager=model_manager)
1037
+ registry.register(COMMUNICATOR, comm)
1038
+ return comm
1039
+
1040
+ from service_registry import registry, COMMUNICATOR
1041
+ from adapter_layer import WildnerveModelAdapter
1042
+
1043
+ class Communicator:
1044
+ def __init__(self):
1045
+ self.adapter = WildnerveModelAdapter()
1046
+
1047
+ def process_request(self, prompt: str, **kwargs):
1048
+ return self.adapter.generate(prompt, **kwargs)
1049
+
1050
+ # Register
1051
+ comm = Communicator()
1052
+ registry.register(COMMUNICATOR, comm, overwrite=True)
config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "model_type": "wildnerve_tlm01",
3
  "architectures": ["Wildnerve_tlm01"],
4
- "SELECTED_MODEL": ["model_Custm.py", "model_PrTr.py", "model_Combn.py"],
5
  "MODEL_NAME": "Wildnerve-tlm01_Hybrid_Model",
6
  "BASE_DATA_DIR": "data",
7
  "FILE_FORMATS": ["csv", "json", "txt"],
 
1
  {
2
  "model_type": "wildnerve_tlm01",
3
  "architectures": ["Wildnerve_tlm01"],
4
+ "SELECTED_MODEL": ["model_Custm.py", "model_PrTr.py"],
5
  "MODEL_NAME": "Wildnerve-tlm01_Hybrid_Model",
6
  "BASE_DATA_DIR": "data",
7
  "FILE_FORMATS": ["csv", "json", "txt"],
config.py CHANGED
@@ -352,6 +352,11 @@ class STDPConfig(BaseModel):
352
  )
353
 
354
  class AppConfig(BaseModel):
 
 
 
 
 
355
  DATA_DIR: str = Field(default="/tmp/tlm_data")
356
  MODEL_DIR: str = Field(default="/tmp/tlm_data/models")
357
  TRANSFORMER_CONFIG: TransformerConfig = Field(default_factory=TransformerConfig)
 
352
  )
353
 
354
  class AppConfig(BaseModel):
355
+ # which model files to load by default
356
+ SELECTED_MODEL: List[str] = Field(
357
+ default=["model_Custm.py", "model_PrTr.py"],
358
+ description="Default model files (custom first, then pretrained)"
359
+ )
360
  DATA_DIR: str = Field(default="/tmp/tlm_data")
361
  MODEL_DIR: str = Field(default="/tmp/tlm_data/models")
362
  TRANSFORMER_CONFIG: TransformerConfig = Field(default_factory=TransformerConfig)
handler.py CHANGED
@@ -73,99 +73,21 @@ except ImportError as e:
73
  return f"Model adapter unavailable. Received input: {text_input[:30]}..."
74
 
75
  class EndpointHandler:
76
- def __init__(self, path=""):
77
- self.path = path or os.getcwd()
78
- logger.info(f"Handler init with path: {self.path}")
79
- self.model_adapter = None
80
- self.initialized = False
81
-
82
- def __call__(self, data: Dict[str, Any], parameters: Dict[str, Any] = None) -> List[Dict[str, Any]]:
83
- """Handler entry point"""
84
- # On first call, if init fails, return the real error
85
- if not self.initialized:
86
- ok = self.initialize()
87
- if not ok:
88
- return [{"generated_text": f"Initialization error: {self.init_error}"}]
89
  try:
90
- logger.info(f"Handler received request: {data}")
91
- result = self.predict(data, parameters)
92
-
93
- # Handle result formatting
94
- if isinstance(result, list):
95
- logger.info(f"Returning list result with {len(result)} items")
96
- return result
97
- elif isinstance(result, dict):
98
- return [result]
99
- else:
100
- return [{"generated_text": str(result) if result is not None else "No output generated"}]
101
-
102
  except Exception as e:
103
- logger.error(f"Error in __call__: {e}", exc_info=True)
104
- return [{"generated_text": f"Runtime error: {e}"}]
105
-
106
- def initialize(self):
107
- if self.initialized:
108
- return True
109
- try:
110
- logger.debug(f"Calling WildnerveModelAdapter with path {self.path}")
111
- self.model_adapter = WildnerveModelAdapter(self.path)
112
- self.initialized = True
113
- return True
114
- except Exception as e:
115
- # log full stack trace
116
- logger.error(f"Adapter initialization failed for path '{self.path}': {e}", exc_info=True)
117
- # store message for client
118
  self.init_error = str(e)
119
- return False
120
-
121
- def predict(self, inputs: Dict[str, Any], parameters: Dict[str, Any] = None) -> List[Dict[str, Any]]:
122
- """Process the input and generate a response"""
123
- # Initialize on first call
124
- if not self.initialized:
125
- success = self.initialize()
126
- if not success:
127
- return [{"generated_text": "Failed to initialize the model."}]
128
 
129
- # Extract the prompt text
130
- text_input = self._extract_input_text(inputs)
131
-
132
- # Process parameters
133
- parameters = parameters or {}
134
-
135
  try:
136
- # Generate text with the adapter
137
- generated_text = self.model_adapter.generate(
138
- text_input,
139
- max_length=parameters.get("max_length", 100),
140
- max_new_tokens=parameters.get("max_new_tokens", None),
141
- temperature=parameters.get("temperature", 0.7),
142
- top_p=parameters.get("top_p", 0.9),
143
- top_k=parameters.get("top_k", 40)
144
- )
145
-
146
- # Return the result
147
- return [{"generated_text": generated_text}]
148
-
149
  except Exception as e:
150
- logger.error(f"Error during prediction: {e}")
151
- logger.error(traceback.format_exc())
152
-
153
- return [{"generated_text": f"Error generating response: {str(e)}"}]
154
-
155
- def _extract_input_text(self, inputs) -> str:
156
- """Extract the input text from various possible input formats"""
157
- if isinstance(inputs, str):
158
- return inputs
159
- elif isinstance(inputs, dict):
160
- if "inputs" in inputs:
161
- return inputs["inputs"]
162
- elif "prompt" in inputs:
163
- return inputs["prompt"]
164
- else:
165
- # Try the first string value we find
166
- for key, value in inputs.items():
167
- if isinstance(value, str):
168
- return value
169
- return str(inputs)
170
- else:
171
- return str(inputs)
 
73
  return f"Model adapter unavailable. Received input: {text_input[:30]}..."
74
 
75
  class EndpointHandler:
76
+ def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
77
  try:
78
+ self.adapter = WildnerveModelAdapter()
 
 
 
 
 
 
 
 
 
 
 
79
  except Exception as e:
80
+ logger.error(f"Adapter init failed: {e}", exc_info=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  self.init_error = str(e)
82
+ self.adapter = None
 
 
 
 
 
 
 
 
83
 
84
+ def __call__(self, data, parameters=None):
85
+ if self.adapter is None:
86
+ return [{"generated_text": f"Initialization error: {self.init_error}"}]
87
+ text = data.get("inputs") if isinstance(data, dict) else str(data)
 
 
88
  try:
89
+ out = self.adapter.generate(text, **(parameters or {}))
90
+ return [{"generated_text": out}]
 
 
 
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
+ logger.error(f"Generation error: {e}", exc_info=True)
93
+ return [{"generated_text": f"Error: {e}"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_List.py CHANGED
@@ -98,27 +98,12 @@ class PromptAnalyzer:
98
  return primary_topic, subtopics
99
 
100
  def get_selected_models(self) -> List[str]:
101
- """
102
- Return candidate model identifiers.
103
- For example, if the prompt is technical (programming) the custom model might be top.
104
- This method can later be expanded to select multiple or weighted candidates.
105
- """
106
- # Here we return our primary custom model and a fallback general model.
107
- return ["Wildnerve-tlm01-0.05Bx12", "bert-base-uncased"]
108
 
109
  def choose_model(self, prompt: str) -> str:
110
- """
111
- Based on the analyzed prompt, select the most appropriate model identifier.
112
- For instance, if 'programming' is detected, return the custom model.
113
- Otherwise, return a general/pretrained model or a combination indicator.
114
- """
115
- primary_topic, _ = self.analyze_prompt(prompt)
116
- if primary_topic == "programming":
117
- return "Wildnerve-tlm01-0.05Bx12"
118
- elif primary_topic in ["science", "mathematics", "history"]:
119
- return "model_Combn.py"
120
- else:
121
- return "bert-base-uncased"
122
 
123
  # Optionally, additional helper methods could be added here for richer topic decomposition.
124
 
 
98
  return primary_topic, subtopics
99
 
100
  def get_selected_models(self) -> List[str]:
101
+ # Always keep the custom hybrid model ready
102
+ return ["model_Custm.py"]
 
 
 
 
 
103
 
104
  def choose_model(self, prompt: str) -> str:
105
+ # Adapter no longer uses this, but keep for compatibility
106
+ return "model_Custm.py"
 
 
 
 
 
 
 
 
 
 
107
 
108
  # Optionally, additional helper methods could be added here for richer topic decomposition.
109
 
model_manager.py CHANGED
@@ -745,4 +745,30 @@ if __name__ == "__main__":
745
  logger.info(f"Model Manager initialized with {len(model_manager.models)} models")
746
  else:
747
  model_manager = None
748
- logger.info("ModelManager module imported; initialization deferred")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
  logger.info(f"Model Manager initialized with {len(model_manager.models)} models")
746
  else:
747
  model_manager = None
748
+ logger.info("ModelManager module imported; initialization deferred")
749
+
750
+ import os
751
+ from service_registry import registry, MODEL, PRETRAINED_MODEL, TOKENIZER
752
+ from model_Custm import Wildnerve_tlm01
753
+ from model_PrTr import Wildnerve_tlm01 as PretrainedModel
754
+ from tokenizer import TokenizerWrapper
755
+
756
+ # Instantiate & register tokenizer
757
+ tok = TokenizerWrapper()
758
+ registry.register(TOKENIZER, tok, overwrite=True)
759
+
760
+ # Instantiate & register custom model
761
+ custom = Wildnerve_tlm01(tokenizer=tok)
762
+ registry.register(MODEL, custom, overwrite=True)
763
+
764
+ # Instantiate & register pretrained model
765
+ pre = PretrainedModel(tokenizer=tok)
766
+ registry.register(PRETRAINED_MODEL, pre, overwrite=True)
767
+
768
+ class ModelManager:
769
+ # ...existing manager methods if any...
770
+ pass
771
+
772
+ # create and register manager stub
773
+ manager = ModelManager()
774
+ registry.register(MODEL_MANAGER, manager, overwrite=True)
service_registry.py CHANGED
@@ -8,7 +8,10 @@ logger = logging.getLogger(__name__)
8
 
9
  # Constants used as keys
10
  MODEL = "model"
 
11
  TOKENIZER = "tokenizer"
 
 
12
 
13
  class ServiceRegistry:
14
  """A simple service registry for dependency management"""
 
8
 
9
  # Constants used as keys
10
  MODEL = "model"
11
+ PRETRAINED_MODEL = "pretrained_model"
12
  TOKENIZER = "tokenizer"
13
+ MODEL_MANAGER = "model_manager"
14
+ COMMUNICATOR = "communicator"
15
 
16
  class ServiceRegistry:
17
  """A simple service registry for dependency management"""