Spaces:
Sleeping
Sleeping
File size: 4,373 Bytes
8f40d24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import torch.nn as nn
import pickle
from flask import Flask, request, jsonify, render_template
# --- Part 1: Re-define the Model Architecture ---
# This class definition must be EXACTLY the same as in your training script.
class ResidualLSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob):
super(ResidualLSTMModel, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=0
)
self.lstm1 = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_units,
num_layers=1,
batch_first=True
)
self.lstm2 = nn.LSTM(
input_size=hidden_units,
hidden_size=hidden_units,
num_layers=1,
batch_first=True
)
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(hidden_units, vocab_size)
def forward(self, x):
embedded = self.embedding(x)
out1, _ = self.lstm1(embedded)
out2, _ = self.lstm2(out1)
residual_sum = out1 + out2
dropped_out = self.dropout(residual_sum)
logits = self.fc(dropped_out)
return logits
# --- Part 2: Helper Functions for Processing Text ---
def text_to_sequence(text, vocab, max_length):
tokens = text.split()
numericalized = [vocab.get(token, vocab.get('<UNK>', 1)) for token in tokens]
if len(numericalized) > max_length:
numericalized = numericalized[:max_length]
pad_id = vocab.get('<PAD>', 0)
padding_needed = max_length - len(numericalized)
padded = numericalized + [pad_id] * padding_needed
return torch.tensor([padded], dtype=torch.long)
def sequence_to_text(sequence, vocab):
id_to_token = {id_val: token for token, id_val in vocab.items()}
tokens = [id_to_token.get(id_val.item(), '<UNK>') for id_val in sequence if id_val.item() != vocab.get('<PAD>', 0)]
return " ".join(tokens)
# --- Part 3: Main Prediction Logic ---
def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5):
model.eval()
with torch.no_grad():
input_tensor = text_to_sequence(text, vocab, max_length).to(device)
logits = model(input_tensor)
num_input_tokens = len(text.split())
if num_input_tokens == 0:
return []
last_token_logits = logits[0, num_input_tokens - 1, :]
_, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1)
top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids]
return top_k_tokens
# --- Part 4: Flask App Initialization ---
app = Flask(__name__)
# --- Configuration and Model Loading ---
MODEL_PATH = 'model.pt'
VOCAB_PATH = 'vocab.pkl'
MAX_LENGTH = 1000
device = torch.device("cpu") # Use CPU for inference on a typical web server
# Load vocabulary
try:
with open(VOCAB_PATH, 'rb') as f:
vocab = pickle.load(f)
print("Vocabulary loaded.")
except FileNotFoundError:
print(f"Error: Vocabulary file not found at {VOCAB_PATH}")
vocab = None
# Load the model
try:
# Since the model was saved as a whole object, we need weights_only=False
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model.eval() # Set model to evaluation mode
print("Model loaded.")
except FileNotFoundError:
print(f"Error: Model file not found at {MODEL_PATH}")
model = None
except Exception as e:
print(f"An error occurred while loading the model: {e}")
model = None
# --- Flask Routes ---
@app.route('/')
def home():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
if not model or not vocab:
return jsonify({'error': 'Model or vocabulary not loaded. Check server logs.'}), 500
data = request.get_json()
code_snippet = data.get('code', '')
if not code_snippet.strip():
return jsonify({'suggestions': []})
try:
suggestions = predict_next_tokens(model, code_snippet, vocab, device, max_length=MAX_LENGTH)
return jsonify({'suggestions': suggestions})
except Exception as e:
print(f"Prediction error: {e}")
return jsonify({'error': 'Failed to get prediction.'}), 500
if __name__ == '__main__':
app.run(debug=True)
|