WildnerveAI commited on
Commit
2de4a17
·
verified ·
1 Parent(s): b9e3afb

Upload 2 files

Browse files
Files changed (2) hide show
  1. model_Custm.py +76 -261
  2. model_manager.py +28 -1
model_Custm.py CHANGED
@@ -263,6 +263,9 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
263
  self.classifier = nn.Linear(embedding_dim, self.vocab_size)
264
  self.dropout_layer = nn.Dropout(dropout)
265
 
 
 
 
266
  self.init_weights()
267
 
268
  def init_weights(self) -> None:
@@ -280,17 +283,19 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
280
 
281
  def forward(
282
  self,
283
- input_ids=None,
284
- attention_mask=None,
285
- labels=None,
286
- src=None,
287
- tgt=None,
288
- src_key_padding_mask=None,
289
- tgt_key_padding_mask=None,
290
- memory_key_padding_mask=None,
291
- return_sequence=False,
292
- **kwargs
293
- ):
 
 
294
  try:
295
  # Log input shapes for debugging
296
  logger.info(f"Input shapes - src: {src.shape if src is not None else None}, tgt: {tgt.shape if tgt is not None else None}")
@@ -308,288 +313,98 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
308
 
309
  # Pass through encoder layers
310
  memory = src_embeddings
311
- for enc_layer in self.encoder_layers:
312
- memory = enc_layer(memory)
313
-
314
- # Pass through decoder layers (ensuring we maintain 3D shape)
315
- # This maintains [batch_size, seq_length, hidden_dim]
316
- output = memory
317
- for dec_layer in self.decoder_layers:
318
- output = dec_layer(output)
319
-
320
- # Apply final projection to vocabulary space
321
- # This should result in [batch_size, seq_length, vocab_size]
322
- output = self.final_layer(output)
323
-
324
- # CRITICAL: Ensure we keep the 3D shape for language modeling
325
- # Check if we have a 2D tensor and reshape if needed
326
- if output.dim() == 2:
327
- # If 2D tensor [batch_size, vocab_size], reshape to 3D [batch_size, 1, vocab_size]
328
- batch_size, vocab_size = output.shape
329
- logger.info(f"Reshaping 2D output {output.shape} to 3D tensor")
330
- output = output.unsqueeze(1) # Add sequence dimension
331
- logger.info(f"Reshaped output: {output.shape}")
332
-
333
- # Record the output shape and dimensions for debugging
334
- logger.info(f"Output shape: {output.shape}, dimensions: {output.dim()}")
335
-
336
- # Calculate loss if labels are provided
337
- loss = None
338
- if labels is not None:
339
- # Check output shape
340
- if output.dim() == 3: # [batch_size, seq_length, vocab_size]
341
- batch_size, seq_length, vocab_size = output.shape
342
- logger.info(f"3D tensor: batch_size={batch_size}, seq_length={seq_length}, vocab_size={vocab_size}")
343
-
344
- # Check if labels are flattened
345
- expected_flattened_size = batch_size * seq_length
346
- is_flattened_labels = (labels.dim() == 1 and labels.size(0) == expected_flattened_size)
347
-
348
- if is_flattened_labels:
349
- # Reshape output to match flattened labels
350
- output_reshaped = output.reshape(-1, vocab_size)
351
- loss_fct = nn.CrossEntropyLoss()
352
- loss = loss_fct(output_reshaped, labels)
353
- else:
354
- # If labels are 2D [batch_size, seq_length], reshape them
355
- if labels.dim() == 2:
356
- # Need to reshape labels to 1D for CrossEntropyLoss
357
- labels_reshaped = labels.reshape(-1)
358
- output_reshaped = output.reshape(-1, vocab_size)
359
- loss_fct = nn.CrossEntropyLoss()
360
- loss = loss_fct(output_reshaped, labels_reshaped)
361
- else:
362
- loss_fct = nn.CrossEntropyLoss()
363
- loss = loss_fct(output.view(-1, vocab_size), labels.view(-1))
364
- elif output.dim() == 2: # [batch_size, vocab_size]
365
- batch_size, vocab_size = output.shape
366
- logger.info(f"2D tensor: batch_size={batch_size}, vocab_size={vocab_size}")
367
-
368
- # Handle 2D output with 1D labels (single prediction per sequence)
369
- if labels.dim() == 1 and labels.size(0) == batch_size:
370
- loss_fct = nn.CrossEntropyLoss()
371
- loss = loss_fct(output, labels)
372
- else:
373
- # If we have 1D labels but with seq_length*batch_size elements
374
- logger.warning(f"Label shape {labels.shape} incompatible with output {output.shape}")
375
- # Take just the first token prediction for each sequence
376
- labels_reshaped = labels.view(batch_size, -1)[:, 0]
377
- loss_fct = nn.CrossEntropyLoss()
378
- loss = loss_fct(output, labels_reshaped)
379
-
380
- # Return the proper format
381
- if loss is not None:
382
- logger.info(f"Returning loss tensor: {loss.item()}")
383
- return loss, output
384
- else:
385
- return output
386
 
387
- except Exception as e:
388
- logger.error(f"Error in forward pass: {str(e)}")
389
- logger.error(f"Traceback: {traceback.format_exc()}")
390
 
391
- # Log input shapes for debugging
392
- logger.error(f"Input shapes - src: {src.shape if src is not None else None}, input_ids: {input_ids.shape if input_ids is not None else None}")
393
-
394
- # Ensure we return a proper tuple with correct types even in error case
395
- dummy_output = torch.zeros(1)
396
- dummy_loss = torch.tensor(float('nan'))
397
- return dummy_loss, dummy_output
398
-
399
- def forward(
400
- self,
401
- src: torch.Tensor = None,
402
- tgt: Optional[torch.Tensor] = None,
403
- token_type_ids: Optional[torch.Tensor] = None,
404
- src_mask: Optional[torch.Tensor] = None,
405
- tgt_mask: Optional[torch.Tensor] = None,
406
- src_key_padding_mask: Optional[torch.Tensor] = None,
407
- tgt_key_padding_mask: Optional[torch.Tensor] = None,
408
- return_sequence: bool = False,
409
- # Add Hugging Face compatibility parameters
410
- input_ids: Optional[torch.Tensor] = None,
411
- attention_mask: Optional[torch.Tensor] = None,
412
- labels: Optional[torch.Tensor] = None,
413
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], ModelOutput]:
414
- try:
415
- # Add this at the start of the forward method to log input shapes
416
- logger.info(f"Input shapes - src: {src.shape if src is not None else None}, tgt: {tgt.shape if tgt is not None else None}")
417
-
418
- # Use Hugging Face parameters if provided
419
- if src is None and input_ids is not None:
420
- src = input_ids
421
- if src_key_padding_mask is None and attention_mask is not None:
422
- src_key_padding_mask = attention_mask
423
-
424
- # Handle input shape - our layers expect batch_first=True format
425
- if src.dim() == 2:
426
- # src is already [batch_size, seq_len]
427
- pass
428
- elif src.dim() == 3 and src.size(0) > src.size(1):
429
- # src is [seq_len, batch_size, dim] - need to transpose
430
- src = src.transpose(0, 1)
431
-
432
- # ----------------------------
433
- # Encoder: Custom processing of source
434
- # ----------------------------
435
- src_emb = self.embedding(src) * math.sqrt(self.embedding_dim)
436
- src_emb = self.pos_encoder(src_emb.transpose(0, 1)).transpose(0, 1) # Apply positional encoding
437
-
438
- # Use hybrid attention if sequence length is above the threshold
439
  if src.size(1) > 256 and hasattr(self, 'hybrid_attention'):
440
  # Prepare inputs for hybrid attention
441
- query = src_emb.transpose(0, 1) # Ensure shape is [seq_len, batch, dim]
442
  key = query
443
  value = query
444
 
445
- # Apply smart hybrid attention - FIX: properly handle any return format
 
 
 
 
 
 
446
  hybrid_outputs = self.hybrid_attention(
447
  query=query,
448
  key=key,
449
  value=value,
450
  key_padding_mask=src_key_padding_mask,
451
- attn_mask=src_mask,
452
  prompt_length=src.size(1),
453
- prompt_complexity=0.5 # Default value, can be computed based on input
454
  )
455
 
456
- # FIX: Handle all possible return types from hybrid_attention
457
- if isinstance(hybrid_outputs, tuple):
458
- # If it returns a tuple, the first element is the attended output
459
- attended_output = hybrid_outputs[0]
460
- logger.debug(f"Hybrid attention returned tuple of length {len(hybrid_outputs)}")
461
- else:
462
- # If it returns a tensor directly
463
- attended_output = hybrid_outputs
464
- logger.debug("Hybrid attention returned single tensor")
465
-
466
- # Convert back to expected format
467
- encoded_src = attended_output.transpose(0, 1)
468
- else:
469
- # Use standard transformer encoder for shorter sequences
470
- encoded_src = self.transformer_encoder(src_emb, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
471
-
472
- # Process through adapter layer
473
- adapted = self.adapter(encoded_src)
474
-
475
- # ----------------------------
476
- # Decoder / Output
477
- # ----------------------------
478
- if tgt is not None:
479
- # Handle tgt shape for batch_first format
480
- if tgt.dim() == 2:
481
- # tgt is already [batch_size, seq_len]
482
- pass
483
- elif tgt.dim() == 3 and tgt.size(0) > tgt.size(1):
484
- # tgt is [seq_len, batch_size, dim] - need to transpose
485
- tgt = tgt.transpose(0, 1)
486
-
487
- tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.embedding_dim)
488
- tgt_emb = self.pos_decoder(tgt_emb.transpose(0, 1)).transpose(0, 1) # Apply positional encoding
489
 
490
- decoded = self.transformer_decoder(
491
- tgt_emb,
492
- adapted,
493
- tgt_mask=tgt_mask,
494
- memory_key_padding_mask=src_key_padding_mask,
495
- tgt_key_padding_mask=tgt_key_padding_mask
496
- )
497
 
498
- output = self.classifier(decoded) # [batch_size, seq_len, output_size]
499
-
500
- if not return_sequence:
501
- output = output.mean(dim=1) # Average over sequence dimension
502
- else:
503
- # For encoder-only tasks (e.g., classification)
504
- if self.pooling_mode == "mean":
505
- pooled = encoded_src.mean(dim=1)
506
- elif self.pooling_mode == "max":
507
- pooled = torch.max(encoded_src, dim=1)[0]
508
- elif self.pooling_mode == "cls":
509
- pooled = encoded_src[:, 0] # Use first token (CLS) - batch_first format
510
- else:
511
- pooled = encoded_src.mean(dim=1)
512
- pooled = self.dropout_layer(pooled)
513
- output = self.classifier(pooled)
514
 
515
  # Calculate loss if labels are provided
516
  loss = None
517
  if labels is not None:
518
- # More defensive shape handling - check shape dimensions first
519
- if not isinstance(output, torch.Tensor):
520
- logger.error(f"Output is not a tensor, got {type(output)}")
521
- return torch.tensor(0.0), output # Return dummy loss
522
-
523
- # Check output shape and handle multiple possible dimensions
524
- output_dim = output.dim()
525
- logger.info(f"Output shape: {output.shape}, dimensions: {output_dim}")
526
 
527
- if output_dim == 3: # [batch_size, seq_length, vocab_size]
528
- batch_size, seq_length, vocab_size = output.shape
529
-
530
- # Debug logging
531
- logger.info(f"3D tensor: batch_size={batch_size}, seq_length={seq_length}, vocab_size={vocab_size}")
532
-
533
- # Check labels shape
534
- expected_flattened_size = batch_size * seq_length
535
- is_flattened_labels = (labels.dim() == 1 and labels.size(0) == expected_flattened_size)
536
-
537
- # Reshape output from [batch_size, seq_length, vocab_size] to [batch_size*seq_length, vocab_size]
538
- output_reshaped = output.reshape(-1, vocab_size)
539
-
540
- # Calculate loss
541
- loss_fct = nn.CrossEntropyLoss()
542
- loss = loss_fct(output_reshaped, labels)
543
- elif output_dim == 2: # [batch_size, vocab_size]
544
- # Already shaped for loss calculation
545
- batch_size, vocab_size = output.shape
546
- logger.info(f"2D tensor: batch_size={batch_size}, vocab_size={vocab_size}")
547
-
548
- # Need to reshape labels to 1D - this is the critical fix
549
- if labels.dim() > 1: # If labels are multi-dimensional
550
- # Language modeling usually needs the last token prediction
551
- # Get the last token's label from each sequence
552
- labels = labels[:, -1] # Take the last token from each sequence
553
- logger.info(f"Reshaped labels to {labels.shape}")
554
-
555
- # Now calculate loss with properly shaped tensors
556
- loss_fct = nn.CrossEntropyLoss()
557
- loss = loss_fct(output, labels)
558
- else:
559
- logger.error(f"Unexpected output dimensions: {output_dim}")
560
- # Create a dummy loss to avoid breaking the training loop
561
- loss = torch.tensor(0.0, requires_grad=True, device=output.device)
562
 
563
  # Return the proper format
564
  if loss is not None:
565
- logger.info(f"Returning loss tensor: {loss.item()}")
566
  return loss, output
567
  else:
568
  return output
569
-
570
  except Exception as e:
571
- # Detailed error logging for debugging
572
- logger.error(f"Error in forward pass: {e}")
573
  logger.error(f"Traceback: {traceback.format_exc()}")
574
- logger.error(f"Input shapes - src: {src.shape if src is not None else None}, "
575
- f"input_ids: {input_ids.shape if input_ids is not None else None}")
576
 
577
- # Create minimal output to prevent cascading errors, matching expected return format
578
- batch_size = src.shape[0] if src is not None else (input_ids.shape[0] if input_ids is not None else 1)
579
- dummy_output = torch.zeros((batch_size, self.output_size), device=next(self.parameters()).device)
580
 
581
- # Return in expected format to avoid "too many values to unpack" errors
582
- if labels is not None:
583
- # Match (loss, logits) format
584
- dummy_loss = torch.tensor(999.0, device=next(self.parameters()).device)
585
- return (dummy_loss, dummy_output)
586
- else:
587
- # Match object with logits attribute
588
- class SimpleModelOutput:
589
- def __init__(self, logits):
590
- self.logits = logits
591
- return SimpleModelOutput(dummy_output)
592
-
593
  # Add sentence transformer methods
594
  def encode_sentences(self, sentences, batch_size=32, normalize_embeddings=True):
595
  """Encode sentences into vectors (sentence transformer functionality)"""
 
263
  self.classifier = nn.Linear(embedding_dim, self.vocab_size)
264
  self.dropout_layer = nn.Dropout(dropout)
265
 
266
+ # This is a standard linear layer that reshapes 3D input to 2D output:
267
+ self.final_layer = nn.Linear(hidden_dim, vocab_size)
268
+
269
  self.init_weights()
270
 
271
  def init_weights(self) -> None:
 
283
 
284
  def forward(
285
  self,
286
+ src: torch.Tensor = None,
287
+ tgt: Optional[torch.Tensor] = None,
288
+ token_type_ids: Optional[torch.Tensor] = None,
289
+ src_mask: Optional[torch.Tensor] = None, # Make sure to include this parameter
290
+ tgt_mask: Optional[torch.Tensor] = None,
291
+ src_key_padding_mask: Optional[torch.Tensor] = None,
292
+ tgt_key_padding_mask: Optional[torch.Tensor] = None,
293
+ return_sequence: bool = False,
294
+ # Add Hugging Face compatibility parameters
295
+ input_ids: Optional[torch.Tensor] = None,
296
+ attention_mask: Optional[torch.Tensor] = None,
297
+ labels: Optional[torch.Tensor] = None,
298
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], ModelOutput]:
299
  try:
300
  # Log input shapes for debugging
301
  logger.info(f"Input shapes - src: {src.shape if src is not None else None}, tgt: {tgt.shape if tgt is not None else None}")
 
313
 
314
  # Pass through encoder layers
315
  memory = src_embeddings
316
+ # Ensure memory maintains 3 dimensions [batch_size, seq_length, hidden_dim]
317
+ if memory.dim() == 2:
318
+ memory = memory.unsqueeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ # Use self.transformer_encoder instead of self.encoder_layers (which doesn't exist)
321
+ encoded_src = self.transformer_encoder(memory)
 
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  if src.size(1) > 256 and hasattr(self, 'hybrid_attention'):
324
  # Prepare inputs for hybrid attention
325
+ query = src_embeddings.transpose(0, 1)
326
  key = query
327
  value = query
328
 
329
+ # IMPORTANT: Initialize src_mask if it's None
330
+ if src_mask is None and src is not None:
331
+ # Create a default mask that allows all tokens to attend to all other tokens
332
+ src_seq_len = src.size(1)
333
+ src_mask = torch.zeros((src_seq_len, src_seq_len), device=src.device, dtype=torch.bool)
334
+
335
+ # Actually using the hybrid attention here!
336
  hybrid_outputs = self.hybrid_attention(
337
  query=query,
338
  key=key,
339
  value=value,
340
  key_padding_mask=src_key_padding_mask,
341
+ attn_mask=src_mask, # Now src_mask is properly defined
342
  prompt_length=src.size(1),
343
+ prompt_complexity=0.5
344
  )
345
 
346
+ # Process the hybrid attention outputs
347
+ encoded_src = hybrid_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
+ # Pass through decoder layers
350
+ output = encoded_src
351
+ # Ensure output maintains 3 dimensions [batch_size, seq_length, hidden_dim]
352
+ if output.dim() == 2:
353
+ output = output.unsqueeze(1)
 
 
354
 
355
+ # Apply final projection to vocabulary space
356
+ output = self.final_layer(output)
357
+
358
+ # CRITICAL: Ensure output is always 3D [batch_size, seq_length, vocab_size]
359
+ if output.dim() == 2:
360
+ # If 2D tensor [batch_size, vocab_size], reshape to 3D [batch_size, 1, vocab_size]
361
+ batch_size, vocab_size = output.shape
362
+ logger.info(f"2D tensor: batch_size={batch_size}, vocab_size={vocab_size}")
363
+ output = output.unsqueeze(1) # Add sequence dimension
364
+ logger.info(f"Reshaped 2D output to 3D tensor: {output.shape}")
365
+
366
+ # Record the output shape and dimensions for debugging
367
+ logger.info(f"Output shape: {output.shape}, dimensions: {output.dim()}")
 
 
 
368
 
369
  # Calculate loss if labels are provided
370
  loss = None
371
  if labels is not None:
372
+ # Reshape labels to 1D if needed
373
+ if labels.dim() > 1:
374
+ labels = labels.reshape(-1)
375
+ logger.info(f"Reshaped labels to {labels.shape}")
 
 
 
 
376
 
377
+ # Calculate loss with properly shaped tensors
378
+ batch_size, seq_length, vocab_size = output.shape
379
+ loss_fct = nn.CrossEntropyLoss()
380
+ loss = loss_fct(output.reshape(-1, vocab_size), labels)
381
+ logger.info(f"Returning loss tensor: {loss.item()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  # Return the proper format
384
  if loss is not None:
 
385
  return loss, output
386
  else:
387
  return output
388
+
389
  except Exception as e:
390
+ logger.error(f"Error in forward pass: {str(e)}")
 
391
  logger.error(f"Traceback: {traceback.format_exc()}")
 
 
392
 
393
+ # Log input shapes for debugging
394
+ logger.error(f"Input shapes - src: {src.shape if src is not None else None}, input_ids: {input_ids.shape if input_ids is not None else None}")
 
395
 
396
+ # Create minimal dummy outputs in correct format
397
+ dummy_batch = 1
398
+ if src is not None:
399
+ dummy_batch = src.shape[0]
400
+ elif input_ids is not None:
401
+ dummy_batch = input_ids.shape[0]
402
+
403
+ # CRITICAL: Return a proper 3D tensor even in error case
404
+ dummy_output = torch.zeros((dummy_batch, 1, self.vocab_size), device=next(self.parameters()).device)
405
+ dummy_loss = torch.tensor(float('nan'), device=next(self.parameters()).device)
406
+ return dummy_loss, dummy_output
407
+
408
  # Add sentence transformer methods
409
  def encode_sentences(self, sentences, batch_size=32, normalize_embeddings=True):
410
  """Encode sentences into vectors (sentence transformer functionality)"""
model_manager.py CHANGED
@@ -341,7 +341,34 @@ class ModelManager:
341
  return model
342
 
343
  def route_input(self, input_text: str) -> dict:
 
344
  input_embedding = self.embedding_model.encode(input_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  similarities = {}
346
  for spec in self.specializations:
347
  model = self.get_model(spec)
@@ -778,4 +805,4 @@ def register_models():
778
  return True
779
  except Exception as e:
780
  logger.error(f"Failed to register models: {e}")
781
- return False
 
341
  return model
342
 
343
  def route_input(self, input_text: str) -> dict:
344
+ # Create embedding for input text
345
  input_embedding = self.embedding_model.encode(input_text)
346
+
347
+ # NEW: Process input through SmartHybridAttention for enhanced understanding
348
+ if hasattr(self, 'smart_attention') and self.smart_attention:
349
+ try:
350
+ # Convert embedding to tensor format needed by attention
351
+ import torch
352
+ input_tensor = torch.tensor(input_embedding).unsqueeze(0).unsqueeze(0) # [1, 1, dim]
353
+
354
+ # Process through attention mechanism to extract key patterns
355
+ # This helps identify which parts of input are most relevant
356
+ enhanced = self.smart_attention(
357
+ query=input_tensor,
358
+ key=input_tensor,
359
+ value=input_tensor
360
+ )
361
+
362
+ # Convert back to numpy for similarity calculations
363
+ if isinstance(enhanced, torch.Tensor):
364
+ enhanced_embedding = enhanced.squeeze().cpu().numpy()
365
+ # Use enhanced embedding for similarity calculation
366
+ input_embedding = enhanced_embedding
367
+ logger.info("Using SmartHybridAttention for enhanced prompt routing")
368
+ except Exception as e:
369
+ logger.warning(f"Error using SmartHybridAttention: {e}")
370
+
371
+ # Continue with existing similarity calculation
372
  similarities = {}
373
  for spec in self.specializations:
374
  model = self.get_model(spec)
 
805
  return True
806
  except Exception as e:
807
  logger.error(f"Failed to register models: {e}")
808
+ return False