import torch import torch.nn as nn import pickle # --- Part 1: Re-define the Model Architecture --- # This class definition must be EXACTLY the same as in your training script. class ResidualLSTMModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob): super(ResidualLSTMModel, self).__init__() self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0 ) self.lstm1 = nn.LSTM( input_size=embedding_dim, hidden_size=hidden_units, num_layers=1, batch_first=True ) self.lstm2 = nn.LSTM( input_size=hidden_units, hidden_size=hidden_units, num_layers=1, batch_first=True ) self.dropout = nn.Dropout(dropout_prob) self.fc = nn.Linear(hidden_units, vocab_size) def forward(self, x): embedded = self.embedding(x) out1, _ = self.lstm1(embedded) out2, _ = self.lstm2(out1) residual_sum = out1 + out2 dropped_out = self.dropout(residual_sum) logits = self.fc(dropped_out) return logits # --- Part 2: Helper Functions for Processing Text --- def text_to_sequence(text, vocab, max_length): """Converts a string of code into a padded tensor.""" tokens = text.split() numericalized = [vocab.get(token, vocab['']) for token in tokens] if len(numericalized) > max_length: numericalized = numericalized[:max_length] pad_id = vocab[''] padding_needed = max_length - len(numericalized) padded = numericalized + [pad_id] * padding_needed return torch.tensor([padded], dtype=torch.long) def sequence_to_text(sequence, vocab): """Converts a tensor of token IDs back to a string.""" id_to_token = {id_val: token for token, id_val in vocab.items()} tokens = [id_to_token.get(id_val.item(), '') for id_val in sequence if id_val.item() != vocab['']] return " ".join(tokens) # --- Part 3: Main Prediction Logic --- def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5): """Predicts the top_k next tokens for a given text input.""" model.eval() with torch.no_grad(): input_tensor = text_to_sequence(text, vocab, max_length).to(device) logits = model(input_tensor) num_input_tokens = len(text.split()) last_token_logits = logits[0, num_input_tokens - 1, :] _, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1) top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids] return top_k_tokens if __name__ == '__main__': # --- Configuration --- MODEL_PATH = 'model.pt' VOCAB_PATH = 'vocab.pkl' # <-- Updated to use .pkl MAX_LENGTH = 1000 # --- Load everything --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load vocabulary using pickle with open(VOCAB_PATH, 'rb') as f: # <-- Use 'rb' for reading bytes vocab = pickle.load(f) print("Vocabulary loaded.") # Load the model model = torch.load(MODEL_PATH, map_location=device , weights_only=False) print("Model loaded.") # --- Make a Prediction --- input_code = "import numpy as" # Example input print(f"\nInput code: '{input_code}'") suggestions = predict_next_tokens(model, input_code, vocab, device, max_length=MAX_LENGTH) print("\nTop 5 suggestions:") for i, suggestion in enumerate(suggestions): print(f"{i+1}. {suggestion}")