Ayush
Added the code files
8f40d24
import torch
import torch.nn as nn
import pickle
from flask import Flask, request, jsonify, render_template
# --- Part 1: Re-define the Model Architecture ---
# This class definition must be EXACTLY the same as in your training script.
class ResidualLSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob):
super(ResidualLSTMModel, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=0
)
self.lstm1 = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_units,
num_layers=1,
batch_first=True
)
self.lstm2 = nn.LSTM(
input_size=hidden_units,
hidden_size=hidden_units,
num_layers=1,
batch_first=True
)
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(hidden_units, vocab_size)
def forward(self, x):
embedded = self.embedding(x)
out1, _ = self.lstm1(embedded)
out2, _ = self.lstm2(out1)
residual_sum = out1 + out2
dropped_out = self.dropout(residual_sum)
logits = self.fc(dropped_out)
return logits
# --- Part 2: Helper Functions for Processing Text ---
def text_to_sequence(text, vocab, max_length):
tokens = text.split()
numericalized = [vocab.get(token, vocab.get('<UNK>', 1)) for token in tokens]
if len(numericalized) > max_length:
numericalized = numericalized[:max_length]
pad_id = vocab.get('<PAD>', 0)
padding_needed = max_length - len(numericalized)
padded = numericalized + [pad_id] * padding_needed
return torch.tensor([padded], dtype=torch.long)
def sequence_to_text(sequence, vocab):
id_to_token = {id_val: token for token, id_val in vocab.items()}
tokens = [id_to_token.get(id_val.item(), '<UNK>') for id_val in sequence if id_val.item() != vocab.get('<PAD>', 0)]
return " ".join(tokens)
# --- Part 3: Main Prediction Logic ---
def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5):
model.eval()
with torch.no_grad():
input_tensor = text_to_sequence(text, vocab, max_length).to(device)
logits = model(input_tensor)
num_input_tokens = len(text.split())
if num_input_tokens == 0:
return []
last_token_logits = logits[0, num_input_tokens - 1, :]
_, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1)
top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids]
return top_k_tokens
# --- Part 4: Flask App Initialization ---
app = Flask(__name__)
# --- Configuration and Model Loading ---
MODEL_PATH = 'model.pt'
VOCAB_PATH = 'vocab.pkl'
MAX_LENGTH = 1000
device = torch.device("cpu") # Use CPU for inference on a typical web server
# Load vocabulary
try:
with open(VOCAB_PATH, 'rb') as f:
vocab = pickle.load(f)
print("Vocabulary loaded.")
except FileNotFoundError:
print(f"Error: Vocabulary file not found at {VOCAB_PATH}")
vocab = None
# Load the model
try:
# Since the model was saved as a whole object, we need weights_only=False
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.eval() # Set model to evaluation mode
print("Model loaded.")
except FileNotFoundError:
print(f"Error: Model file not found at {MODEL_PATH}")
model = None
except Exception as e:
print(f"An error occurred while loading the model: {e}")
model = None
# --- Flask Routes ---
@app.route('/')
def home():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
if not model or not vocab:
return jsonify({'error': 'Model or vocabulary not loaded. Check server logs.'}), 500
data = request.get_json()
code_snippet = data.get('code', '')
if not code_snippet.strip():
return jsonify({'suggestions': []})
try:
suggestions = predict_next_tokens(model, code_snippet, vocab, device, max_length=MAX_LENGTH)
return jsonify({'suggestions': suggestions})
except Exception as e:
print(f"Prediction error: {e}")
return jsonify({'error': 'Failed to get prediction.'}), 500
if __name__ == '__main__':
app.run(debug=True)