import torch import torch.nn as nn import pickle from flask import Flask, request, jsonify, render_template # --- Part 1: Re-define the Model Architecture --- # This class definition must be EXACTLY the same as in your training script. class ResidualLSTMModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob): super(ResidualLSTMModel, self).__init__() self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0 ) self.lstm1 = nn.LSTM( input_size=embedding_dim, hidden_size=hidden_units, num_layers=1, batch_first=True ) self.lstm2 = nn.LSTM( input_size=hidden_units, hidden_size=hidden_units, num_layers=1, batch_first=True ) self.dropout = nn.Dropout(dropout_prob) self.fc = nn.Linear(hidden_units, vocab_size) def forward(self, x): embedded = self.embedding(x) out1, _ = self.lstm1(embedded) out2, _ = self.lstm2(out1) residual_sum = out1 + out2 dropped_out = self.dropout(residual_sum) logits = self.fc(dropped_out) return logits # --- Part 2: Helper Functions for Processing Text --- def text_to_sequence(text, vocab, max_length): tokens = text.split() numericalized = [vocab.get(token, vocab.get('', 1)) for token in tokens] if len(numericalized) > max_length: numericalized = numericalized[:max_length] pad_id = vocab.get('', 0) padding_needed = max_length - len(numericalized) padded = numericalized + [pad_id] * padding_needed return torch.tensor([padded], dtype=torch.long) def sequence_to_text(sequence, vocab): id_to_token = {id_val: token for token, id_val in vocab.items()} tokens = [id_to_token.get(id_val.item(), '') for id_val in sequence if id_val.item() != vocab.get('', 0)] return " ".join(tokens) # --- Part 3: Main Prediction Logic --- def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5): model.eval() with torch.no_grad(): input_tensor = text_to_sequence(text, vocab, max_length).to(device) logits = model(input_tensor) num_input_tokens = len(text.split()) if num_input_tokens == 0: return [] last_token_logits = logits[0, num_input_tokens - 1, :] _, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1) top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids] return top_k_tokens # --- Part 4: Flask App Initialization --- app = Flask(__name__) # --- Configuration and Model Loading --- MODEL_PATH = 'model.pt' VOCAB_PATH = 'vocab.pkl' MAX_LENGTH = 1000 device = torch.device("cpu") # Use CPU for inference on a typical web server # Load vocabulary try: with open(VOCAB_PATH, 'rb') as f: vocab = pickle.load(f) print("Vocabulary loaded.") except FileNotFoundError: print(f"Error: Vocabulary file not found at {VOCAB_PATH}") vocab = None # Load the model try: # Since the model was saved as a whole object, we need weights_only=False model = torch.load(MODEL_PATH, map_location=device, weights_only=False) model.eval() # Set model to evaluation mode print("Model loaded.") except FileNotFoundError: print(f"Error: Model file not found at {MODEL_PATH}") model = None except Exception as e: print(f"An error occurred while loading the model: {e}") model = None # --- Flask Routes --- @app.route('/') def home(): return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): if not model or not vocab: return jsonify({'error': 'Model or vocabulary not loaded. Check server logs.'}), 500 data = request.get_json() code_snippet = data.get('code', '') if not code_snippet.strip(): return jsonify({'suggestions': []}) try: suggestions = predict_next_tokens(model, code_snippet, vocab, device, max_length=MAX_LENGTH) return jsonify({'suggestions': suggestions}) except Exception as e: print(f"Prediction error: {e}") return jsonify({'error': 'Failed to get prediction.'}), 500 if __name__ == '__main__': app.run(debug=True)