WildnerveAI commited on
Commit
4c26014
·
verified ·
1 Parent(s): 812bb66

Upload model_Custm.py

Browse files
Files changed (1) hide show
  1. model_Custm.py +48 -51
model_Custm.py CHANGED
@@ -393,78 +393,75 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
393
  # Calculate loss if labels are provided
394
  loss = None
395
  if labels is not None:
396
- # Get shapes for debugging
397
- logger.debug(f"Output shape: {output.shape}, Labels shape: {labels.shape}")
398
-
399
  # Create loss function
400
  loss_fct = nn.CrossEntropyLoss()
401
 
402
- # Handle shape mismatches properly
403
- if output.dim() == 3: # [batch, seq, vocab]
 
404
  batch_size, seq_len, vocab_size = output.size()
405
- # Reshape to [batch*seq, vocab]
406
- output_flat = output.reshape(-1, vocab_size)
407
 
408
- # If labels are way bigger than our batch size, something is wrong
409
- # in the training loop, but we'll try to handle it gracefully
410
- if labels.size(0) > batch_size * seq_len:
411
- # Calculate target size
412
- target_size = output_flat.size(0)
413
- # Take just enough labels to match our flattened output
414
- if labels.size(0) >= target_size:
415
- labels = labels[:target_size]
416
- else:
417
- # Pad labels if needed
418
- padding = torch.zeros(target_size - labels.size(0),
419
- device=labels.device,
420
- dtype=labels.dtype)
421
- labels = torch.cat([labels, padding])
422
- # Calculate loss with proper shapes
423
- loss = loss_fct(output_flat, labels.view(-1))
424
- else: # output is [batch, vocab]
425
- # Handle excessive label size similar to above
426
- if labels.size(0) > output.size(0):
427
- labels = labels[:output.size(0)]
428
  loss = loss_fct(output, labels)
429
 
 
 
 
 
 
430
  # Return in HuggingFace format
431
  if loss is not None:
432
- return (loss, output)
433
  else:
434
- # Create a compatible output object
435
- class SimpleModelOutput:
436
- def __init__(self, logits):
437
- self.logits = logits
438
- def __getitem__(self, idx):
439
- if idx == 0: return None # Return None for loss
440
- elif idx == 1: return self.logits
441
- raise IndexError("Index out of range")
442
 
443
- return SimpleModelOutput(output)
444
-
445
  except Exception as e:
446
- # Detailed error logging for debugging
447
- logger.error(f"Error in forward pass: {e}")
448
- logger.error(f"Traceback: {traceback.format_exc()}")
449
- logger.error(f"Input shapes - src: {src.shape if src is not None else None}, "
450
- f"input_ids: {input_ids.shape if input_ids is not None else None}")
451
 
452
- # Create minimal output to prevent cascading errors, matching expected return format
453
- batch_size = src.shape[0] if src is not None else (input_ids.shape[0] if input_ids is not None else 1)
454
- dummy_output = torch.zeros((batch_size, self.output_size), device=next(self.parameters()).device)
455
 
456
- # Return in expected format to avoid "too many values to unpack" errors
 
 
 
 
 
 
 
 
457
  if labels is not None:
458
- # Match (loss, logits) format
459
- dummy_loss = torch.tensor(999.0, device=next(self.parameters()).device)
 
 
 
 
 
 
 
460
  return (dummy_loss, dummy_output)
461
  else:
462
- # Match object with logits attribute
463
  class SimpleModelOutput:
464
  def __init__(self, logits):
465
  self.logits = logits
466
  return SimpleModelOutput(dummy_output)
467
-
468
  # Add sentence transformer methods
469
  def encode_sentences(self, sentences, batch_size=32, normalize_embeddings=True):
470
  """Encode sentences into vectors (sentence transformer functionality)"""
 
393
  # Calculate loss if labels are provided
394
  loss = None
395
  if labels is not None:
 
 
 
396
  # Create loss function
397
  loss_fct = nn.CrossEntropyLoss()
398
 
399
+ # CRITICAL FIX: Debug shape information
400
+ batch_size, seq_len = None, None
401
+ if output.dim() == 3:
402
  batch_size, seq_len, vocab_size = output.size()
403
+ logger.debug(f"3D Output shape: {output.shape}, Labels shape: {labels.shape}")
 
404
 
405
+ # Fix for the target batch size mismatch (12 vs 9204, 16 vs 12272, etc.)
406
+ # If labels are flattened but output isn't, reshape output to match
407
+ if labels.size(0) == batch_size * seq_len:
408
+ # This means labels are already flattened to [batch_size*seq_len]
409
+ flattened_output = output.view(-1, output.size(-1))
410
+ loss = loss_fct(flattened_output, labels)
411
+ # Return explicitly formatted for HuggingFace compatibility
412
+ return (loss, output)
413
+ else:
414
+ # Regular case - reshape both
415
+ flattened_output = output.view(-1, output.size(-1))
416
+ flattened_labels = labels.view(-1)
417
+ loss = loss_fct(flattened_output, flattened_labels)
418
+ else:
419
+ # For classification (2D output)
 
 
 
 
 
420
  loss = loss_fct(output, labels)
421
 
422
+ # Simple object with logits attribute for HuggingFace compatibility
423
+ class SimpleModelOutput:
424
+ def __init__(self, logits):
425
+ self.logits = logits
426
+
427
  # Return in HuggingFace format
428
  if loss is not None:
429
+ return (loss, output) # Return tuple
430
  else:
431
+ return SimpleModelOutput(output) # Return object with logits attribute
 
 
 
 
 
 
 
432
 
 
 
433
  except Exception as e:
434
+ logger.error(f"Error in forward pass: {e}", exc_info=True)
 
 
 
 
435
 
436
+ # Create fallback outputs that match expected formats
437
+ device = next(self.parameters()).device if hasattr(self, 'parameters') else torch.device('cpu')
 
438
 
439
+ # Get batch size from inputs
440
+ if src is not None:
441
+ batch_size = src.size(0)
442
+ elif input_ids is not None:
443
+ batch_size = input_ids.size(0)
444
+ else:
445
+ batch_size = 1
446
+
447
+ # Log input/target shapes for debugging
448
  if labels is not None:
449
+ logger.error(f"Input shapes - batch_size: {batch_size}, labels: {labels.shape}")
450
+
451
+ # Create dummy output with correct vocab size
452
+ vocab_size = self.output_size if hasattr(self, 'output_size') else 50257
453
+ dummy_output = torch.zeros((batch_size, vocab_size), device=device)
454
+
455
+ # Match the expected return format
456
+ if labels is not None:
457
+ dummy_loss = torch.tensor(999.0, device=device)
458
  return (dummy_loss, dummy_output)
459
  else:
 
460
  class SimpleModelOutput:
461
  def __init__(self, logits):
462
  self.logits = logits
463
  return SimpleModelOutput(dummy_output)
464
+
465
  # Add sentence transformer methods
466
  def encode_sentences(self, sentences, batch_size=32, normalize_embeddings=True):
467
  """Encode sentences into vectors (sentence transformer functionality)"""