| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import pickle |
| | from safetensors.torch import load_file |
| | import logging |
| | import argparse |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| |
|
| | |
| | embedding_dim = 128 |
| | hidden_dim = 256 |
| | num_layers = 2 |
| | sequence_length = 10 |
| |
|
| | |
| | class LSTMModel(nn.Module): |
| | def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers): |
| | super(LSTMModel, self).__init__() |
| | self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| | self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True) |
| | self.fc = nn.Linear(hidden_dim, vocab_size) |
| |
|
| | def forward(self, x): |
| | embeds = self.embedding(x) |
| | lstm_out, _ = self.lstm(embeds) |
| | logits = self.fc(lstm_out[:, -1, :]) |
| | return logits |
| |
|
| | |
| | def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k): |
| | model.eval() |
| | with torch.no_grad(): |
| | seq_idx = [word2idx.get(word, word2idx['<UNK>']) for word in sequence.split()] |
| | seq_idx = seq_idx[-sequence_length:] |
| | seq_tensor = torch.tensor(seq_idx, dtype=torch.long).unsqueeze(0) |
| | outputs = model(seq_tensor) |
| | outputs = outputs / temp |
| | probs = F.softmax(outputs, dim=1).squeeze() |
| | top_k_probs, top_k_idx = torch.topk(probs, top_k) |
| | predicted_idx = torch.multinomial(top_k_probs, 1).item() |
| | predicted_word = idx2word[top_k_idx[predicted_idx].item()] |
| | return predicted_word |
| |
|
| | |
| | def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length): |
| | sentence = start_sequence |
| | for _ in range(max_length): |
| | next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k) |
| | sentence += ' ' + next_word |
| | if next_word == '<pad>': |
| | break |
| | return sentence |
| |
|
| | |
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description='LSTM Next Word Prediction Chatbot') |
| | parser.add_argument('--temp', type=float, default=1.0, help='Temperature parameter') |
| | parser.add_argument('--top_k', type=int, default=10, help='Top-k sampling parameter') |
| | parser.add_argument('--model_file', type=str, default='lstm_model.safetensors', help='Path to the safetensors model file') |
| | parser.add_argument('--start_sequence', type=str, default='Once upon a time', help='Starting sequence for sentence generation') |
| | parser.add_argument('--max_length', type=int, default=50, help='Maximum number of words to generate') |
| | return parser.parse_args() |
| |
|
| | |
| | def main(): |
| | args = parse_args() |
| | temp = args.temp |
| | top_k = args.top_k |
| | model_file = args.model_file |
| | start_sequence = args.start_sequence |
| | max_length = args.max_length |
| |
|
| | logging.info(f'Loading the model and vocabulary from {model_file}...') |
| | model_state_dict = load_file(model_file) |
| | with open('word2idx.pkl', 'rb') as f: |
| | word2idx = pickle.load(f) |
| |
|
| | |
| | idx2word = {idx: word for word, idx in word2idx.items()} |
| |
|
| | vocab_size = len(word2idx) |
| | model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers) |
| | model.load_state_dict(model_state_dict) |
| | model.eval() |
| |
|
| | logging.info('Model and vocabulary loaded successfully.') |
| | logging.info(f'Starting sequence: {start_sequence}') |
| | logging.info(f'Temperature: {temp}, Top-k: {top_k}, Max Length: {max_length}') |
| | generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length) |
| | logging.info(f'Generated sentence: {generated_sentence}') |
| |
|
| | if __name__ == '__main__': |
| | main() |