WildnerveAI commited on
Commit
d5c47f9
·
verified ·
1 Parent(s): 4261f51

Upload 8 files

Browse files
Files changed (7) hide show
  1. adapter_layer.py +105 -58
  2. config.json +12 -1
  3. config.py +29 -6
  4. model_Custm.py +24 -14
  5. model_List.py +212 -180
  6. model_PrTr.py +104 -25
  7. service_registry.py +3 -3
adapter_layer.py CHANGED
@@ -208,65 +208,112 @@ class WildnerveModelAdapter:
208
  gen_list.append(s)
209
  return " ".join(tech_list).strip(), " ".join(gen_list).strip()
210
 
211
- def generate(self, prompt: str, **kwargs) -> str:
212
- """Generate a response to the given prompt."""
213
- # Determine prompt type
214
- primary, _ = PromptAnalyzer().analyze_prompt(prompt) if hasattr(PromptAnalyzer(), 'analyze_prompt') else (None, None)
215
-
216
- # Set appropriate max_length to prevent length errors
217
- if 'max_length' in kwargs and isinstance(kwargs['max_length'], int):
218
- if kwargs['max_length'] < 512: # If max_length is too small
219
- kwargs['max_length'] = 512 # Use a reasonable default
220
- else:
221
- kwargs['max_length'] = 1024 # Set a default if not provided
222
-
223
- # Try using the pretrained GPT-2 model first for generation
224
- pre = registry.get(PRETRAINED_MODEL)
225
- if pre:
226
  try:
227
- logger.info("Using GPT-2 pretrained model for generation")
228
- # Try to use the pretrained model's generate method
229
- if hasattr(pre, "generate"):
230
- # Check the signature of the generate method to determine correct parameters
231
- import inspect
232
- sig = inspect.signature(pre.generate)
233
- if "prompt" in sig.parameters:
234
- return pre.generate(prompt=prompt, **kwargs)
235
- else:
236
- # If no prompt parameter, try tokenizing first
237
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
238
- return pre.generate(input_ids=inputs.input_ids, **kwargs) # Explicitly pass as input_ids
239
- else:
240
- logger.warning("Pretrained model doesn't have generate method")
241
  except Exception as e:
242
- logger.error(f"Error using pretrained model: {e}")
243
-
244
- # Fall back to using the custom model if needed
245
- if self.model:
246
- try:
247
- logger.info("Using custom model for generation")
248
 
249
- # Check if the model is expecting a prompt parameter or input_ids
250
- import inspect
251
- if hasattr(self.model, "generate"):
252
- sig = inspect.signature(self.model.generate)
253
- if "prompt" in sig.parameters:
254
- # Model accepts prompt parameter directly
255
- return self.model.generate(prompt=prompt, **kwargs) # Explicitly pass as prompt
256
- else:
257
- # Model expects tokenized input_ids instead
258
- logger.info("Model expects tokenized input - converting prompt to input_ids")
259
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
260
- return self.model.generate(input_ids=inputs.input_ids, **kwargs) # Explicitly pass as input_ids
261
- else:
262
- logger.error("Model has no generate method")
263
- # Simple fallback for models without generate
264
- return f"I'm processing your request about '{prompt[:30]}...'"
265
- except Exception as e:
266
- logger.error(f"Error using custom model: {e}")
267
 
268
- # Add last-chance fallback with generic response
269
- return f"I apologize, but I'm experiencing some technical difficulties processing your request about '{prompt[:30]}...'. (Error: {str(e)})"
270
-
271
- # Final fallback
272
- return f"I apologize, but I'm unable to process your request about '{prompt[:30]}...' at this time."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  gen_list.append(s)
209
  return " ".join(tech_list).strip(), " ".join(gen_list).strip()
210
 
211
+ def generate(self, text_input, max_length=None, **kwargs):
212
+ """Generate text using the model - centralized generation point"""
213
+ try:
214
+ # Use PromptAnalyzer to determine which model to use
 
 
 
 
 
 
 
 
 
 
 
215
  try:
216
+ from model_List import PromptAnalyzer
217
+ analyzer = PromptAnalyzer()
218
+ model_type, confidence = analyzer.analyze_prompt(text_input)
219
+ logger.info(f"PromptAnalyzer selected {model_type} with confidence {confidence:.2f}")
 
 
 
 
 
 
 
 
 
 
220
  except Exception as e:
221
+ logger.error(f"Error using PromptAnalyzer: {e}")
222
+ model_type = "model_Custm" # Default to custom model on error
 
 
 
 
223
 
224
+ # Enhanced generation parameters with strong repetition prevention
225
+ generation_kwargs = {
226
+ 'max_length': max_length or 150,
227
+ 'temperature': kwargs.get('temperature', 0.7),
228
+ 'top_p': kwargs.get('top_p', 0.95),
229
+ 'top_k': kwargs.get('top_k', 50),
230
+ 'repetition_penalty': kwargs.get('repetition_penalty', 1.3), # Increased from 1.2
231
+ 'no_repeat_ngram_size': kwargs.get('no_repeat_ngram_size', 3), # Increased from 2
232
+ 'do_sample': kwargs.get('do_sample', True),
233
+ 'num_return_sequences': kwargs.get('num_return_sequences', 1),
234
+ 'early_stopping': kwargs.get('early_stopping', True),
235
+ 'bad_words_ids': kwargs.get('bad_words_ids', None), # Block repetitive phrases
236
+ 'min_length': kwargs.get('min_length', 10), # Ensure reasonable response length
237
+ }
238
+
239
+ # Create penalty_alpha for GPT-2 encoder-decoder attention
240
+ if 'penalty_alpha' not in kwargs:
241
+ generation_kwargs['penalty_alpha'] = 0.6 # Helps prevent looping in GPT-2
242
 
243
+ # Override with any explicitly provided kwargs
244
+ generation_kwargs.update({k:v for k,v in kwargs.items() if k not in ('prompt', 'context')})
245
+
246
+ if model_type == "model_Custm":
247
+ # Use the Custom Wildnerve model for technical topics
248
+ custom_model = registry.get(MODEL)
249
+ if custom_model:
250
+ try:
251
+ logger.info("Using custom Wildnerve-tlm01_Hybrid_Model for technical prompt")
252
+ # Check signature of the generate method
253
+ import inspect
254
+ if hasattr(custom_model, "generate"):
255
+ sig = inspect.signature(custom_model.generate)
256
+ if "prompt" in sig.parameters:
257
+ return custom_model.generate(prompt=text_input, **generation_kwargs)
258
+ else:
259
+ # If no prompt parameter, try tokenizing first
260
+ inputs = self.tokenizer(text_input, return_tensors="pt", truncation=True, padding=True)
261
+ return custom_model.generate(input_ids=inputs.input_ids, **generation_kwargs)
262
+ else:
263
+ logger.warning("Custom model doesn't have generate method, falling back to pretrained")
264
+ except Exception as e:
265
+ logger.error(f"Error using custom model: {e}")
266
+ else:
267
+ # Use the Pretrained model (GPT-2) for general topics
268
+ pre = registry.get(PRETRAINED_MODEL)
269
+ if pre:
270
+ try:
271
+ logger.info("Using GPT-2 pretrained model for general prompt")
272
+ # Try to use the pretrained model's generate method
273
+ if hasattr(pre, "generate"):
274
+ # Check the signature of the generate method to determine correct parameters
275
+ import inspect
276
+ sig = inspect.signature(pre.generate)
277
+ if "prompt" in sig.parameters:
278
+ return pre.generate(prompt=text_input, **generation_kwargs)
279
+ else:
280
+ # If no prompt parameter, try tokenizing first
281
+ inputs = self.tokenizer(text_input, return_tensors="pt", truncation=True, padding=True)
282
+ return pre.generate(input_ids=inputs.input_ids, **generation_kwargs) # Explicitly pass as input_ids
283
+ else:
284
+ logger.warning("Pretrained model doesn't have generate method")
285
+ except Exception as e:
286
+ logger.error(f"Error using pretrained model: {e}")
287
+
288
+ # Fall back to using the custom model if needed
289
+ if self.model:
290
+ try:
291
+ logger.info("Using custom model for generation")
292
+
293
+ # Check if the model is expecting a prompt parameter or input_ids
294
+ import inspect
295
+ if hasattr(self.model, "generate"):
296
+ sig = inspect.signature(self.model.generate)
297
+ if "prompt" in sig.parameters:
298
+ # Model accepts prompt parameter directly
299
+ return self.model.generate(prompt=text_input, **generation_kwargs) # Explicitly pass as prompt
300
+ else:
301
+ # Model expects tokenized input_ids instead
302
+ logger.info("Model expects tokenized input - converting prompt to input_ids")
303
+ inputs = self.tokenizer(text_input, return_tensors="pt", truncation=True, padding=True)
304
+ return self.model.generate(input_ids=inputs.input_ids, **generation_kwargs) # Explicitly pass as input_ids
305
+ else:
306
+ logger.error("Model has no generate method")
307
+ # Simple fallback for models without generate
308
+ return f"I'm processing your request about '{text_input[:30]}...'"
309
+ except Exception as e:
310
+ logger.error(f"Error using custom model: {e}")
311
+
312
+ # Add last-chance fallback with generic response
313
+ return f"I apologize, but I'm experiencing some technical difficulties processing your request about '{text_input[:30]}...'. (Error: {str(e)})"
314
+
315
+ # Final fallback
316
+ return f"I apologize, but I'm unable to process your request about '{text_input[:30]}...' at this time."
317
+ except Exception as e:
318
+ logger.error(f"Error in generate method: {e}")
319
+ return f"An error occurred while generating text: {str(e)}"
config.json CHANGED
@@ -241,8 +241,19 @@
241
  "HIDDEN_DIM": 768,
242
  "MAX_CACHE_SIZE": 10
243
  },
 
 
 
 
 
244
  "MAX_ACTIVE_MODELS": 5,
245
  "MODEL_IDLE_THRESHOLD": 600,
246
  "MAX_MEMORY_USAGE": 0.8,
247
- "TOP_K": 3
 
 
 
 
 
 
248
  }
 
241
  "HIDDEN_DIM": 768,
242
  "MAX_CACHE_SIZE": 10
243
  },
244
+ "MODEL_PRIORITY": {
245
+ "PRIMARY": "model_Custm",
246
+ "SECONDARY": "model_PrTr",
247
+ "USE_PRETRAINED_FALLBACK": true
248
+ },
249
  "MAX_ACTIVE_MODELS": 5,
250
  "MODEL_IDLE_THRESHOLD": 600,
251
  "MAX_MEMORY_USAGE": 0.8,
252
+ "TOP_K": 3,
253
+ "TOPIC_KEYWORDS": {
254
+ "programming": ["python", "java", "javascript", /* other keywords */],
255
+ "computer_science": ["algorithm", "complexity", /* other keywords */],
256
+ "software_engineering": ["design pattern", "architecture", /* other keywords */],
257
+ "web_development": ["frontend", "backend", /* other keywords */]
258
+ }
259
  }
config.py CHANGED
@@ -388,6 +388,12 @@ class AppConfig(BaseModel):
388
  MAX_ACTIVE_MODELS: int = Field(default=2)
389
  MODEL_IDLE_THRESHOLD: int = Field(default=600)
390
 
 
 
 
 
 
 
391
  def load_config() -> AppConfig:
392
  config_path = os.path.join(os.path.dirname(__file__), "config.json")
393
  logger.info(f"Loading config from {config_path}")
@@ -395,14 +401,31 @@ def load_config() -> AppConfig:
395
  with open(config_path, "r") as f:
396
  raw = json.load(f)
397
 
398
- # helper to convert a dict into an object with attribute access
399
- class AttrDict(dict):
400
- __getattr__ = dict.get
401
- __setattr__ = dict.__setitem__
402
-
403
- # wrap TRANSFORMER_CONFIG if it's a dict
404
  if isinstance(raw.get("TRANSFORMER_CONFIG"), dict):
405
  raw["TRANSFORMER_CONFIG"] = AttrDict(raw["TRANSFORMER_CONFIG"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  except Exception as e:
407
  logger.error(f"Failed to read config.json: {e}", exc_info=True)
408
  raise
 
388
  MAX_ACTIVE_MODELS: int = Field(default=2)
389
  MODEL_IDLE_THRESHOLD: int = Field(default=600)
390
 
391
+ class AttrDict(dict):
392
+ """Dictionary subclass with attribute-style access"""
393
+ __getattr__ = dict.get
394
+ __setattr__ = dict.__setitem__
395
+ __delattr__ = dict.__delitem__
396
+
397
  def load_config() -> AppConfig:
398
  config_path = os.path.join(os.path.dirname(__file__), "config.json")
399
  logger.info(f"Loading config from {config_path}")
 
401
  with open(config_path, "r") as f:
402
  raw = json.load(f)
403
 
404
+ # Always wrap TRANSFORMER_CONFIG for attribute access
 
 
 
 
 
405
  if isinstance(raw.get("TRANSFORMER_CONFIG"), dict):
406
  raw["TRANSFORMER_CONFIG"] = AttrDict(raw["TRANSFORMER_CONFIG"])
407
+
408
+ # Ensure GPT-2 parameters
409
+ if not isinstance(raw["TRANSFORMER_CONFIG"].get("VOCAB_SIZE"), int) or raw["TRANSFORMER_CONFIG"]["VOCAB_SIZE"] != 50257:
410
+ raw["TRANSFORMER_CONFIG"]["VOCAB_SIZE"] = 50257 # Standard GPT-2 vocab size
411
+
412
+ if raw["TRANSFORMER_CONFIG"].get("MODEL_NAME") != "gpt2":
413
+ raw["TRANSFORMER_CONFIG"]["MODEL_NAME"] = "gpt2"
414
+
415
+ # Ensure OUTPUT_SIZE matches VOCAB_SIZE
416
+ raw["TRANSFORMER_CONFIG"]["OUTPUT_SIZE"] = raw["TRANSFORMER_CONFIG"]["VOCAB_SIZE"]
417
+
418
+ # Add generation parameters if missing
419
+ if "GENERATION_CONFIG" not in raw:
420
+ raw["GENERATION_CONFIG"] = {
421
+ "temperature": 0.7,
422
+ "top_p": 0.95,
423
+ "top_k": 50,
424
+ "repetition_penalty": 1.3,
425
+ "no_repeat_ngram_size": 3,
426
+ "do_sample": True,
427
+ "penalty_alpha": 0.6
428
+ }
429
  except Exception as e:
430
  logger.error(f"Failed to read config.json: {e}", exc_info=True)
431
  raise
model_Custm.py CHANGED
@@ -81,15 +81,15 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
81
  - SmartHybridAttention for better context handling"""
82
  def __init__(
83
  self,
84
- vocab_size=30522, # Default BERT vocab size
85
  specialization="general",
86
  dataset_path=None,
87
- model_name="Wildnerve-tlm01_Hybrid_Model", # Primary model name
88
  embedding_dim=768,
89
  num_heads=12,
90
  hidden_dim=768,
91
  num_layers=6,
92
- output_size=768,
93
  dropout=0.1,
94
  max_seq_length=512,
95
  pooling_mode="mean",
@@ -123,18 +123,16 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
123
  self.tokenizer = registry.get(TOKENIZER)
124
  else:
125
  try:
126
- from transformers import AutoTokenizer
127
- self.tokenizer = AutoTokenizer.from_pretrained("Wildnerve-tlm01_Hybrid_Model")
128
- logger.info("Loaded primary tokenizer: Wildnerve-tlm01_Hybrid_Model")
 
 
 
129
  except Exception as e:
130
- logger.warning(f"Primary tokenizer load failed: {e}")
131
- try:
132
- from transformers import BertTokenizer
133
- self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
134
- logger.info("Loaded fallback tokenizer: bert-base-uncased")
135
- except Exception as e2:
136
- logger.error(f"Fallback tokenizer load failed: {e2}")
137
- self.tokenizer = None
138
  registry.register(TOKENIZER, self.tokenizer, overwrite=True)
139
 
140
  # Register this model instance in the registry by specialization
@@ -363,6 +361,7 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
363
  embeddings = self.encode_sentences([sentence1, sentence2])
364
  return np.dot(embeddings[0], embeddings[1]) / (np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]))
365
 
 
366
  def generate(
367
  self,
368
  prompt=None,
@@ -373,6 +372,17 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
373
  **kwargs
374
  ) -> str:
375
  """Generate text using the model, supporting either prompt string or input_ids."""
 
 
 
 
 
 
 
 
 
 
 
376
  # Log what we're working with
377
  logger.info(f"Generate called with: prompt={type(prompt).__name__ if prompt else None}, input_ids={type(input_ids).__name__ if input_ids else None}")
378
 
 
81
  - SmartHybridAttention for better context handling"""
82
  def __init__(
83
  self,
84
+ vocab_size=50257, # Updated to GPT-2 vocab size
85
  specialization="general",
86
  dataset_path=None,
87
+ model_name="gpt2", # Standardized to GPT-2
88
  embedding_dim=768,
89
  num_heads=12,
90
  hidden_dim=768,
91
  num_layers=6,
92
+ output_size=50257, # Updated to GPT-2 vocab size
93
  dropout=0.1,
94
  max_seq_length=512,
95
  pooling_mode="mean",
 
123
  self.tokenizer = registry.get(TOKENIZER)
124
  else:
125
  try:
126
+ from transformers import GPT2Tokenizer
127
+ self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
128
+ # Ensure pad_token is set for GPT-2
129
+ if self.tokenizer.pad_token_id is None:
130
+ self.tokenizer.pad_token = self.tokenizer.eos_token
131
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
132
  except Exception as e:
133
+ logger.warning(f"Failed to load GPT-2 tokenizer: {e}")
134
+ from utils.transformer_utils import get_tokenizer
135
+ self.tokenizer = get_tokenizer()
 
 
 
 
 
136
  registry.register(TOKENIZER, self.tokenizer, overwrite=True)
137
 
138
  # Register this model instance in the registry by specialization
 
361
  embeddings = self.encode_sentences([sentence1, sentence2])
362
  return np.dot(embeddings[0], embeddings[1]) / (np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]))
363
 
364
+ # Update generate to use adapter_layer as the primary generation point
365
  def generate(
366
  self,
367
  prompt=None,
 
372
  **kwargs
373
  ) -> str:
374
  """Generate text using the model, supporting either prompt string or input_ids."""
375
+ # Try to use adapter_layer.generate if available
376
+ adapter_layer = registry.get("adapter_layer")
377
+ if adapter_layer and hasattr(adapter_layer, "generate"):
378
+ if prompt:
379
+ return adapter_layer.generate(prompt, max_length=max_length, temperature=temperature, **kwargs)
380
+ elif input_ids is not None and self.tokenizer:
381
+ # Convert input_ids back to text to use centralized generation
382
+ decoded_prompt = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
383
+ return adapter_layer.generate(decoded_prompt, max_length=max_length, temperature=temperature, **kwargs)
384
+
385
+ # Fall back to direct generation if adapter_layer is not available
386
  # Log what we're working with
387
  logger.info(f"Generate called with: prompt={type(prompt).__name__ if prompt else None}, input_ids={type(input_ids).__name__ if input_ids else None}")
388
 
model_List.py CHANGED
@@ -32,22 +32,46 @@ class PromptAnalyzer:
32
  - SmartHybridAttention for analyzing complex or long prompts
33
  - Performance tracking and caching for efficiency
34
  """
35
- def __init__(self):
36
  self.logger = logging.getLogger(__name__)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # For caching and performance tracking
39
  self._model_cache = {}
40
  self._performance_metrics = {}
41
 
42
- # Define topic keywords for the simple approach
43
- self.predefined_topics = {
44
- "programming": ["code", "function", "class", "algorithm", "programming", "python", "javascript", "java", "c++", "developer", "api"],
45
- "science": ["science", "physics", "chemistry", "biology", "scientific", "experiment", "hypothesis", "theory"],
46
- "mathematics": ["math", "equation", "calculus", "algebra", "geometry", "theorem", "mathematical"],
47
- "history": ["history", "historical", "ancient", "century", "war", "civilization", "empire"],
48
- "general": ["how", "what", "when", "where", "why", "who", "can you", "please", "thanks", "hello"]
49
- }
50
-
51
  # Initialize model_class attribute
52
  self.model_class = None
53
 
@@ -67,30 +91,88 @@ class PromptAnalyzer:
67
  except Exception:
68
  pass
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def _init_advanced_tools(self):
71
- """Initialize advanced analysis tools with proper error handling"""
72
  self.sentence_model = None
73
  self.gpt2_model = None
74
  self.gpt2_tokenizer = None
75
 
76
- # Only initialize if enabled by environment variable
77
- if os.environ.get("LOAD_PRETRAINED_MODELS", "0") == "1":
78
- try:
79
- from utils.transformer_utils import get_sentence_transformer
80
- self.sentence_model = get_sentence_transformer('sentence-transformers/all-MiniLM-L6-v2')
81
- self.logger.info("Loaded SentenceTransformer model successfully")
82
- except Exception as e:
83
- self.logger.warning(f"Failed to load SentenceTransformer: {e}")
 
 
 
 
 
 
 
 
 
84
 
85
- try:
86
- from transformers import AutoModelForCausalLM, AutoTokenizer
87
- self.gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
88
- self.gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2")
89
- self.gpt2_model.eval()
90
- self.logger.info("Loaded GPT-2 model for perplexity calculation")
91
- except Exception as e:
92
- self.logger.warning(f"Failed to load GPT-2: {e}")
93
 
 
 
 
 
 
 
 
94
  # Initialize SmartHybridAttention
95
  try:
96
  attention_config = get_hybrid_attention_config()
@@ -253,173 +335,123 @@ class PromptAnalyzer:
253
  self.logger.error(f"Error in attention-based analysis: {e}")
254
  return None
255
 
256
- def analyze_prompt(self, prompt):
257
- """
258
- Enhanced prompt analysis with SmartHybridAttention for complex prompts
259
- """
260
- # Start with simple keyword-based classification
261
  prompt_lower = prompt.lower()
262
- topic_scores = {}
 
263
 
264
- for topic, keywords in self.predefined_topics.items():
265
- score = sum(1 for keyword in keywords if keyword in prompt_lower)
266
- topic_scores[topic] = score
 
 
267
 
268
- # For complex prompts, use attention-based analysis
269
- is_complex = len(prompt) > 100 or prompt.count('.') > 2 # Basic heuristic for complexity
270
 
271
- if is_complex and self.attention:
272
- attention_scores = self._analyze_with_attention(prompt)
273
- if attention_scores:
274
- # Combine scores with attention-based analysis
275
- for topic, score in attention_scores.items():
276
- if topic in topic_scores:
277
- # Weighted combination based on prompt complexity
278
- complexity_factor = min(0.7, len(prompt) / 1000)
279
- topic_scores[topic] = (topic_scores[topic] * (1-complexity_factor)) + (score * complexity_factor)
 
280
 
281
- # Advanced analysis if available
282
- try:
283
- if self.sentence_model:
284
- # Get embedding and boost scores based on embedding similarity
285
- embedding = self.sentence_model.encode(prompt)
286
-
287
- # We could have reference embeddings for each topic and calculate similarity
288
- # For now, we'll just use the embedding magnitude as a complexity measure
289
- complexity = np.linalg.norm(embedding)
290
-
291
- # Adjust scores based on complexity
292
- if complexity > 15: # High complexity
293
- if topic_scores.get("programming", 0) > 0:
294
- topic_scores["programming"] *= 1.5
295
- if topic_scores.get("science", 0) > 0:
296
- topic_scores["science"] *= 1.4
297
- if topic_scores.get("mathematics", 0) > 0:
298
- topic_scores["mathematics"] *= 1.3
299
-
300
- if self.gpt2_model and self.gpt2_tokenizer:
301
- # Calculate perplexity for another dimension of analysis
302
- try:
303
- inputs = self.gpt2_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
304
- with torch.no_grad():
305
- outputs = self.gpt2_model(**inputs, labels=inputs["input_ids"])
306
-
307
- loss = outputs.loss.item()
308
- perplexity = math.exp(loss)
309
-
310
- # Adjust scores based on perplexity
311
- if perplexity > 100: # Very specialized/technical content
312
- if topic_scores.get("programming", 0) > 0:
313
- topic_scores["programming"] *= 1.4
314
- if topic_scores.get("science", 0) > 0:
315
- topic_scores["science"] *= 1.3
316
- if topic_scores.get("mathematics", 0) > 0:
317
- topic_scores["mathematics"] *= 1.2
318
- except Exception as e:
319
- logger.warning(f"Error in perplexity calculation: {e}")
320
- except Exception as e:
321
- logger.warning(f"Advanced analysis failed: {e}")
322
 
323
- # Find the topic with the highest score
324
- if not topic_scores or max(topic_scores.values()) == 0:
325
- return "general", []
 
 
 
 
 
 
326
 
327
- primary_topic = max(topic_scores.items(), key=lambda x: x[1])[0]
 
 
 
328
 
329
- # Get secondary topics (any with non-zero scores except primary)
330
- secondary_topics = [t for t, s in topic_scores.items()
331
- if s > 0 and t != primary_topic]
332
 
333
- return primary_topic, secondary_topics
334
-
335
- def analyze(self, prompt: str) -> int:
336
- """
337
- Analyze prompt complexity with attention-enhanced analysis
338
- """
339
- # First check if we can use attention-based analysis for complex prompts
340
- if self.attention and len(prompt) > 150:
341
  try:
342
- # Get sentence embeddings
343
- sentences = nltk.sent_tokenize(prompt)
344
- if len(sentences) > 1:
345
- # Apply attention to understand cross-sentence relationships
346
- sentence_embeddings = [self.sentence_model.encode(s) for s in sentences]
347
- embeddings_tensor = torch.tensor(sentence_embeddings).unsqueeze(1)
348
-
349
- # Use attention to focus on important parts of the prompt
350
- attended_embeddings, _ = self.attention(
351
- query=embeddings_tensor,
352
- key=embeddings_tensor,
353
- value=embeddings_tensor
354
- )
355
-
356
- # Calculate complexity based on attention-weighted embeddings
357
- complexity = torch.norm(attended_embeddings.mean(dim=0)).item()
358
- logger.info(f"Computed attention-weighted complexity: {complexity}")
359
-
360
- # Return candidate index based on complexity
361
- if complexity < 12:
362
- return 0 # Simpler model
363
- elif complexity < 24:
364
- return 1 # Moderate model
365
- else:
366
- return 2 # Complex model
367
  except Exception as e:
368
- logger.warning(f"Attention-based analysis failed: {e}")
369
-
370
- # Use embeddings if available for complexity analysis
371
- if self.sentence_model:
372
- try:
373
- # Get embedding and calculate complexity based on vector properties
374
- embedding = self.sentence_model.encode(prompt)
375
- # Calculate complexity (vector magnitude)
376
- complexity = np.linalg.norm(embedding)
377
- logger.info(f"Computed embedding complexity: {complexity}")
378
-
379
- # Return appropriate index based on complexity
380
- if complexity < 10:
381
- return 0 # Less complex, use simpler model
382
- elif complexity < 20:
383
- return 1 # Moderate complexity
384
- else:
385
- return 2 # High complexity, use specialized model
386
- except Exception as e:
387
- logger.warning(f"Embedding-based analysis failed: {e}")
388
 
389
- # Use perplexity as a fallback
390
- if self.gpt2_model and self.gpt2_tokenizer:
391
- try:
392
- inputs = self.gpt2_tokenizer(prompt, return_tensors="pt", truncation=True)
393
- with torch.no_grad():
394
- outputs = self.gpt2_model(**inputs, labels=inputs["input_ids"])
395
- loss = outputs.loss.item()
396
- perplexity = math.exp(loss)
397
- logger.info(f"Computed perplexity: {perplexity}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
- # Example threshold-based candidate selection:
400
- if perplexity < 50:
401
- return 0 # Less perplexing, use simpler model
402
- elif perplexity < 100:
403
- return 1 # Moderate perplexity
404
- else:
405
- return 2 # High perplexity, use specialized model
406
- except Exception as e:
407
- logger.warning(f"Perplexity calculation failed: {e}")
408
-
409
- # Fallback to simple keyword-based analysis
410
- primary_topic, secondary_topics = self.analyze_prompt(prompt)
411
-
412
- # Map topics to model indices
413
- topic_to_index = {
414
- "general": 0,
415
- "history": 0,
416
- "programming": 1,
417
- "science": 1,
418
- "mathematics": 2
419
- }
420
 
421
- # Return appropriate index or 0 if topic not in mapping
422
- return topic_to_index.get(primary_topic, 0)
 
 
 
423
 
424
  def choose_model(self, prompt: str = None) -> Type:
425
  """Enhanced model selection that combines config and analysis"""
 
32
  - SmartHybridAttention for analyzing complex or long prompts
33
  - Performance tracking and caching for efficiency
34
  """
35
+ def __init__(self, model_name=None, dataset_path=None, specialization=None, hidden_dim=None):
36
  self.logger = logging.getLogger(__name__)
37
 
38
+ # Load config
39
+ self.config = load_config(config_file="config.json")
40
+
41
+ # Use provided values or config values
42
+ self.model_name = model_name or self.config.PROMPT_ANALYZER_CONFIG.MODEL_NAME
43
+ self.dataset_path = dataset_path or self.config.PROMPT_ANALYZER_CONFIG.DATASET_PATH
44
+ self.specialization = specialization or self.config.PROMPT_ANALYZER_CONFIG.SPECIALIZATION
45
+ self.hidden_dim = hidden_dim or self.config.PROMPT_ANALYZER_CONFIG.HIDDEN_DIM
46
+
47
+ self.logger.info(f"Initialized PromptAnalyzer with {self.model_name}")
48
+ self._model_cache: Dict[str, Type] = {}
49
+ self._performance_metrics: Dict[str, Dict[str, float]] = {}
50
+
51
+ # Load predefined topics from config or fall back to defaults
52
+ self._load_predefined_topics()
53
+
54
+ # Always use a proper SentenceTransformer model - fix this to avoid warnings
55
+ if hasattr(self, 'sentence_model'):
56
+ del self.sentence_model # Remove any existing instance
57
+
58
+ # Use a proper SentenceTransformer model
59
+ self.sentence_model = get_sentence_transformer('sentence-transformers/all-MiniLM-L6-v2')
60
+ self.logger.info(f"Using SentenceTransformer model: sentence-transformers/all-MiniLM-L6-v2")
61
+
62
+ # Use GPT-2 for perplexity calculation
63
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
64
+ self.model = AutoModelForCausalLM.from_pretrained("gpt2")
65
+ self.model.eval()
66
+
67
+ logger.info(f"Initialized PromptAnalyzer with {self.model_name}, specialization: {self.specialization}, hidden_dim: {self.hidden_dim}")
68
+ if self.dataset_path:
69
+ logger.info(f"Using dataset from: {self.dataset_path}")
70
+
71
  # For caching and performance tracking
72
  self._model_cache = {}
73
  self._performance_metrics = {}
74
 
 
 
 
 
 
 
 
 
 
75
  # Initialize model_class attribute
76
  self.model_class = None
77
 
 
91
  except Exception:
92
  pass
93
 
94
+ def _load_predefined_topics(self):
95
+ """Load topic keywords from config file or use defaults with caching"""
96
+ # Try to load from config first
97
+ try:
98
+ if hasattr(app_config, 'TOPIC_KEYWORDS') and app_config.TOPIC_KEYWORDS:
99
+ logger.info("Loading topic keywords from config")
100
+ self.predefined_topics = app_config.TOPIC_KEYWORDS
101
+ return
102
+
103
+ # Try loading from a JSON file in the data directory
104
+ topic_file = os.path.join(app_config.DATA_DIR, "topic_keywords.json")
105
+ if os.path.exists(topic_file):
106
+ with open(topic_file, 'r') as f:
107
+ self.predefined_topics = json.load(f)
108
+ logger.info(f"Loaded {len(self.predefined_topics)} topic categories from {topic_file}")
109
+ return
110
+ except Exception as e:
111
+ logger.warning(f"Error loading topic keywords: {e}, using defaults")
112
+
113
+ # Fall back to default hardcoded topics
114
+ logger.info("Using default hardcoded topic keywords")
115
+ self.predefined_topics = {
116
+ "programming": [
117
+ "python", "java", "javascript", "typescript", "rust", "go", "golang",
118
+ # ...existing keywords...
119
+ ],
120
+ "computer_science": [
121
+ # ...existing keywords...
122
+ ],
123
+ "software_engineering": [
124
+ # ...existing keywords...
125
+ ],
126
+ "web_development": [
127
+ # ...existing keywords...
128
+ ]
129
+ }
130
+
131
+ # Cache the topics to a file for future use
132
+ try:
133
+ os.makedirs(app_config.DATA_DIR, exist_ok=True)
134
+ with open(os.path.join(app_config.DATA_DIR, "topic_keywords.json"), 'w') as f:
135
+ json.dump(self.predefined_topics, f, indent=2)
136
+ except Exception as e:
137
+ logger.debug(f"Could not cache topic keywords: {e}")
138
+
139
  def _init_advanced_tools(self):
140
+ """Initialize advanced analysis tools with proper error handling and fallbacks"""
141
  self.sentence_model = None
142
  self.gpt2_model = None
143
  self.gpt2_tokenizer = None
144
 
145
+ # For embedding model, implement multiple fallbacks
146
+ MAX_RETRIES = 3
147
+ embedding_models = [
148
+ 'sentence-transformers/all-MiniLM-L6-v2', # Primary choice
149
+ 'sentence-transformers/paraphrase-MiniLM-L3-v2', # Smaller fallback
150
+ 'sentence-transformers/distilbert-base-nli-mean-tokens' # Last resort
151
+ ]
152
+
153
+ for retry in range(MAX_RETRIES):
154
+ for model_name in embedding_models:
155
+ try:
156
+ from utils.transformer_utils import get_sentence_transformer
157
+ self.sentence_model = get_sentence_transformer(model_name)
158
+ self.logger.info(f"Successfully loaded SentenceTransformer: {model_name}")
159
+ break
160
+ except Exception as e:
161
+ self.logger.warning(f"Failed to load embedding model {model_name}: {e}")
162
 
163
+ if self.sentence_model:
164
+ break
165
+
166
+ # Wait before retry
167
+ time.sleep(2)
 
 
 
168
 
169
+ # Create keyword-based fallback if embedding loading completely fails
170
+ if not self.sentence_model:
171
+ self.logger.warning("All embedding models failed to load - using keyword fallback")
172
+ self._use_keyword_fallback = True
173
+ else:
174
+ self._use_keyword_fallback = False
175
+
176
  # Initialize SmartHybridAttention
177
  try:
178
  attention_config = get_hybrid_attention_config()
 
335
  self.logger.error(f"Error in attention-based analysis: {e}")
336
  return None
337
 
338
+ def _analyze_with_keywords(self, prompt: str) -> Tuple[str, float]:
339
+ """Analyze prompt using only keywords when embeddings are unavailable"""
 
 
 
340
  prompt_lower = prompt.lower()
341
+ technical_matches = 0
342
+ total_words = len(prompt_lower.split())
343
 
344
+ # Count matches across all technical categories
345
+ for category, keywords in self.predefined_topics.items():
346
+ for keyword in keywords:
347
+ if keyword in prompt_lower:
348
+ technical_matches += 1
349
 
350
+ # Simple ratio calculation
351
+ match_ratio = technical_matches / max(1, min(15, total_words))
352
 
353
+ if match_ratio > 0.1: # Even a single match in a short query is significant
354
+ return "model_Custm", match_ratio
355
+ else:
356
+ return "model_PrTr", 0.7
357
+
358
+ def analyze_prompt(self, prompt: str) -> Tuple[str, float]:
359
+ """Analyze if a prompt is technical or general and return the appropriate model type and confidence score."""
360
+ # Check if we need to use keyword fallback due to embedding failure
361
+ if hasattr(self, '_use_keyword_fallback') and self._use_keyword_fallback:
362
+ return self._analyze_with_keywords(prompt)
363
 
364
+ # Convert prompt to lowercase for case-insensitive matching
365
+ prompt_lower = prompt.lower()
366
+
367
+ # Check for technical keywords from predefined topics - use memory-efficient approach
368
+ technical_matches = 0
369
+ word_count = len(prompt_lower.split())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
+ # Use a set-based intersection approach for better performance on longer texts
372
+ prompt_words = set(prompt_lower.split())
373
+
374
+ # Count keyword matches across all technical categories more efficiently
375
+ for category, keywords in self.predefined_topics.items():
376
+ # Convert keywords to set for O(1) lookups - helps with longer texts
377
+ keywords_set = set(keywords)
378
+ matches = prompt_words.intersection(keywords_set)
379
+ technical_matches += len(matches)
380
 
381
+ # Also check for multi-word keywords not caught by simple splitting
382
+ for keyword in keywords:
383
+ if " " in keyword and keyword in prompt_lower:
384
+ technical_matches += 1
385
 
386
+ # Calculate keyword match ratio (normalized by word count)
387
+ keyword_ratio = technical_matches / max(1, min(20, word_count))
 
388
 
389
+ # Get attention-based analysis for complex prompts
390
+ attention_scores = None
391
+ if len(prompt) > 100 and self.attention: # Only use attention for longer prompts
 
 
 
 
 
392
  try:
393
+ attention_scores = self._analyze_with_attention(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  except Exception as e:
395
+ self.logger.warning(f"Error in attention analysis: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
+ # Use embedding similarity for semantic understanding
398
+ try:
399
+ # Get embedding of the prompt
400
+ prompt_embedding = self.sentence_model.encode(prompt)
401
+
402
+ # Example technical and general reference texts
403
+ technical_reference = "Write code to solve a programming problem using algorithms and data structures."
404
+ general_reference = "Tell me about daily life topics like weather, food, or general conversation."
405
+
406
+ # Get embeddings for reference texts
407
+ technical_embedding = self.sentence_model.encode(technical_reference)
408
+ general_embedding = self.sentence_model.encode(general_reference)
409
+
410
+ # Calculate cosine similarities
411
+ technical_similarity = cosine_similarity([prompt_embedding], [technical_embedding])[0][0]
412
+ general_similarity = cosine_similarity([prompt_embedding], [general_embedding])[0][0]
413
+
414
+ # Calculate technical score combining all signals:
415
+ # 1. Keyword matching (30%)
416
+ # 2. Semantic similarity (40%)
417
+ # 3. Attention analysis if available (30%)
418
+ technical_score = 0.3 * keyword_ratio + 0.4 * technical_similarity
419
+
420
+ # Add attention score contribution if available
421
+ if attention_scores:
422
+ # Calculate tech score from attention - sum of programming/computer_science categories
423
+ tech_attention_score = (
424
+ attention_scores.get("programming", 0) +
425
+ attention_scores.get("computer_science", 0) +
426
+ attention_scores.get("software_engineering", 0) +
427
+ attention_scores.get("web_development", 0)
428
+ ) / 4.0 # Normalize
429
+ technical_score += 0.3 * tech_attention_score
430
+
431
+ # Decide based on combined score
432
+ if technical_score > 0.3: # Threshold - tune this as needed
433
+ return "model_Custm", technical_score
434
+ else:
435
+ return "model_PrTr", 1.0 - technical_score
436
 
437
+ except Exception as e:
438
+ self.logger.error(f"Error in prompt analysis: {e}")
439
+
440
+ # Fallback to simple keyword matching
441
+ if technical_matches > 0:
442
+ return "model_Custm", 0.7
443
+ else:
444
+ return "model_PrTr", 0.7
445
+
446
+ def analyze(self, prompt: str) -> int:
447
+ """Legacy compatibility method that returns a candidate index."""
448
+ model_type, confidence = self.analyze_prompt(prompt)
 
 
 
 
 
 
 
 
 
449
 
450
+ # Map model_type to candidate index
451
+ if model_type == "model_Custm":
452
+ return 0 # Index 0 corresponds to model_Custm
453
+ else:
454
+ return 1 # Index 1 corresponds to model_PrTr
455
 
456
  def choose_model(self, prompt: str = None) -> Type:
457
  """Enhanced model selection that combines config and analysis"""
model_PrTr.py CHANGED
@@ -58,15 +58,15 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
58
  The model uses the GPT-2 tokenizer for consistent tokenization."""
59
  def __init__(
60
  self,
61
- vocab_size: int = 50257, # Updated to GPT-2 vocab size
62
  specialization: str = "general",
63
  dataset_path: str = None,
64
- model_name: str = "gpt2", # Changed from bert-base-uncased to gpt2
65
  embedding_dim: int = 768,
66
  num_heads: int = 12,
67
  hidden_dim: int = 768,
68
  num_layers: int = 6,
69
- output_size: int = 50257, # Match GPT-2 vocab size
70
  dropout: float = 0.1,
71
  max_seq_length: int = 1024, # GPT-2 supports longer contexts
72
  pooling_mode: str = "last", # GPT-2 typically uses last token
@@ -99,15 +99,18 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
99
  # Initialize the model and tokenizer
100
  self.gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)
101
 
102
- # Use tokenizer from params, registry, or create new GPT-2 tokenizer
103
  if tokenizer is not None:
104
  self.tokenizer = tokenizer
105
  elif registry.has(TOKENIZER):
106
  self.tokenizer = registry.get(TOKENIZER)
107
  else:
108
  self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
109
- if self.tokenizer.pad_token_id is None:
110
- self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
111
 
112
  logger.info(f"Successfully loaded GPT-2 model: {model_name}")
113
 
@@ -135,36 +138,54 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
135
  return outputs.logits
136
 
137
  # Update generate to handle both direct prompt and tokenized input
138
- def generate(self, prompt=None, input_ids=None, **kwargs):
139
  """Generate text using the GPT-2 model"""
140
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # Handle either string prompt or direct input_ids
142
  if isinstance(prompt, str) and input_ids is None:
143
  inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
144
  input_ids = inputs.input_ids
145
  elif input_ids is None:
146
  raise ValueError("Either prompt or input_ids must be provided")
147
-
148
- # Set default parameters if not provided
149
- generation_kwargs = {
150
- "max_length": kwargs.get("max_length", min(self.max_length, 1024)),
151
- "max_new_tokens": kwargs.get("max_new_tokens", 512), # Added max_new_tokens
152
- "temperature": kwargs.get("temperature", 0.7),
153
- "top_p": kwargs.get("top_p", 0.9),
154
- "top_k": kwargs.get("top_k", 50),
155
- "repetition_penalty": kwargs.get("repetition_penalty", 1.0),
156
- "do_sample": kwargs.get("do_sample", True),
157
- "num_return_sequences": kwargs.get("num_return_sequences", 1),
158
- "pad_token_id": self.tokenizer.pad_token_id
159
- }
160
 
161
  # Use max_new_tokens instead of max_length if input is longer than max_length-50
162
- if input_ids.shape[1] > (generation_kwargs["max_length"] - 50):
163
  logger.info(f"Input length {input_ids.shape[1]} is close to max_length, using max_new_tokens instead")
164
- del generation_kwargs["max_length"]
165
 
166
  # Generate output using the full GPT-2 model
167
- output_ids = self.gpt2_model.generate(input_ids, **generation_kwargs)
168
 
169
  # Decode the output and ensure it's a string, not a tensor
170
  if torch.is_tensor(output_ids):
@@ -358,7 +379,7 @@ class Wildnerve_tlm01:
358
  """
359
  def __init__(
360
  self,
361
- model_name="distilbert-base-uncased",
362
  tokenizer=None,
363
  device=None,
364
  **kwargs
@@ -408,4 +429,62 @@ try:
408
  registry.register(PRETRAINED_MODEL, model, overwrite=True)
409
  logger.info("Registered pretrained model in service registry")
410
  except Exception as e:
411
- logger.error(f"Failed to register pretrained model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  The model uses the GPT-2 tokenizer for consistent tokenization."""
59
  def __init__(
60
  self,
61
+ vocab_size: int = 50257, # Standardized GPT-2 vocab size
62
  specialization: str = "general",
63
  dataset_path: str = None,
64
+ model_name: str = "gpt2", # Standardized to GPT-2
65
  embedding_dim: int = 768,
66
  num_heads: int = 12,
67
  hidden_dim: int = 768,
68
  num_layers: int = 6,
69
+ output_size: int = 50257, # Standardized GPT-2 vocab size
70
  dropout: float = 0.1,
71
  max_seq_length: int = 1024, # GPT-2 supports longer contexts
72
  pooling_mode: str = "last", # GPT-2 typically uses last token
 
99
  # Initialize the model and tokenizer
100
  self.gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)
101
 
102
+ # Ensure proper tokenizer setup for GPT-2
103
  if tokenizer is not None:
104
  self.tokenizer = tokenizer
105
  elif registry.has(TOKENIZER):
106
  self.tokenizer = registry.get(TOKENIZER)
107
  else:
108
  self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
109
+
110
+ # Ensure GPT-2 tokenizer has pad_token set (critical fix)
111
+ if self.tokenizer.pad_token_id is None:
112
+ self.tokenizer.pad_token = self.tokenizer.eos_token
113
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
114
 
115
  logger.info(f"Successfully loaded GPT-2 model: {model_name}")
116
 
 
138
  return outputs.logits
139
 
140
  # Update generate to handle both direct prompt and tokenized input
141
+ def generate(self, prompt=None, input_ids=None, max_length=None, **kwargs):
142
  """Generate text using the GPT-2 model"""
143
  try:
144
+ # Try to use adapter_layer.generate if available (consolidate generation paths)
145
+ adapter_layer = registry.get("adapter_layer")
146
+ if adapter_layer and hasattr(adapter_layer, "generate"):
147
+ if prompt:
148
+ return adapter_layer.generate(prompt, max_length=max_length, **kwargs)
149
+ elif input_ids is not None and self.tokenizer:
150
+ # Convert input_ids back to text
151
+ prompt = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
152
+ return adapter_layer.generate(prompt, max_length=max_length, **kwargs)
153
+
154
+ # Continue with direct generation if adapter_layer not available
155
+ # Enhanced generation parameters
156
+ generation_config = {
157
+ "max_length": max_length or 150,
158
+ "temperature": kwargs.get('temperature', 0.7),
159
+ "top_p": kwargs.get('top_p', 0.95),
160
+ "top_k": kwargs.get('top_k', 50),
161
+ "repetition_penalty": kwargs.get('repetition_penalty', 1.3),
162
+ "no_repeat_ngram_size": kwargs.get('no_repeat_ngram_size', 3),
163
+ "do_sample": True,
164
+ "pad_token_id": self.tokenizer.pad_token_id,
165
+ "eos_token_id": self.tokenizer.eos_token_id,
166
+ "early_stopping": True,
167
+ "penalty_alpha": 0.6 # Add penalty alpha for better response quality
168
+ }
169
+
170
  # Handle either string prompt or direct input_ids
171
  if isinstance(prompt, str) and input_ids is None:
172
  inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
173
  input_ids = inputs.input_ids
174
  elif input_ids is None:
175
  raise ValueError("Either prompt or input_ids must be provided")
176
+
177
+ # Add user-provided kwargs that we didn't explicitly set
178
+ for k, v in kwargs.items():
179
+ if k not in generation_config and k not in ('prompt', 'context'):
180
+ generation_config[k] = v
 
 
 
 
 
 
 
 
181
 
182
  # Use max_new_tokens instead of max_length if input is longer than max_length-50
183
+ if input_ids.shape[1] > (generation_config["max_length"] - 50):
184
  logger.info(f"Input length {input_ids.shape[1]} is close to max_length, using max_new_tokens instead")
185
+ del generation_config["max_length"]
186
 
187
  # Generate output using the full GPT-2 model
188
+ output_ids = self.gpt2_model.generate(input_ids, **generation_config)
189
 
190
  # Decode the output and ensure it's a string, not a tensor
191
  if torch.is_tensor(output_ids):
 
379
  """
380
  def __init__(
381
  self,
382
+ model_name="gpt2",
383
  tokenizer=None,
384
  device=None,
385
  **kwargs
 
429
  registry.register(PRETRAINED_MODEL, model, overwrite=True)
430
  logger.info("Registered pretrained model in service registry")
431
  except Exception as e:
432
+ logger.error(f"Failed to register pretrained model: {e}")
433
+
434
+ def initialize_system():
435
+ """Initialize all components in the correct order"""
436
+ logger.info("Starting system initialization")
437
+
438
+ # First tokenizer - Use GPT-2 tokenizer instead of BERT
439
+ try:
440
+ from transformers import GPT2Tokenizer
441
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
442
+ # GPT-2 tokenizer doesn't have a pad_token by default, so we set it
443
+ if tokenizer.pad_token is None:
444
+ tokenizer.pad_token = tokenizer.eos_token
445
+ except Exception as e:
446
+ logger.warning(f"Could not load GPT-2 tokenizer, falling back to wrapper: {e}")
447
+ from tokenizer import TokenizerWrapper
448
+ tokenizer = TokenizerWrapper(model_name="gpt2")
449
+
450
+ # Then register tokenizer
451
+ from service_registry import registry, TOKENIZER, PRETRAINED_MODEL
452
+ registry.register(TOKENIZER, tokenizer, overwrite=True)
453
+ logger.info("Tokenizer registered")
454
+
455
+ # Initialize pretrained model (GPT-2)
456
+ try:
457
+ from model_PrTr import Wildnerve_tlm01 as PretrainedModel
458
+ pretrained = PretrainedModel(model_name="gpt2", tokenizer=tokenizer)
459
+ registry.register(PRETRAINED_MODEL, pretrained, overwrite=True)
460
+ logger.info("GPT-2 pretrained model registered")
461
+ except Exception as e:
462
+ logger.error(f"Failed to initialize GPT-2 model: {e}", exc_info=True)
463
+
464
+ # Now load custom model
465
+ try:
466
+ from model_Custm import Wildnerve_tlm01
467
+ model = Wildnerve_tlm01(
468
+ vocab_size=50257, # Match GPT-2 vocab size
469
+ specialization="general",
470
+ dataset_path=None,
471
+ model_name="gpt2", # Use GPT-2 compatibility
472
+ embedding_dim=768,
473
+ num_heads=12,
474
+ hidden_dim=768,
475
+ num_layers=2,
476
+ output_size=50257, # Match GPT-2 vocab
477
+ dropout=0.1,
478
+ max_seq_length=128,
479
+ pooling_mode="mean",
480
+ tokenizer=tokenizer
481
+ )
482
+
483
+ # Register model
484
+ from service_registry import MODEL
485
+ registry.register(MODEL, model, overwrite=True)
486
+ logger.info("Custom model registered successfully")
487
+ return True
488
+ except Exception as e:
489
+ logger.error(f"Failed to initialize custom model: {e}", exc_info=True)
490
+ return False
service_registry.py CHANGED
@@ -137,18 +137,18 @@ def ensure_models_registered():
137
  tok = registry.get(TOKENIZER)
138
  if not tok:
139
  from tokenizer import TokenizerWrapper
140
- tok = TokenizerWrapper(model_name="gpt2")
141
  registry.register(TOKENIZER, tok, overwrite=True)
142
 
143
  # Create pretrained model
144
  model = model_class(
145
- model_name="gpt2",
146
  tokenizer=tok
147
  )
148
 
149
  # Register as pretrained model
150
  registry.register(PRETRAINED_MODEL, model, overwrite=True)
151
- logger.info("Successfully registered pretrained model")
152
  return True
153
 
154
  logger.error(f"model_PrTr.py not found at {model_path}")
 
137
  tok = registry.get(TOKENIZER)
138
  if not tok:
139
  from tokenizer import TokenizerWrapper
140
+ tok = TokenizerWrapper(model_name="gpt2") # Changed from bert-base-uncased
141
  registry.register(TOKENIZER, tok, overwrite=True)
142
 
143
  # Create pretrained model
144
  model = model_class(
145
+ model_name="gpt2", # Explicitly use gpt2
146
  tokenizer=tok
147
  )
148
 
149
  # Register as pretrained model
150
  registry.register(PRETRAINED_MODEL, model, overwrite=True)
151
+ logger.info("Successfully registered GPT-2 pretrained model")
152
  return True
153
 
154
  logger.error(f"model_PrTr.py not found at {model_path}")