Spaces:
Sleeping
Sleeping
File size: 2,783 Bytes
e361ea8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | import torch
import numpy as np
import json
def load_rnn_params(filepath):
with open(filepath, 'r') as f:
data = json.load(f)
return torch.tensor(data, dtype=torch.float32)
W_h = load_rnn_params("W_h.weight.json")
Wh_bias = load_rnn_params("W_h.bias.json").unsqueeze(1)
U_h = load_rnn_params("U_h.weight.json")
W_y = load_rnn_params("W_y.weight.json")
Embedding = load_rnn_params("embedding.weight.json")
with open("vocab.json", 'r') as f:
VOCAB = json.load(f)
# Знаходимо індекси спеціальних токенів
BOS_INDEX = VOCAB.index('[')
EOS_INDEX = VOCAB.index(']')
# Розміри
embedding_dim = 96
hidden_dim = 160
vocab_size = 132
@torch.no_grad()
def greedy_decode_rnn(W_h, Wh_bias, U_h, W_y, Embedding, VOCAB, BOS_INDEX, EOS_INDEX, max_len=100):
# Ініціалізація
h_t = torch.zeros(hidden_dim, 1, dtype=torch.float32)
# Початковий токен
current_token_index = BOS_INDEX
decoded_message_indices = [BOS_INDEX]
# Цикл декодування
for _ in range(max_len - 1): # -1, тому що перший токен вже є
# 1. Вхідний вектор (x_t)
# Розмірність: (embedding_dim) -> [96]
x_t = Embedding[current_token_index]
# Додаємо вимір для матричного множення: [96] -> [96, 1]
x_t = x_t.unsqueeze(1)
# 2. Обчислення нового прихованого стану (h_t)
# Використовуємо оператор @ для матричного множення (MatMul)
# W_h @ x_t: [160, 96] @ [96, 1] -> [160, 1]
term1 = W_h @ x_t
# U_h @ h_{t-1}: [160, 160] @ [160, 1] -> [160, 1]
term2 = U_h @ h_t
# Новий прихований стан (h_t)
h_t = torch.tanh(term1 + Wh_bias + term2)
# 3. Обчислення логітів виходу (y_t)
# y_t: [vocab_size, hidden_dim] @ [hidden_dim, 1] -> [vocab_size, 1]
y_t = W_y @ h_t
# 4. Жадібний декодинг (argmax)
# Знаходимо індекс токена з найбільшим логітом
current_token_index = torch.argmax(y_t).item()
# 5. Зупинка
if current_token_index == EOS_INDEX:
decoded_message_indices.append(current_token_index)
break
decoded_message_indices.append(current_token_index)
# 6. Декодування повідомлення
decoded_message = "".join([VOCAB[idx] for idx in decoded_message_indices])
return decoded_message
decoded_message = greedy_decode_rnn(W_h, Wh_bias, U_h, W_y, Embedding, VOCAB, BOS_INDEX, EOS_INDEX)
print(decoded_message)
|