Upload model_Custm.py
Browse files- model_Custm.py +37 -37
model_Custm.py
CHANGED
|
@@ -469,9 +469,7 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 469 |
return f"Error generating response: {str(e)}"
|
| 470 |
|
| 471 |
def generate_tokens(self, input_ids, max_length=None, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0, **kwargs):
|
| 472 |
-
"""
|
| 473 |
-
Generate tokens autoregressively without recursion.
|
| 474 |
-
"""
|
| 475 |
logger.info(f"generate_tokens called with tensor of shape: {input_ids.shape if hasattr(input_ids, 'shape') else 'unknown'}")
|
| 476 |
|
| 477 |
try:
|
|
@@ -485,51 +483,53 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 485 |
if input_ids.dim() == 1:
|
| 486 |
input_ids = input_ids.unsqueeze(0)
|
| 487 |
|
| 488 |
-
# Get device from input tensor
|
| 489 |
-
device = input_ids.device
|
| 490 |
-
|
| 491 |
# Set reasonable defaults for missing parameters
|
| 492 |
if max_length is None:
|
| 493 |
max_length = min(getattr(self, 'max_seq_length', 1024), 1024)
|
| 494 |
-
max_length = min(max_length, 1024)
|
| 495 |
-
|
| 496 |
-
# Check if we're already at or beyond max length
|
| 497 |
-
if input_ids.shape[1] >= max_length:
|
| 498 |
-
return input_ids # Return without change
|
| 499 |
|
| 500 |
-
# Create attention mask if needed
|
| 501 |
-
attention_mask = None
|
| 502 |
-
if hasattr(self, 'transformer') and getattr(self, 'transformer', None) is not None:
|
| 503 |
-
attention_mask = torch.ones((input_ids.shape[0], input_ids.shape[1]), dtype=torch.long, device=device)
|
| 504 |
-
|
| 505 |
# Initialize generated sequences with input_ids
|
| 506 |
generated_sequences = input_ids.clone()
|
| 507 |
|
| 508 |
-
#
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
-
|
| 526 |
-
return output_ids
|
| 527 |
|
| 528 |
except Exception as e:
|
| 529 |
logger.error(f"Error in generate_tokens: {e}")
|
| 530 |
-
|
| 531 |
-
# Return input as fallback to prevent errors
|
| 532 |
-
return input_ids
|
| 533 |
|
| 534 |
def generate_with_decoding(self, input_ids=None, prompt=None, **kwargs):
|
| 535 |
"""
|
|
|
|
| 469 |
return f"Error generating response: {str(e)}"
|
| 470 |
|
| 471 |
def generate_tokens(self, input_ids, max_length=None, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0, **kwargs):
|
| 472 |
+
"""Generate tokens autoregressively."""
|
|
|
|
|
|
|
| 473 |
logger.info(f"generate_tokens called with tensor of shape: {input_ids.shape if hasattr(input_ids, 'shape') else 'unknown'}")
|
| 474 |
|
| 475 |
try:
|
|
|
|
| 483 |
if input_ids.dim() == 1:
|
| 484 |
input_ids = input_ids.unsqueeze(0)
|
| 485 |
|
|
|
|
|
|
|
|
|
|
| 486 |
# Set reasonable defaults for missing parameters
|
| 487 |
if max_length is None:
|
| 488 |
max_length = min(getattr(self, 'max_seq_length', 1024), 1024)
|
| 489 |
+
max_length = min(max_length, 1024)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
# Initialize generated sequences with input_ids
|
| 492 |
generated_sequences = input_ids.clone()
|
| 493 |
|
| 494 |
+
# Auto-regressive generation loop
|
| 495 |
+
for step in range(max_length - input_ids.shape[1]):
|
| 496 |
+
# Forward pass through the model
|
| 497 |
+
with torch.no_grad():
|
| 498 |
+
outputs = self(generated_sequences)
|
| 499 |
+
|
| 500 |
+
# FIX: Handle both 2D and 3D output formats
|
| 501 |
+
if outputs.dim() == 3: # [batch_size, seq_length, vocab_size]
|
| 502 |
+
# Original 3D format
|
| 503 |
+
next_token_logits = outputs[:, -1, :]
|
| 504 |
+
else: # outputs.dim() == 2: [batch_size, vocab_size]
|
| 505 |
+
# Direct 2D output format
|
| 506 |
+
next_token_logits = outputs
|
| 507 |
+
|
| 508 |
+
# Apply temperature
|
| 509 |
+
if temperature > 0:
|
| 510 |
+
next_token_logits = next_token_logits / temperature
|
| 511 |
+
|
| 512 |
+
# Apply top-k filtering
|
| 513 |
+
if top_k > 0:
|
| 514 |
+
top_k_values, top_k_indices = torch.topk(next_token_logits, top_k)
|
| 515 |
+
next_token_logits = torch.full_like(next_token_logits, float("-inf"))
|
| 516 |
+
for batch_idx in range(generated_sequences.shape[0]):
|
| 517 |
+
next_token_logits[batch_idx, top_k_indices[batch_idx]] = top_k_values[batch_idx]
|
| 518 |
+
|
| 519 |
+
# Sample next token
|
| 520 |
+
probs = torch.softmax(next_token_logits, dim=-1)
|
| 521 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 522 |
+
|
| 523 |
+
# Add to sequence
|
| 524 |
+
generated_sequences = torch.cat([generated_sequences, next_tokens.unsqueeze(-1)], dim=1)
|
| 525 |
+
|
| 526 |
+
# Optional stopping criteria could be added here
|
| 527 |
|
| 528 |
+
return generated_sequences
|
|
|
|
| 529 |
|
| 530 |
except Exception as e:
|
| 531 |
logger.error(f"Error in generate_tokens: {e}")
|
| 532 |
+
return input_ids # Return input as fallback
|
|
|
|
|
|
|
| 533 |
|
| 534 |
def generate_with_decoding(self, input_ids=None, prompt=None, **kwargs):
|
| 535 |
"""
|