WildnerveAI commited on
Commit
1acab08
·
verified ·
1 Parent(s): 2de4a17

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.py +6 -4
  2. model_Custm.py +22 -21
  3. model_manager.py +2 -3
config.py CHANGED
@@ -511,12 +511,14 @@ def get_model_architecture_params():
511
  """Get model architecture parameters from config file"""
512
  if hasattr(app_config, "TRANSFORMER_CONFIG"):
513
  tc = app_config.TRANSFORMER_CONFIG
 
 
514
  return {
515
  "vocab_size": getattr(tc, "VOCAB_SIZE", 50257),
516
- "embedding_dim": getattr(tc, "EMBEDDING_DIM", 768),
517
- "num_heads": getattr(tc, "NUM_HEADS", 12),
518
- "hidden_dim": getattr(tc, "HIDDEN_DIM", 768),
519
- "num_layers": getattr(tc, "NUM_LAYERS", 12),
520
  "output_size": getattr(tc, "VOCAB_SIZE", 50257),
521
  "dropout": getattr(tc, "DROPOUT", 0.1),
522
  "max_seq_length": getattr(tc, "MAX_SEQ_LENGTH", 512)
 
511
  """Get model architecture parameters from config file"""
512
  if hasattr(app_config, "TRANSFORMER_CONFIG"):
513
  tc = app_config.TRANSFORMER_CONFIG
514
+ # CRITICAL: Ensure we ALWAYS get 768 for embedding_dim and hidden_dim
515
+ # This avoids issues with dimension mismatches between 512 and 768
516
  return {
517
  "vocab_size": getattr(tc, "VOCAB_SIZE", 50257),
518
+ "embedding_dim": 768, # Fixed to 768 to prevent mismatches
519
+ "num_heads": 12, # 12 heads works with 768 (768/12=64)
520
+ "hidden_dim": 768, # Fixed to 768 to prevent mismatches
521
+ "num_layers": getattr(tc, "NUM_LAYERS", 12),
522
  "output_size": getattr(tc, "VOCAB_SIZE", 50257),
523
  "dropout": getattr(tc, "DROPOUT", 0.1),
524
  "max_seq_length": getattr(tc, "MAX_SEQ_LENGTH", 512)
model_Custm.py CHANGED
@@ -283,19 +283,17 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
283
 
284
  def forward(
285
  self,
286
- src: torch.Tensor = None,
287
- tgt: Optional[torch.Tensor] = None,
288
- token_type_ids: Optional[torch.Tensor] = None,
289
- src_mask: Optional[torch.Tensor] = None, # Make sure to include this parameter
290
- tgt_mask: Optional[torch.Tensor] = None,
291
- src_key_padding_mask: Optional[torch.Tensor] = None,
292
- tgt_key_padding_mask: Optional[torch.Tensor] = None,
293
- return_sequence: bool = False,
294
- # Add Hugging Face compatibility parameters
295
- input_ids: Optional[torch.Tensor] = None,
296
- attention_mask: Optional[torch.Tensor] = None,
297
- labels: Optional[torch.Tensor] = None,
298
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], ModelOutput]:
299
  try:
300
  # Log input shapes for debugging
301
  logger.info(f"Input shapes - src: {src.shape if src is not None else None}, tgt: {tgt.shape if tgt is not None else None}")
@@ -312,13 +310,8 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
312
  src_embeddings = self.pos_encoder(src_embeddings)
313
 
314
  # Pass through encoder layers
315
- memory = src_embeddings
316
- # Ensure memory maintains 3 dimensions [batch_size, seq_length, hidden_dim]
317
- if memory.dim() == 2:
318
- memory = memory.unsqueeze(1)
319
-
320
- # Use self.transformer_encoder instead of self.encoder_layers (which doesn't exist)
321
- encoded_src = self.transformer_encoder(memory)
322
 
323
  if src.size(1) > 256 and hasattr(self, 'hybrid_attention'):
324
  # Prepare inputs for hybrid attention
@@ -347,7 +340,15 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
347
  encoded_src = hybrid_outputs
348
 
349
  # Pass through decoder layers
350
- output = encoded_src
 
 
 
 
 
 
 
 
351
  # Ensure output maintains 3 dimensions [batch_size, seq_length, hidden_dim]
352
  if output.dim() == 2:
353
  output = output.unsqueeze(1)
 
283
 
284
  def forward(
285
  self,
286
+ input_ids=None,
287
+ attention_mask=None,
288
+ labels=None,
289
+ src=None,
290
+ tgt=None,
291
+ src_key_padding_mask=None,
292
+ tgt_key_padding_mask=None,
293
+ memory_key_padding_mask=None,
294
+ return_sequence=False,
295
+ **kwargs
296
+ ):
 
 
297
  try:
298
  # Log input shapes for debugging
299
  logger.info(f"Input shapes - src: {src.shape if src is not None else None}, tgt: {tgt.shape if tgt is not None else None}")
 
310
  src_embeddings = self.pos_encoder(src_embeddings)
311
 
312
  # Pass through encoder layers
313
+ memory = self.transformer_encoder(src_embeddings,
314
+ src_key_padding_mask=src_key_padding_mask)
 
 
 
 
 
315
 
316
  if src.size(1) > 256 and hasattr(self, 'hybrid_attention'):
317
  # Prepare inputs for hybrid attention
 
340
  encoded_src = hybrid_outputs
341
 
342
  # Pass through decoder layers
343
+ if tgt is not None:
344
+ tgt_embeddings = self.tgt_embedding(tgt)
345
+ tgt_embeddings = self.pos_decoder(tgt_embeddings)
346
+ output = self.transformer_decoder(tgt_embeddings, memory,
347
+ tgt_key_padding_mask=tgt_key_padding_mask,
348
+ memory_key_padding_mask=memory_key_padding_mask)
349
+ else:
350
+ output = memory
351
+
352
  # Ensure output maintains 3 dimensions [batch_size, seq_length, hidden_dim]
353
  if output.dim() == 2:
354
  output = output.unsqueeze(1)
model_manager.py CHANGED
@@ -344,7 +344,7 @@ class ModelManager:
344
  # Create embedding for input text
345
  input_embedding = self.embedding_model.encode(input_text)
346
 
347
- # NEW: Process input through SmartHybridAttention for enhanced understanding
348
  if hasattr(self, 'smart_attention') and self.smart_attention:
349
  try:
350
  # Convert embedding to tensor format needed by attention
@@ -352,8 +352,7 @@ class ModelManager:
352
  input_tensor = torch.tensor(input_embedding).unsqueeze(0).unsqueeze(0) # [1, 1, dim]
353
 
354
  # Process through attention mechanism to extract key patterns
355
- # This helps identify which parts of input are most relevant
356
- enhanced = self.smart_attention(
357
  query=input_tensor,
358
  key=input_tensor,
359
  value=input_tensor
 
344
  # Create embedding for input text
345
  input_embedding = self.embedding_model.encode(input_text)
346
 
347
+ # Process input through SmartHybridAttention for enhanced understanding
348
  if hasattr(self, 'smart_attention') and self.smart_attention:
349
  try:
350
  # Convert embedding to tensor format needed by attention
 
352
  input_tensor = torch.tensor(input_embedding).unsqueeze(0).unsqueeze(0) # [1, 1, dim]
353
 
354
  # Process through attention mechanism to extract key patterns
355
+ enhanced, _ = self.smart_attention( # FIXED: Properly unpack tuple
 
356
  query=input_tensor,
357
  key=input_tensor,
358
  value=input_tensor