Update inference.py
Browse files- inference.py +44 -31
inference.py
CHANGED
|
@@ -4,18 +4,16 @@ import torch.nn.functional as F
|
|
| 4 |
import pickle
|
| 5 |
from safetensors.torch import load_file
|
| 6 |
import logging
|
|
|
|
| 7 |
|
| 8 |
# Set up logging
|
| 9 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 10 |
|
| 11 |
# Hyperparameters
|
| 12 |
-
embedding_dim =
|
| 13 |
-
#change H number according to model used
|
| 14 |
hidden_dim = 256
|
| 15 |
-
num_layers =
|
| 16 |
-
sequence_length =
|
| 17 |
-
temp = 1.0 # Temperature parameter
|
| 18 |
-
top_k = 10 # Top-k sampling parameter
|
| 19 |
|
| 20 |
# LSTM Model
|
| 21 |
class LSTMModel(nn.Module):
|
|
@@ -31,21 +29,6 @@ class LSTMModel(nn.Module):
|
|
| 31 |
logits = self.fc(lstm_out[:, -1, :])
|
| 32 |
return logits
|
| 33 |
|
| 34 |
-
# Load the model and vocabulary
|
| 35 |
-
logging.info('Loading the model and vocabulary...')
|
| 36 |
-
model_state_dict = load_file('lstm_H256.safetensors')
|
| 37 |
-
with open('word2idx.pkl', 'rb') as f:
|
| 38 |
-
word2idx = pickle.load(f)
|
| 39 |
-
with open('idx2word.pkl', 'rb') as f:
|
| 40 |
-
idx2word = pickle.load(f)
|
| 41 |
-
|
| 42 |
-
vocab_size = len(word2idx)
|
| 43 |
-
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
|
| 44 |
-
model.load_state_dict(model_state_dict)
|
| 45 |
-
model.eval()
|
| 46 |
-
|
| 47 |
-
logging.info('Model and vocabulary loaded successfully.')
|
| 48 |
-
|
| 49 |
# Function to predict the next word with temperature and top-k sampling
|
| 50 |
def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
|
| 51 |
model.eval()
|
|
@@ -62,22 +45,52 @@ def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp
|
|
| 62 |
return predicted_word
|
| 63 |
|
| 64 |
# Function to generate a sentence
|
| 65 |
-
def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length
|
| 66 |
sentence = start_sequence
|
| 67 |
for _ in range(max_length):
|
| 68 |
next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
|
| 69 |
sentence += ' ' + next_word
|
| 70 |
-
if next_word == '<pad>'
|
| 71 |
break
|
| 72 |
return sentence
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
logging.info(f'Starting sequence: {start_sequence}')
|
| 81 |
-
logging.info(f'Temperature: {temp}, Top-k: {top_k}')
|
| 82 |
-
generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k)
|
| 83 |
-
logging.info(f'Generated sentence: {generated_sentence}')
|
|
|
|
| 4 |
import pickle
|
| 5 |
from safetensors.torch import load_file
|
| 6 |
import logging
|
| 7 |
+
import argparse
|
| 8 |
|
| 9 |
# Set up logging
|
| 10 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 11 |
|
| 12 |
# Hyperparameters
|
| 13 |
+
embedding_dim = 128
|
|
|
|
| 14 |
hidden_dim = 256
|
| 15 |
+
num_layers = 2
|
| 16 |
+
sequence_length = 10
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# LSTM Model
|
| 19 |
class LSTMModel(nn.Module):
|
|
|
|
| 29 |
logits = self.fc(lstm_out[:, -1, :])
|
| 30 |
return logits
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# Function to predict the next word with temperature and top-k sampling
|
| 33 |
def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
|
| 34 |
model.eval()
|
|
|
|
| 45 |
return predicted_word
|
| 46 |
|
| 47 |
# Function to generate a sentence
|
| 48 |
+
def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length):
|
| 49 |
sentence = start_sequence
|
| 50 |
for _ in range(max_length):
|
| 51 |
next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
|
| 52 |
sentence += ' ' + next_word
|
| 53 |
+
if next_word == '<pad>':
|
| 54 |
break
|
| 55 |
return sentence
|
| 56 |
|
| 57 |
+
# Parse command-line arguments
|
| 58 |
+
def parse_args():
|
| 59 |
+
parser = argparse.ArgumentParser(description='LSTM Next Word Prediction Chatbot')
|
| 60 |
+
parser.add_argument('--temp', type=float, default=1.0, help='Temperature parameter')
|
| 61 |
+
parser.add_argument('--top_k', type=int, default=10, help='Top-k sampling parameter')
|
| 62 |
+
parser.add_argument('--model_file', type=str, default='lstm_model.safetensors', help='Path to the safetensors model file')
|
| 63 |
+
parser.add_argument('--start_sequence', type=str, default='Once upon a time', help='Starting sequence for sentence generation')
|
| 64 |
+
parser.add_argument('--max_length', type=int, default=50, help='Maximum number of words to generate')
|
| 65 |
+
return parser.parse_args()
|
| 66 |
+
|
| 67 |
+
# Main function
|
| 68 |
+
def main():
|
| 69 |
+
args = parse_args()
|
| 70 |
+
temp = args.temp
|
| 71 |
+
top_k = args.top_k
|
| 72 |
+
model_file = args.model_file
|
| 73 |
+
start_sequence = args.start_sequence
|
| 74 |
+
max_length = args.max_length
|
| 75 |
+
|
| 76 |
+
logging.info(f'Loading the model and vocabulary from {model_file}...')
|
| 77 |
+
model_state_dict = load_file(model_file)
|
| 78 |
+
with open('word2idx.pkl', 'rb') as f:
|
| 79 |
+
word2idx = pickle.load(f)
|
| 80 |
+
|
| 81 |
+
# Generate idx2word from word2idx
|
| 82 |
+
idx2word = {idx: word for word, idx in word2idx.items()}
|
| 83 |
+
|
| 84 |
+
vocab_size = len(word2idx)
|
| 85 |
+
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
|
| 86 |
+
model.load_state_dict(model_state_dict)
|
| 87 |
+
model.eval()
|
| 88 |
|
| 89 |
+
logging.info('Model and vocabulary loaded successfully.')
|
| 90 |
+
logging.info(f'Starting sequence: {start_sequence}')
|
| 91 |
+
logging.info(f'Temperature: {temp}, Top-k: {top_k}, Max Length: {max_length}')
|
| 92 |
+
generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length)
|
| 93 |
+
logging.info(f'Generated sentence: {generated_sentence}')
|
| 94 |
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|