Upload 2 files
Browse files- model_Custm.py +76 -261
- model_manager.py +28 -1
model_Custm.py
CHANGED
|
@@ -263,6 +263,9 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 263 |
self.classifier = nn.Linear(embedding_dim, self.vocab_size)
|
| 264 |
self.dropout_layer = nn.Dropout(dropout)
|
| 265 |
|
|
|
|
|
|
|
|
|
|
| 266 |
self.init_weights()
|
| 267 |
|
| 268 |
def init_weights(self) -> None:
|
|
@@ -280,17 +283,19 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 280 |
|
| 281 |
def forward(
|
| 282 |
self,
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
src_key_padding_mask=None,
|
| 289 |
-
tgt_key_padding_mask=None,
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
try:
|
| 295 |
# Log input shapes for debugging
|
| 296 |
logger.info(f"Input shapes - src: {src.shape if src is not None else None}, tgt: {tgt.shape if tgt is not None else None}")
|
|
@@ -308,288 +313,98 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 308 |
|
| 309 |
# Pass through encoder layers
|
| 310 |
memory = src_embeddings
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
# Pass through decoder layers (ensuring we maintain 3D shape)
|
| 315 |
-
# This maintains [batch_size, seq_length, hidden_dim]
|
| 316 |
-
output = memory
|
| 317 |
-
for dec_layer in self.decoder_layers:
|
| 318 |
-
output = dec_layer(output)
|
| 319 |
-
|
| 320 |
-
# Apply final projection to vocabulary space
|
| 321 |
-
# This should result in [batch_size, seq_length, vocab_size]
|
| 322 |
-
output = self.final_layer(output)
|
| 323 |
-
|
| 324 |
-
# CRITICAL: Ensure we keep the 3D shape for language modeling
|
| 325 |
-
# Check if we have a 2D tensor and reshape if needed
|
| 326 |
-
if output.dim() == 2:
|
| 327 |
-
# If 2D tensor [batch_size, vocab_size], reshape to 3D [batch_size, 1, vocab_size]
|
| 328 |
-
batch_size, vocab_size = output.shape
|
| 329 |
-
logger.info(f"Reshaping 2D output {output.shape} to 3D tensor")
|
| 330 |
-
output = output.unsqueeze(1) # Add sequence dimension
|
| 331 |
-
logger.info(f"Reshaped output: {output.shape}")
|
| 332 |
-
|
| 333 |
-
# Record the output shape and dimensions for debugging
|
| 334 |
-
logger.info(f"Output shape: {output.shape}, dimensions: {output.dim()}")
|
| 335 |
-
|
| 336 |
-
# Calculate loss if labels are provided
|
| 337 |
-
loss = None
|
| 338 |
-
if labels is not None:
|
| 339 |
-
# Check output shape
|
| 340 |
-
if output.dim() == 3: # [batch_size, seq_length, vocab_size]
|
| 341 |
-
batch_size, seq_length, vocab_size = output.shape
|
| 342 |
-
logger.info(f"3D tensor: batch_size={batch_size}, seq_length={seq_length}, vocab_size={vocab_size}")
|
| 343 |
-
|
| 344 |
-
# Check if labels are flattened
|
| 345 |
-
expected_flattened_size = batch_size * seq_length
|
| 346 |
-
is_flattened_labels = (labels.dim() == 1 and labels.size(0) == expected_flattened_size)
|
| 347 |
-
|
| 348 |
-
if is_flattened_labels:
|
| 349 |
-
# Reshape output to match flattened labels
|
| 350 |
-
output_reshaped = output.reshape(-1, vocab_size)
|
| 351 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 352 |
-
loss = loss_fct(output_reshaped, labels)
|
| 353 |
-
else:
|
| 354 |
-
# If labels are 2D [batch_size, seq_length], reshape them
|
| 355 |
-
if labels.dim() == 2:
|
| 356 |
-
# Need to reshape labels to 1D for CrossEntropyLoss
|
| 357 |
-
labels_reshaped = labels.reshape(-1)
|
| 358 |
-
output_reshaped = output.reshape(-1, vocab_size)
|
| 359 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 360 |
-
loss = loss_fct(output_reshaped, labels_reshaped)
|
| 361 |
-
else:
|
| 362 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 363 |
-
loss = loss_fct(output.view(-1, vocab_size), labels.view(-1))
|
| 364 |
-
elif output.dim() == 2: # [batch_size, vocab_size]
|
| 365 |
-
batch_size, vocab_size = output.shape
|
| 366 |
-
logger.info(f"2D tensor: batch_size={batch_size}, vocab_size={vocab_size}")
|
| 367 |
-
|
| 368 |
-
# Handle 2D output with 1D labels (single prediction per sequence)
|
| 369 |
-
if labels.dim() == 1 and labels.size(0) == batch_size:
|
| 370 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 371 |
-
loss = loss_fct(output, labels)
|
| 372 |
-
else:
|
| 373 |
-
# If we have 1D labels but with seq_length*batch_size elements
|
| 374 |
-
logger.warning(f"Label shape {labels.shape} incompatible with output {output.shape}")
|
| 375 |
-
# Take just the first token prediction for each sequence
|
| 376 |
-
labels_reshaped = labels.view(batch_size, -1)[:, 0]
|
| 377 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 378 |
-
loss = loss_fct(output, labels_reshaped)
|
| 379 |
-
|
| 380 |
-
# Return the proper format
|
| 381 |
-
if loss is not None:
|
| 382 |
-
logger.info(f"Returning loss tensor: {loss.item()}")
|
| 383 |
-
return loss, output
|
| 384 |
-
else:
|
| 385 |
-
return output
|
| 386 |
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 390 |
|
| 391 |
-
# Log input shapes for debugging
|
| 392 |
-
logger.error(f"Input shapes - src: {src.shape if src is not None else None}, input_ids: {input_ids.shape if input_ids is not None else None}")
|
| 393 |
-
|
| 394 |
-
# Ensure we return a proper tuple with correct types even in error case
|
| 395 |
-
dummy_output = torch.zeros(1)
|
| 396 |
-
dummy_loss = torch.tensor(float('nan'))
|
| 397 |
-
return dummy_loss, dummy_output
|
| 398 |
-
|
| 399 |
-
def forward(
|
| 400 |
-
self,
|
| 401 |
-
src: torch.Tensor = None,
|
| 402 |
-
tgt: Optional[torch.Tensor] = None,
|
| 403 |
-
token_type_ids: Optional[torch.Tensor] = None,
|
| 404 |
-
src_mask: Optional[torch.Tensor] = None,
|
| 405 |
-
tgt_mask: Optional[torch.Tensor] = None,
|
| 406 |
-
src_key_padding_mask: Optional[torch.Tensor] = None,
|
| 407 |
-
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
| 408 |
-
return_sequence: bool = False,
|
| 409 |
-
# Add Hugging Face compatibility parameters
|
| 410 |
-
input_ids: Optional[torch.Tensor] = None,
|
| 411 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 412 |
-
labels: Optional[torch.Tensor] = None,
|
| 413 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], ModelOutput]:
|
| 414 |
-
try:
|
| 415 |
-
# Add this at the start of the forward method to log input shapes
|
| 416 |
-
logger.info(f"Input shapes - src: {src.shape if src is not None else None}, tgt: {tgt.shape if tgt is not None else None}")
|
| 417 |
-
|
| 418 |
-
# Use Hugging Face parameters if provided
|
| 419 |
-
if src is None and input_ids is not None:
|
| 420 |
-
src = input_ids
|
| 421 |
-
if src_key_padding_mask is None and attention_mask is not None:
|
| 422 |
-
src_key_padding_mask = attention_mask
|
| 423 |
-
|
| 424 |
-
# Handle input shape - our layers expect batch_first=True format
|
| 425 |
-
if src.dim() == 2:
|
| 426 |
-
# src is already [batch_size, seq_len]
|
| 427 |
-
pass
|
| 428 |
-
elif src.dim() == 3 and src.size(0) > src.size(1):
|
| 429 |
-
# src is [seq_len, batch_size, dim] - need to transpose
|
| 430 |
-
src = src.transpose(0, 1)
|
| 431 |
-
|
| 432 |
-
# ----------------------------
|
| 433 |
-
# Encoder: Custom processing of source
|
| 434 |
-
# ----------------------------
|
| 435 |
-
src_emb = self.embedding(src) * math.sqrt(self.embedding_dim)
|
| 436 |
-
src_emb = self.pos_encoder(src_emb.transpose(0, 1)).transpose(0, 1) # Apply positional encoding
|
| 437 |
-
|
| 438 |
-
# Use hybrid attention if sequence length is above the threshold
|
| 439 |
if src.size(1) > 256 and hasattr(self, 'hybrid_attention'):
|
| 440 |
# Prepare inputs for hybrid attention
|
| 441 |
-
query =
|
| 442 |
key = query
|
| 443 |
value = query
|
| 444 |
|
| 445 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
hybrid_outputs = self.hybrid_attention(
|
| 447 |
query=query,
|
| 448 |
key=key,
|
| 449 |
value=value,
|
| 450 |
key_padding_mask=src_key_padding_mask,
|
| 451 |
-
attn_mask=src_mask,
|
| 452 |
prompt_length=src.size(1),
|
| 453 |
-
prompt_complexity=0.5
|
| 454 |
)
|
| 455 |
|
| 456 |
-
#
|
| 457 |
-
|
| 458 |
-
# If it returns a tuple, the first element is the attended output
|
| 459 |
-
attended_output = hybrid_outputs[0]
|
| 460 |
-
logger.debug(f"Hybrid attention returned tuple of length {len(hybrid_outputs)}")
|
| 461 |
-
else:
|
| 462 |
-
# If it returns a tensor directly
|
| 463 |
-
attended_output = hybrid_outputs
|
| 464 |
-
logger.debug("Hybrid attention returned single tensor")
|
| 465 |
-
|
| 466 |
-
# Convert back to expected format
|
| 467 |
-
encoded_src = attended_output.transpose(0, 1)
|
| 468 |
-
else:
|
| 469 |
-
# Use standard transformer encoder for shorter sequences
|
| 470 |
-
encoded_src = self.transformer_encoder(src_emb, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
|
| 471 |
-
|
| 472 |
-
# Process through adapter layer
|
| 473 |
-
adapted = self.adapter(encoded_src)
|
| 474 |
-
|
| 475 |
-
# ----------------------------
|
| 476 |
-
# Decoder / Output
|
| 477 |
-
# ----------------------------
|
| 478 |
-
if tgt is not None:
|
| 479 |
-
# Handle tgt shape for batch_first format
|
| 480 |
-
if tgt.dim() == 2:
|
| 481 |
-
# tgt is already [batch_size, seq_len]
|
| 482 |
-
pass
|
| 483 |
-
elif tgt.dim() == 3 and tgt.size(0) > tgt.size(1):
|
| 484 |
-
# tgt is [seq_len, batch_size, dim] - need to transpose
|
| 485 |
-
tgt = tgt.transpose(0, 1)
|
| 486 |
-
|
| 487 |
-
tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.embedding_dim)
|
| 488 |
-
tgt_emb = self.pos_decoder(tgt_emb.transpose(0, 1)).transpose(0, 1) # Apply positional encoding
|
| 489 |
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
tgt_key_padding_mask=tgt_key_padding_mask
|
| 496 |
-
)
|
| 497 |
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
#
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
pooled = encoded_src.mean(dim=1)
|
| 512 |
-
pooled = self.dropout_layer(pooled)
|
| 513 |
-
output = self.classifier(pooled)
|
| 514 |
|
| 515 |
# Calculate loss if labels are provided
|
| 516 |
loss = None
|
| 517 |
if labels is not None:
|
| 518 |
-
#
|
| 519 |
-
if
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
# Check output shape and handle multiple possible dimensions
|
| 524 |
-
output_dim = output.dim()
|
| 525 |
-
logger.info(f"Output shape: {output.shape}, dimensions: {output_dim}")
|
| 526 |
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
# Check labels shape
|
| 534 |
-
expected_flattened_size = batch_size * seq_length
|
| 535 |
-
is_flattened_labels = (labels.dim() == 1 and labels.size(0) == expected_flattened_size)
|
| 536 |
-
|
| 537 |
-
# Reshape output from [batch_size, seq_length, vocab_size] to [batch_size*seq_length, vocab_size]
|
| 538 |
-
output_reshaped = output.reshape(-1, vocab_size)
|
| 539 |
-
|
| 540 |
-
# Calculate loss
|
| 541 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 542 |
-
loss = loss_fct(output_reshaped, labels)
|
| 543 |
-
elif output_dim == 2: # [batch_size, vocab_size]
|
| 544 |
-
# Already shaped for loss calculation
|
| 545 |
-
batch_size, vocab_size = output.shape
|
| 546 |
-
logger.info(f"2D tensor: batch_size={batch_size}, vocab_size={vocab_size}")
|
| 547 |
-
|
| 548 |
-
# Need to reshape labels to 1D - this is the critical fix
|
| 549 |
-
if labels.dim() > 1: # If labels are multi-dimensional
|
| 550 |
-
# Language modeling usually needs the last token prediction
|
| 551 |
-
# Get the last token's label from each sequence
|
| 552 |
-
labels = labels[:, -1] # Take the last token from each sequence
|
| 553 |
-
logger.info(f"Reshaped labels to {labels.shape}")
|
| 554 |
-
|
| 555 |
-
# Now calculate loss with properly shaped tensors
|
| 556 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 557 |
-
loss = loss_fct(output, labels)
|
| 558 |
-
else:
|
| 559 |
-
logger.error(f"Unexpected output dimensions: {output_dim}")
|
| 560 |
-
# Create a dummy loss to avoid breaking the training loop
|
| 561 |
-
loss = torch.tensor(0.0, requires_grad=True, device=output.device)
|
| 562 |
|
| 563 |
# Return the proper format
|
| 564 |
if loss is not None:
|
| 565 |
-
logger.info(f"Returning loss tensor: {loss.item()}")
|
| 566 |
return loss, output
|
| 567 |
else:
|
| 568 |
return output
|
| 569 |
-
|
| 570 |
except Exception as e:
|
| 571 |
-
|
| 572 |
-
logger.error(f"Error in forward pass: {e}")
|
| 573 |
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 574 |
-
logger.error(f"Input shapes - src: {src.shape if src is not None else None}, "
|
| 575 |
-
f"input_ids: {input_ids.shape if input_ids is not None else None}")
|
| 576 |
|
| 577 |
-
#
|
| 578 |
-
|
| 579 |
-
dummy_output = torch.zeros((batch_size, self.output_size), device=next(self.parameters()).device)
|
| 580 |
|
| 581 |
-
#
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
# Add sentence transformer methods
|
| 594 |
def encode_sentences(self, sentences, batch_size=32, normalize_embeddings=True):
|
| 595 |
"""Encode sentences into vectors (sentence transformer functionality)"""
|
|
|
|
| 263 |
self.classifier = nn.Linear(embedding_dim, self.vocab_size)
|
| 264 |
self.dropout_layer = nn.Dropout(dropout)
|
| 265 |
|
| 266 |
+
# This is a standard linear layer that reshapes 3D input to 2D output:
|
| 267 |
+
self.final_layer = nn.Linear(hidden_dim, vocab_size)
|
| 268 |
+
|
| 269 |
self.init_weights()
|
| 270 |
|
| 271 |
def init_weights(self) -> None:
|
|
|
|
| 283 |
|
| 284 |
def forward(
|
| 285 |
self,
|
| 286 |
+
src: torch.Tensor = None,
|
| 287 |
+
tgt: Optional[torch.Tensor] = None,
|
| 288 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 289 |
+
src_mask: Optional[torch.Tensor] = None, # Make sure to include this parameter
|
| 290 |
+
tgt_mask: Optional[torch.Tensor] = None,
|
| 291 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
| 292 |
+
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
| 293 |
+
return_sequence: bool = False,
|
| 294 |
+
# Add Hugging Face compatibility parameters
|
| 295 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 296 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 297 |
+
labels: Optional[torch.Tensor] = None,
|
| 298 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], ModelOutput]:
|
| 299 |
try:
|
| 300 |
# Log input shapes for debugging
|
| 301 |
logger.info(f"Input shapes - src: {src.shape if src is not None else None}, tgt: {tgt.shape if tgt is not None else None}")
|
|
|
|
| 313 |
|
| 314 |
# Pass through encoder layers
|
| 315 |
memory = src_embeddings
|
| 316 |
+
# Ensure memory maintains 3 dimensions [batch_size, seq_length, hidden_dim]
|
| 317 |
+
if memory.dim() == 2:
|
| 318 |
+
memory = memory.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
# Use self.transformer_encoder instead of self.encoder_layers (which doesn't exist)
|
| 321 |
+
encoded_src = self.transformer_encoder(memory)
|
|
|
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
if src.size(1) > 256 and hasattr(self, 'hybrid_attention'):
|
| 324 |
# Prepare inputs for hybrid attention
|
| 325 |
+
query = src_embeddings.transpose(0, 1)
|
| 326 |
key = query
|
| 327 |
value = query
|
| 328 |
|
| 329 |
+
# IMPORTANT: Initialize src_mask if it's None
|
| 330 |
+
if src_mask is None and src is not None:
|
| 331 |
+
# Create a default mask that allows all tokens to attend to all other tokens
|
| 332 |
+
src_seq_len = src.size(1)
|
| 333 |
+
src_mask = torch.zeros((src_seq_len, src_seq_len), device=src.device, dtype=torch.bool)
|
| 334 |
+
|
| 335 |
+
# Actually using the hybrid attention here!
|
| 336 |
hybrid_outputs = self.hybrid_attention(
|
| 337 |
query=query,
|
| 338 |
key=key,
|
| 339 |
value=value,
|
| 340 |
key_padding_mask=src_key_padding_mask,
|
| 341 |
+
attn_mask=src_mask, # Now src_mask is properly defined
|
| 342 |
prompt_length=src.size(1),
|
| 343 |
+
prompt_complexity=0.5
|
| 344 |
)
|
| 345 |
|
| 346 |
+
# Process the hybrid attention outputs
|
| 347 |
+
encoded_src = hybrid_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
+
# Pass through decoder layers
|
| 350 |
+
output = encoded_src
|
| 351 |
+
# Ensure output maintains 3 dimensions [batch_size, seq_length, hidden_dim]
|
| 352 |
+
if output.dim() == 2:
|
| 353 |
+
output = output.unsqueeze(1)
|
|
|
|
|
|
|
| 354 |
|
| 355 |
+
# Apply final projection to vocabulary space
|
| 356 |
+
output = self.final_layer(output)
|
| 357 |
+
|
| 358 |
+
# CRITICAL: Ensure output is always 3D [batch_size, seq_length, vocab_size]
|
| 359 |
+
if output.dim() == 2:
|
| 360 |
+
# If 2D tensor [batch_size, vocab_size], reshape to 3D [batch_size, 1, vocab_size]
|
| 361 |
+
batch_size, vocab_size = output.shape
|
| 362 |
+
logger.info(f"2D tensor: batch_size={batch_size}, vocab_size={vocab_size}")
|
| 363 |
+
output = output.unsqueeze(1) # Add sequence dimension
|
| 364 |
+
logger.info(f"Reshaped 2D output to 3D tensor: {output.shape}")
|
| 365 |
+
|
| 366 |
+
# Record the output shape and dimensions for debugging
|
| 367 |
+
logger.info(f"Output shape: {output.shape}, dimensions: {output.dim()}")
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
# Calculate loss if labels are provided
|
| 370 |
loss = None
|
| 371 |
if labels is not None:
|
| 372 |
+
# Reshape labels to 1D if needed
|
| 373 |
+
if labels.dim() > 1:
|
| 374 |
+
labels = labels.reshape(-1)
|
| 375 |
+
logger.info(f"Reshaped labels to {labels.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
+
# Calculate loss with properly shaped tensors
|
| 378 |
+
batch_size, seq_length, vocab_size = output.shape
|
| 379 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 380 |
+
loss = loss_fct(output.reshape(-1, vocab_size), labels)
|
| 381 |
+
logger.info(f"Returning loss tensor: {loss.item()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
# Return the proper format
|
| 384 |
if loss is not None:
|
|
|
|
| 385 |
return loss, output
|
| 386 |
else:
|
| 387 |
return output
|
| 388 |
+
|
| 389 |
except Exception as e:
|
| 390 |
+
logger.error(f"Error in forward pass: {str(e)}")
|
|
|
|
| 391 |
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
|
|
|
|
|
| 392 |
|
| 393 |
+
# Log input shapes for debugging
|
| 394 |
+
logger.error(f"Input shapes - src: {src.shape if src is not None else None}, input_ids: {input_ids.shape if input_ids is not None else None}")
|
|
|
|
| 395 |
|
| 396 |
+
# Create minimal dummy outputs in correct format
|
| 397 |
+
dummy_batch = 1
|
| 398 |
+
if src is not None:
|
| 399 |
+
dummy_batch = src.shape[0]
|
| 400 |
+
elif input_ids is not None:
|
| 401 |
+
dummy_batch = input_ids.shape[0]
|
| 402 |
+
|
| 403 |
+
# CRITICAL: Return a proper 3D tensor even in error case
|
| 404 |
+
dummy_output = torch.zeros((dummy_batch, 1, self.vocab_size), device=next(self.parameters()).device)
|
| 405 |
+
dummy_loss = torch.tensor(float('nan'), device=next(self.parameters()).device)
|
| 406 |
+
return dummy_loss, dummy_output
|
| 407 |
+
|
| 408 |
# Add sentence transformer methods
|
| 409 |
def encode_sentences(self, sentences, batch_size=32, normalize_embeddings=True):
|
| 410 |
"""Encode sentences into vectors (sentence transformer functionality)"""
|
model_manager.py
CHANGED
|
@@ -341,7 +341,34 @@ class ModelManager:
|
|
| 341 |
return model
|
| 342 |
|
| 343 |
def route_input(self, input_text: str) -> dict:
|
|
|
|
| 344 |
input_embedding = self.embedding_model.encode(input_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
similarities = {}
|
| 346 |
for spec in self.specializations:
|
| 347 |
model = self.get_model(spec)
|
|
@@ -778,4 +805,4 @@ def register_models():
|
|
| 778 |
return True
|
| 779 |
except Exception as e:
|
| 780 |
logger.error(f"Failed to register models: {e}")
|
| 781 |
-
return False
|
|
|
|
| 341 |
return model
|
| 342 |
|
| 343 |
def route_input(self, input_text: str) -> dict:
|
| 344 |
+
# Create embedding for input text
|
| 345 |
input_embedding = self.embedding_model.encode(input_text)
|
| 346 |
+
|
| 347 |
+
# NEW: Process input through SmartHybridAttention for enhanced understanding
|
| 348 |
+
if hasattr(self, 'smart_attention') and self.smart_attention:
|
| 349 |
+
try:
|
| 350 |
+
# Convert embedding to tensor format needed by attention
|
| 351 |
+
import torch
|
| 352 |
+
input_tensor = torch.tensor(input_embedding).unsqueeze(0).unsqueeze(0) # [1, 1, dim]
|
| 353 |
+
|
| 354 |
+
# Process through attention mechanism to extract key patterns
|
| 355 |
+
# This helps identify which parts of input are most relevant
|
| 356 |
+
enhanced = self.smart_attention(
|
| 357 |
+
query=input_tensor,
|
| 358 |
+
key=input_tensor,
|
| 359 |
+
value=input_tensor
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Convert back to numpy for similarity calculations
|
| 363 |
+
if isinstance(enhanced, torch.Tensor):
|
| 364 |
+
enhanced_embedding = enhanced.squeeze().cpu().numpy()
|
| 365 |
+
# Use enhanced embedding for similarity calculation
|
| 366 |
+
input_embedding = enhanced_embedding
|
| 367 |
+
logger.info("Using SmartHybridAttention for enhanced prompt routing")
|
| 368 |
+
except Exception as e:
|
| 369 |
+
logger.warning(f"Error using SmartHybridAttention: {e}")
|
| 370 |
+
|
| 371 |
+
# Continue with existing similarity calculation
|
| 372 |
similarities = {}
|
| 373 |
for spec in self.specializations:
|
| 374 |
model = self.get_model(spec)
|
|
|
|
| 805 |
return True
|
| 806 |
except Exception as e:
|
| 807 |
logger.error(f"Failed to register models: {e}")
|
| 808 |
+
return False
|