Spaces:
Paused
Paused
File size: 2,690 Bytes
addc107 d6e3f3c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | import torch
import torch.nn.functional as F
def decode_predictions(predictions, blank_label=0, return_wait_times=False):
"""
Decodes the predictions using greedy decoding (best path), correctly handling duplicates.
Args:
predictions: A tensor of shape [B, C, L] representing the logits.
blank_label: The index of the blank label.
Returns:
A list of tensors, where each tensor is the decoded sequence.
"""
batch_size, num_classes, prediction_length = predictions.shape
decoded_sequences = []
wait_times_all = []
probs = F.softmax(predictions, dim=1) # Probabilities
for b in range(batch_size):
best_path = torch.argmax(probs[b], dim=0) # Best path indices
decoded = []
wait_times = []
prev_char = -1 # Keep track of the previous character
wait_time_now = 0
for t in range(prediction_length):
char_idx = best_path[t].item() # Get index as integer
if char_idx != blank_label and char_idx != prev_char: # Skip blanks and duplicates
decoded.append(char_idx)
prev_char = char_idx # Update previous character
wait_times.append(wait_time_now)
wait_time_now = 0
else:
wait_time_now += 1
decoded_sequences.append(torch.tensor(decoded, device=predictions.device))
if return_wait_times: wait_times_all.append(torch.tensor(wait_times, device=predictions.device))
if return_wait_times: return decoded_sequences, wait_times_all
return decoded_sequences
def compute_ctc_accuracy(predictions, targets, blank_label=0):
"""
Computes the accuracy of the predictions given the targets, considering CTC decoding.
Args:
predictions: A tensor of shape [B, C, L] representing the logits.
targets: A list of tensors, each of shape [T_i], representing a target sequence.
blank_label: The index of the blank label.
Returns:
The accuracy (a float).
"""
batch_size, num_classes, prediction_length = predictions.shape
total_correct = 0
# 1. Get predicted sequences (decoded from logits):
predicted_sequences = decode_predictions(predictions, blank_label)
# 2. Compare predicted sequences to targets:
for i in range(batch_size):
target = targets[i]
predicted = predicted_sequences[i]
if torch.equal(predicted, target): # Direct comparison of tensors
total_correct += 1
accuracy = total_correct / batch_size if batch_size > 0 else 0.0
return accuracy |