Spaces:
Sleeping
Sleeping
File size: 3,698 Bytes
8f40d24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torch
import torch.nn as nn
import pickle
# --- Part 1: Re-define the Model Architecture ---
# This class definition must be EXACTLY the same as in your training script.
class ResidualLSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob):
super(ResidualLSTMModel, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=0
)
self.lstm1 = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_units,
num_layers=1,
batch_first=True
)
self.lstm2 = nn.LSTM(
input_size=hidden_units,
hidden_size=hidden_units,
num_layers=1,
batch_first=True
)
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(hidden_units, vocab_size)
def forward(self, x):
embedded = self.embedding(x)
out1, _ = self.lstm1(embedded)
out2, _ = self.lstm2(out1)
residual_sum = out1 + out2
dropped_out = self.dropout(residual_sum)
logits = self.fc(dropped_out)
return logits
# --- Part 2: Helper Functions for Processing Text ---
def text_to_sequence(text, vocab, max_length):
"""Converts a string of code into a padded tensor."""
tokens = text.split()
numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens]
if len(numericalized) > max_length:
numericalized = numericalized[:max_length]
pad_id = vocab['<PAD>']
padding_needed = max_length - len(numericalized)
padded = numericalized + [pad_id] * padding_needed
return torch.tensor([padded], dtype=torch.long)
def sequence_to_text(sequence, vocab):
"""Converts a tensor of token IDs back to a string."""
id_to_token = {id_val: token for token, id_val in vocab.items()}
tokens = [id_to_token.get(id_val.item(), '<UNK>') for id_val in sequence if id_val.item() != vocab['<PAD>']]
return " ".join(tokens)
# --- Part 3: Main Prediction Logic ---
def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5):
"""Predicts the top_k next tokens for a given text input."""
model.eval()
with torch.no_grad():
input_tensor = text_to_sequence(text, vocab, max_length).to(device)
logits = model(input_tensor)
num_input_tokens = len(text.split())
last_token_logits = logits[0, num_input_tokens - 1, :]
_, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1)
top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids]
return top_k_tokens
if __name__ == '__main__':
# --- Configuration ---
MODEL_PATH = 'model.pt'
VOCAB_PATH = 'vocab.pkl' # <-- Updated to use .pkl
MAX_LENGTH = 1000
# --- Load everything ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load vocabulary using pickle
with open(VOCAB_PATH, 'rb') as f: # <-- Use 'rb' for reading bytes
vocab = pickle.load(f)
print("Vocabulary loaded.")
# Load the model
model = torch.load(MODEL_PATH, map_location=device , weights_only=False)
print("Model loaded.")
# --- Make a Prediction ---
input_code = "import numpy as" # Example input
print(f"\nInput code: '{input_code}'")
suggestions = predict_next_tokens(model, input_code, vocab, device, max_length=MAX_LENGTH)
print("\nTop 5 suggestions:")
for i, suggestion in enumerate(suggestions):
print(f"{i+1}. {suggestion}") |