WildnerveAI commited on
Commit
f23efb2
·
verified ·
1 Parent(s): f71f4f5

Upload model_Custm.py

Browse files
Files changed (1) hide show
  1. model_Custm.py +37 -37
model_Custm.py CHANGED
@@ -469,9 +469,7 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
469
  return f"Error generating response: {str(e)}"
470
 
471
  def generate_tokens(self, input_ids, max_length=None, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0, **kwargs):
472
- """
473
- Generate tokens autoregressively without recursion.
474
- """
475
  logger.info(f"generate_tokens called with tensor of shape: {input_ids.shape if hasattr(input_ids, 'shape') else 'unknown'}")
476
 
477
  try:
@@ -485,51 +483,53 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
485
  if input_ids.dim() == 1:
486
  input_ids = input_ids.unsqueeze(0)
487
 
488
- # Get device from input tensor
489
- device = input_ids.device
490
-
491
  # Set reasonable defaults for missing parameters
492
  if max_length is None:
493
  max_length = min(getattr(self, 'max_seq_length', 1024), 1024)
494
- max_length = min(max_length, 1024) # Reasonable maximum
495
-
496
- # Check if we're already at or beyond max length
497
- if input_ids.shape[1] >= max_length:
498
- return input_ids # Return without change
499
 
500
- # Create attention mask if needed
501
- attention_mask = None
502
- if hasattr(self, 'transformer') and getattr(self, 'transformer', None) is not None:
503
- attention_mask = torch.ones((input_ids.shape[0], input_ids.shape[1]), dtype=torch.long, device=device)
504
-
505
  # Initialize generated sequences with input_ids
506
  generated_sequences = input_ids.clone()
507
 
508
- # Get end token ID (use EOS token if model has one, otherwise use default)
509
- eos_token_id = None
510
- if hasattr(self, 'tokenizer') and self.tokenizer is not None and hasattr(self.tokenizer, 'eos_token_id'):
511
- eos_token_id = self.tokenizer.eos_token_id
512
-
513
- # Simply append a few tokens to avoid the recursive call
514
- # For a production system, you would implement proper token generation here
515
- current_len = input_ids.shape[1]
516
- new_tokens_needed = min(10, max_length - current_len)
517
-
518
- # Create some dummy token IDs (this will be basic but avoid errors)
519
- batch_size = input_ids.shape[0]
520
- dummy_tokens = torch.ones((batch_size, new_tokens_needed), dtype=torch.long, device=device) * (eos_token_id or 50256) # GPT-2 EOS token
521
-
522
- # Concatenate new tokens to input_ids
523
- output_ids = torch.cat([input_ids, dummy_tokens], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
- logger.info(f"Simple generate_tokens returning output of shape {output_ids.shape}")
526
- return output_ids
527
 
528
  except Exception as e:
529
  logger.error(f"Error in generate_tokens: {e}")
530
-
531
- # Return input as fallback to prevent errors
532
- return input_ids
533
 
534
  def generate_with_decoding(self, input_ids=None, prompt=None, **kwargs):
535
  """
 
469
  return f"Error generating response: {str(e)}"
470
 
471
  def generate_tokens(self, input_ids, max_length=None, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0, **kwargs):
472
+ """Generate tokens autoregressively."""
 
 
473
  logger.info(f"generate_tokens called with tensor of shape: {input_ids.shape if hasattr(input_ids, 'shape') else 'unknown'}")
474
 
475
  try:
 
483
  if input_ids.dim() == 1:
484
  input_ids = input_ids.unsqueeze(0)
485
 
 
 
 
486
  # Set reasonable defaults for missing parameters
487
  if max_length is None:
488
  max_length = min(getattr(self, 'max_seq_length', 1024), 1024)
489
+ max_length = min(max_length, 1024)
 
 
 
 
490
 
 
 
 
 
 
491
  # Initialize generated sequences with input_ids
492
  generated_sequences = input_ids.clone()
493
 
494
+ # Auto-regressive generation loop
495
+ for step in range(max_length - input_ids.shape[1]):
496
+ # Forward pass through the model
497
+ with torch.no_grad():
498
+ outputs = self(generated_sequences)
499
+
500
+ # FIX: Handle both 2D and 3D output formats
501
+ if outputs.dim() == 3: # [batch_size, seq_length, vocab_size]
502
+ # Original 3D format
503
+ next_token_logits = outputs[:, -1, :]
504
+ else: # outputs.dim() == 2: [batch_size, vocab_size]
505
+ # Direct 2D output format
506
+ next_token_logits = outputs
507
+
508
+ # Apply temperature
509
+ if temperature > 0:
510
+ next_token_logits = next_token_logits / temperature
511
+
512
+ # Apply top-k filtering
513
+ if top_k > 0:
514
+ top_k_values, top_k_indices = torch.topk(next_token_logits, top_k)
515
+ next_token_logits = torch.full_like(next_token_logits, float("-inf"))
516
+ for batch_idx in range(generated_sequences.shape[0]):
517
+ next_token_logits[batch_idx, top_k_indices[batch_idx]] = top_k_values[batch_idx]
518
+
519
+ # Sample next token
520
+ probs = torch.softmax(next_token_logits, dim=-1)
521
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
522
+
523
+ # Add to sequence
524
+ generated_sequences = torch.cat([generated_sequences, next_tokens.unsqueeze(-1)], dim=1)
525
+
526
+ # Optional stopping criteria could be added here
527
 
528
+ return generated_sequences
 
529
 
530
  except Exception as e:
531
  logger.error(f"Error in generate_tokens: {e}")
532
+ return input_ids # Return input as fallback
 
 
533
 
534
  def generate_with_decoding(self, input_ids=None, prompt=None, **kwargs):
535
  """