English
File size: 3,925 Bytes
a1b5703
 
 
 
 
 
cc68a7c
a1b5703
 
 
 
 
cc68a7c
9d40549
cc68a7c
 
a1b5703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc68a7c
a1b5703
 
 
 
cc68a7c
a1b5703
 
 
cc68a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1b5703
cc68a7c
 
 
 
 
a1b5703
cc68a7c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from safetensors.torch import load_file
import logging
import argparse

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Hyperparameters
embedding_dim = 128
hidden_dim = 256
num_layers = 2
sequence_length = 10

# LSTM Model
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embeds = self.embedding(x)
        lstm_out, _ = self.lstm(embeds)
        logits = self.fc(lstm_out[:, -1, :])
        return logits

# Function to predict the next word with temperature and top-k sampling
def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
    model.eval()
    with torch.no_grad():
        seq_idx = [word2idx.get(word, word2idx['<UNK>']) for word in sequence.split()]
        seq_idx = seq_idx[-sequence_length:]  # Ensure the sequence length is correct
        seq_tensor = torch.tensor(seq_idx, dtype=torch.long).unsqueeze(0)
        outputs = model(seq_tensor)
        outputs = outputs / temp  # Apply temperature
        probs = F.softmax(outputs, dim=1).squeeze()
        top_k_probs, top_k_idx = torch.topk(probs, top_k)
        predicted_idx = torch.multinomial(top_k_probs, 1).item()
        predicted_word = idx2word[top_k_idx[predicted_idx].item()]
        return predicted_word

# Function to generate a sentence
def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length):
    sentence = start_sequence
    for _ in range(max_length):
        next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
        sentence += ' ' + next_word
        if next_word == '<pad>':
            break
    return sentence

# Parse command-line arguments
def parse_args():
    parser = argparse.ArgumentParser(description='LSTM Next Word Prediction Chatbot')
    parser.add_argument('--temp', type=float, default=1.0, help='Temperature parameter')
    parser.add_argument('--top_k', type=int, default=10, help='Top-k sampling parameter')
    parser.add_argument('--model_file', type=str, default='lstm_model.safetensors', help='Path to the safetensors model file')
    parser.add_argument('--start_sequence', type=str, default='Once upon a time', help='Starting sequence for sentence generation')
    parser.add_argument('--max_length', type=int, default=50, help='Maximum number of words to generate')
    return parser.parse_args()

# Main function
def main():
    args = parse_args()
    temp = args.temp
    top_k = args.top_k
    model_file = args.model_file
    start_sequence = args.start_sequence
    max_length = args.max_length

    logging.info(f'Loading the model and vocabulary from {model_file}...')
    model_state_dict = load_file(model_file)
    with open('word2idx.pkl', 'rb') as f:
        word2idx = pickle.load(f)

    # Generate idx2word from word2idx
    idx2word = {idx: word for word, idx in word2idx.items()}

    vocab_size = len(word2idx)
    model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
    model.load_state_dict(model_state_dict)
    model.eval()

    logging.info('Model and vocabulary loaded successfully.')
    logging.info(f'Starting sequence: {start_sequence}')
    logging.info(f'Temperature: {temp}, Top-k: {top_k}, Max Length: {max_length}')
    generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length)
    logging.info(f'Generated sentence: {generated_sentence}')

if __name__ == '__main__':
    main()