redneural / tasks /sort /utils.py
Elliotasdasdasfasas's picture
Config: Added Azure SSHFS persistence via cloudflared
addc107
import torch
import torch.nn.functional as F
def decode_predictions(predictions, blank_label=0, return_wait_times=False):
"""
Decodes the predictions using greedy decoding (best path), correctly handling duplicates.
Args:
predictions: A tensor of shape [B, C, L] representing the logits.
blank_label: The index of the blank label.
Returns:
A list of tensors, where each tensor is the decoded sequence.
"""
batch_size, num_classes, prediction_length = predictions.shape
decoded_sequences = []
wait_times_all = []
probs = F.softmax(predictions, dim=1) # Probabilities
for b in range(batch_size):
best_path = torch.argmax(probs[b], dim=0) # Best path indices
decoded = []
wait_times = []
prev_char = -1 # Keep track of the previous character
wait_time_now = 0
for t in range(prediction_length):
char_idx = best_path[t].item() # Get index as integer
if char_idx != blank_label and char_idx != prev_char: # Skip blanks and duplicates
decoded.append(char_idx)
prev_char = char_idx # Update previous character
wait_times.append(wait_time_now)
wait_time_now = 0
else:
wait_time_now += 1
decoded_sequences.append(torch.tensor(decoded, device=predictions.device))
if return_wait_times: wait_times_all.append(torch.tensor(wait_times, device=predictions.device))
if return_wait_times: return decoded_sequences, wait_times_all
return decoded_sequences
def compute_ctc_accuracy(predictions, targets, blank_label=0):
"""
Computes the accuracy of the predictions given the targets, considering CTC decoding.
Args:
predictions: A tensor of shape [B, C, L] representing the logits.
targets: A list of tensors, each of shape [T_i], representing a target sequence.
blank_label: The index of the blank label.
Returns:
The accuracy (a float).
"""
batch_size, num_classes, prediction_length = predictions.shape
total_correct = 0
# 1. Get predicted sequences (decoded from logits):
predicted_sequences = decode_predictions(predictions, blank_label)
# 2. Compare predicted sequences to targets:
for i in range(batch_size):
target = targets[i]
predicted = predicted_sequences[i]
if torch.equal(predicted, target): # Direct comparison of tensors
total_correct += 1
accuracy = total_correct / batch_size if batch_size > 0 else 0.0
return accuracy