Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import pickle | |
| from flask import Flask, request, jsonify, render_template | |
| # --- Part 1: Re-define the Model Architecture --- | |
| # This class definition must be EXACTLY the same as in your training script. | |
| class ResidualLSTMModel(nn.Module): | |
| def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob): | |
| super(ResidualLSTMModel, self).__init__() | |
| self.embedding = nn.Embedding( | |
| num_embeddings=vocab_size, | |
| embedding_dim=embedding_dim, | |
| padding_idx=0 | |
| ) | |
| self.lstm1 = nn.LSTM( | |
| input_size=embedding_dim, | |
| hidden_size=hidden_units, | |
| num_layers=1, | |
| batch_first=True | |
| ) | |
| self.lstm2 = nn.LSTM( | |
| input_size=hidden_units, | |
| hidden_size=hidden_units, | |
| num_layers=1, | |
| batch_first=True | |
| ) | |
| self.dropout = nn.Dropout(dropout_prob) | |
| self.fc = nn.Linear(hidden_units, vocab_size) | |
| def forward(self, x): | |
| embedded = self.embedding(x) | |
| out1, _ = self.lstm1(embedded) | |
| out2, _ = self.lstm2(out1) | |
| residual_sum = out1 + out2 | |
| dropped_out = self.dropout(residual_sum) | |
| logits = self.fc(dropped_out) | |
| return logits | |
| # --- Part 2: Helper Functions for Processing Text --- | |
| def text_to_sequence(text, vocab, max_length): | |
| tokens = text.split() | |
| numericalized = [vocab.get(token, vocab.get('<UNK>', 1)) for token in tokens] | |
| if len(numericalized) > max_length: | |
| numericalized = numericalized[:max_length] | |
| pad_id = vocab.get('<PAD>', 0) | |
| padding_needed = max_length - len(numericalized) | |
| padded = numericalized + [pad_id] * padding_needed | |
| return torch.tensor([padded], dtype=torch.long) | |
| def sequence_to_text(sequence, vocab): | |
| id_to_token = {id_val: token for token, id_val in vocab.items()} | |
| tokens = [id_to_token.get(id_val.item(), '<UNK>') for id_val in sequence if id_val.item() != vocab.get('<PAD>', 0)] | |
| return " ".join(tokens) | |
| # --- Part 3: Main Prediction Logic --- | |
| def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5): | |
| model.eval() | |
| with torch.no_grad(): | |
| input_tensor = text_to_sequence(text, vocab, max_length).to(device) | |
| logits = model(input_tensor) | |
| num_input_tokens = len(text.split()) | |
| if num_input_tokens == 0: | |
| return [] | |
| last_token_logits = logits[0, num_input_tokens - 1, :] | |
| _, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1) | |
| top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids] | |
| return top_k_tokens | |
| # --- Part 4: Flask App Initialization --- | |
| app = Flask(__name__) | |
| # --- Configuration and Model Loading --- | |
| MODEL_PATH = 'model.pt' | |
| VOCAB_PATH = 'vocab.pkl' | |
| MAX_LENGTH = 1000 | |
| device = torch.device("cpu") # Use CPU for inference on a typical web server | |
| # Load vocabulary | |
| try: | |
| with open(VOCAB_PATH, 'rb') as f: | |
| vocab = pickle.load(f) | |
| print("Vocabulary loaded.") | |
| except FileNotFoundError: | |
| print(f"Error: Vocabulary file not found at {VOCAB_PATH}") | |
| vocab = None | |
| # Load the model | |
| try: | |
| # Since the model was saved as a whole object, we need weights_only=False | |
| model = torch.load(MODEL_PATH, map_location=device, weights_only=False) | |
| model.eval() # Set model to evaluation mode | |
| print("Model loaded.") | |
| except FileNotFoundError: | |
| print(f"Error: Model file not found at {MODEL_PATH}") | |
| model = None | |
| except Exception as e: | |
| print(f"An error occurred while loading the model: {e}") | |
| model = None | |
| # --- Flask Routes --- | |
| def home(): | |
| return render_template('index.html') | |
| def predict(): | |
| if not model or not vocab: | |
| return jsonify({'error': 'Model or vocabulary not loaded. Check server logs.'}), 500 | |
| data = request.get_json() | |
| code_snippet = data.get('code', '') | |
| if not code_snippet.strip(): | |
| return jsonify({'suggestions': []}) | |
| try: | |
| suggestions = predict_next_tokens(model, code_snippet, vocab, device, max_length=MAX_LENGTH) | |
| return jsonify({'suggestions': suggestions}) | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| return jsonify({'error': 'Failed to get prediction.'}), 500 | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |