File size: 4,373 Bytes
8f40d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
import pickle
from flask import Flask, request, jsonify, render_template

# --- Part 1: Re-define the Model Architecture ---
# This class definition must be EXACTLY the same as in your training script.
class ResidualLSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob):
        super(ResidualLSTMModel, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
            padding_idx=0
        )
        self.lstm1 = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_units,
            num_layers=1,
            batch_first=True
        )
        self.lstm2 = nn.LSTM(
            input_size=hidden_units,
            hidden_size=hidden_units,
            num_layers=1,
            batch_first=True
        )
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(hidden_units, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        out1, _ = self.lstm1(embedded)
        out2, _ = self.lstm2(out1)
        residual_sum = out1 + out2
        dropped_out = self.dropout(residual_sum)
        logits = self.fc(dropped_out)
        return logits

# --- Part 2: Helper Functions for Processing Text ---
def text_to_sequence(text, vocab, max_length):
    tokens = text.split()
    numericalized = [vocab.get(token, vocab.get('<UNK>', 1)) for token in tokens]
    if len(numericalized) > max_length:
        numericalized = numericalized[:max_length]
    pad_id = vocab.get('<PAD>', 0)
    padding_needed = max_length - len(numericalized)
    padded = numericalized + [pad_id] * padding_needed
    return torch.tensor([padded], dtype=torch.long)

def sequence_to_text(sequence, vocab):
    id_to_token = {id_val: token for token, id_val in vocab.items()}
    tokens = [id_to_token.get(id_val.item(), '<UNK>') for id_val in sequence if id_val.item() != vocab.get('<PAD>', 0)]
    return " ".join(tokens)

# --- Part 3: Main Prediction Logic ---
def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5):
    model.eval()
    with torch.no_grad():
        input_tensor = text_to_sequence(text, vocab, max_length).to(device)
        logits = model(input_tensor)
        num_input_tokens = len(text.split())
        if num_input_tokens == 0:
            return []
        last_token_logits = logits[0, num_input_tokens - 1, :]
        _, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1)
        top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids]
        return top_k_tokens

# --- Part 4: Flask App Initialization ---
app = Flask(__name__)

# --- Configuration and Model Loading ---
MODEL_PATH = 'model.pt'
VOCAB_PATH = 'vocab.pkl'
MAX_LENGTH = 1000

device = torch.device("cpu") # Use CPU for inference on a typical web server

# Load vocabulary
try:
    with open(VOCAB_PATH, 'rb') as f:
        vocab = pickle.load(f)
    print("Vocabulary loaded.")
except FileNotFoundError:
    print(f"Error: Vocabulary file not found at {VOCAB_PATH}")
    vocab = None

# Load the model
try:
    # Since the model was saved as a whole object, we need weights_only=False
    model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
    model.eval() # Set model to evaluation mode
    print("Model loaded.")
except FileNotFoundError:
    print(f"Error: Model file not found at {MODEL_PATH}")
    model = None
except Exception as e:
    print(f"An error occurred while loading the model: {e}")
    model = None


# --- Flask Routes ---
@app.route('/')
def home():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    if not model or not vocab:
        return jsonify({'error': 'Model or vocabulary not loaded. Check server logs.'}), 500

    data = request.get_json()
    code_snippet = data.get('code', '')

    if not code_snippet.strip():
        return jsonify({'suggestions': []})

    try:
        suggestions = predict_next_tokens(model, code_snippet, vocab, device, max_length=MAX_LENGTH)
        return jsonify({'suggestions': suggestions})
    except Exception as e:
        print(f"Prediction error: {e}")
        return jsonify({'error': 'Failed to get prediction.'}), 500

if __name__ == '__main__':
    app.run(debug=True)