Ayush commited on
Commit
8f40d24
·
1 Parent(s): f67957d

Added the code files

Browse files
Files changed (8) hide show
  1. Dockerfile +20 -0
  2. Procfile +1 -0
  3. app.py +127 -0
  4. main.py +107 -0
  5. model.pt +3 -0
  6. requirements.txt +3 -0
  7. templates/index.html +183 -0
  8. vocab.pkl +3 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Copy the requirements file into the container
8
+ COPY requirements.txt .
9
+
10
+ # Install any needed packages specified in requirements.txt
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Copy the rest of the application's code
14
+ COPY . .
15
+
16
+ # Expose the port the app runs on
17
+ EXPOSE 7860
18
+
19
+ # Command to run the app using Gunicorn
20
+ CMD ["gunicorn", "--bind", "0.0.0.0:7860", "app:app"]
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: gunicorn app:app
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pickle
4
+ from flask import Flask, request, jsonify, render_template
5
+
6
+ # --- Part 1: Re-define the Model Architecture ---
7
+ # This class definition must be EXACTLY the same as in your training script.
8
+ class ResidualLSTMModel(nn.Module):
9
+ def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob):
10
+ super(ResidualLSTMModel, self).__init__()
11
+ self.embedding = nn.Embedding(
12
+ num_embeddings=vocab_size,
13
+ embedding_dim=embedding_dim,
14
+ padding_idx=0
15
+ )
16
+ self.lstm1 = nn.LSTM(
17
+ input_size=embedding_dim,
18
+ hidden_size=hidden_units,
19
+ num_layers=1,
20
+ batch_first=True
21
+ )
22
+ self.lstm2 = nn.LSTM(
23
+ input_size=hidden_units,
24
+ hidden_size=hidden_units,
25
+ num_layers=1,
26
+ batch_first=True
27
+ )
28
+ self.dropout = nn.Dropout(dropout_prob)
29
+ self.fc = nn.Linear(hidden_units, vocab_size)
30
+
31
+ def forward(self, x):
32
+ embedded = self.embedding(x)
33
+ out1, _ = self.lstm1(embedded)
34
+ out2, _ = self.lstm2(out1)
35
+ residual_sum = out1 + out2
36
+ dropped_out = self.dropout(residual_sum)
37
+ logits = self.fc(dropped_out)
38
+ return logits
39
+
40
+ # --- Part 2: Helper Functions for Processing Text ---
41
+ def text_to_sequence(text, vocab, max_length):
42
+ tokens = text.split()
43
+ numericalized = [vocab.get(token, vocab.get('<UNK>', 1)) for token in tokens]
44
+ if len(numericalized) > max_length:
45
+ numericalized = numericalized[:max_length]
46
+ pad_id = vocab.get('<PAD>', 0)
47
+ padding_needed = max_length - len(numericalized)
48
+ padded = numericalized + [pad_id] * padding_needed
49
+ return torch.tensor([padded], dtype=torch.long)
50
+
51
+ def sequence_to_text(sequence, vocab):
52
+ id_to_token = {id_val: token for token, id_val in vocab.items()}
53
+ tokens = [id_to_token.get(id_val.item(), '<UNK>') for id_val in sequence if id_val.item() != vocab.get('<PAD>', 0)]
54
+ return " ".join(tokens)
55
+
56
+ # --- Part 3: Main Prediction Logic ---
57
+ def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5):
58
+ model.eval()
59
+ with torch.no_grad():
60
+ input_tensor = text_to_sequence(text, vocab, max_length).to(device)
61
+ logits = model(input_tensor)
62
+ num_input_tokens = len(text.split())
63
+ if num_input_tokens == 0:
64
+ return []
65
+ last_token_logits = logits[0, num_input_tokens - 1, :]
66
+ _, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1)
67
+ top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids]
68
+ return top_k_tokens
69
+
70
+ # --- Part 4: Flask App Initialization ---
71
+ app = Flask(__name__)
72
+
73
+ # --- Configuration and Model Loading ---
74
+ MODEL_PATH = 'model.pt'
75
+ VOCAB_PATH = 'vocab.pkl'
76
+ MAX_LENGTH = 1000
77
+
78
+ device = torch.device("cpu") # Use CPU for inference on a typical web server
79
+
80
+ # Load vocabulary
81
+ try:
82
+ with open(VOCAB_PATH, 'rb') as f:
83
+ vocab = pickle.load(f)
84
+ print("Vocabulary loaded.")
85
+ except FileNotFoundError:
86
+ print(f"Error: Vocabulary file not found at {VOCAB_PATH}")
87
+ vocab = None
88
+
89
+ # Load the model
90
+ try:
91
+ # Since the model was saved as a whole object, we need weights_only=False
92
+ model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
93
+ model.eval() # Set model to evaluation mode
94
+ print("Model loaded.")
95
+ except FileNotFoundError:
96
+ print(f"Error: Model file not found at {MODEL_PATH}")
97
+ model = None
98
+ except Exception as e:
99
+ print(f"An error occurred while loading the model: {e}")
100
+ model = None
101
+
102
+
103
+ # --- Flask Routes ---
104
+ @app.route('/')
105
+ def home():
106
+ return render_template('index.html')
107
+
108
+ @app.route('/predict', methods=['POST'])
109
+ def predict():
110
+ if not model or not vocab:
111
+ return jsonify({'error': 'Model or vocabulary not loaded. Check server logs.'}), 500
112
+
113
+ data = request.get_json()
114
+ code_snippet = data.get('code', '')
115
+
116
+ if not code_snippet.strip():
117
+ return jsonify({'suggestions': []})
118
+
119
+ try:
120
+ suggestions = predict_next_tokens(model, code_snippet, vocab, device, max_length=MAX_LENGTH)
121
+ return jsonify({'suggestions': suggestions})
122
+ except Exception as e:
123
+ print(f"Prediction error: {e}")
124
+ return jsonify({'error': 'Failed to get prediction.'}), 500
125
+
126
+ if __name__ == '__main__':
127
+ app.run(debug=True)
main.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pickle
4
+
5
+ # --- Part 1: Re-define the Model Architecture ---
6
+ # This class definition must be EXACTLY the same as in your training script.
7
+
8
+ class ResidualLSTMModel(nn.Module):
9
+ def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob):
10
+ super(ResidualLSTMModel, self).__init__()
11
+ self.embedding = nn.Embedding(
12
+ num_embeddings=vocab_size,
13
+ embedding_dim=embedding_dim,
14
+ padding_idx=0
15
+ )
16
+ self.lstm1 = nn.LSTM(
17
+ input_size=embedding_dim,
18
+ hidden_size=hidden_units,
19
+ num_layers=1,
20
+ batch_first=True
21
+ )
22
+ self.lstm2 = nn.LSTM(
23
+ input_size=hidden_units,
24
+ hidden_size=hidden_units,
25
+ num_layers=1,
26
+ batch_first=True
27
+ )
28
+ self.dropout = nn.Dropout(dropout_prob)
29
+ self.fc = nn.Linear(hidden_units, vocab_size)
30
+
31
+ def forward(self, x):
32
+ embedded = self.embedding(x)
33
+ out1, _ = self.lstm1(embedded)
34
+ out2, _ = self.lstm2(out1)
35
+ residual_sum = out1 + out2
36
+ dropped_out = self.dropout(residual_sum)
37
+ logits = self.fc(dropped_out)
38
+ return logits
39
+
40
+ # --- Part 2: Helper Functions for Processing Text ---
41
+
42
+ def text_to_sequence(text, vocab, max_length):
43
+ """Converts a string of code into a padded tensor."""
44
+ tokens = text.split()
45
+ numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens]
46
+
47
+ if len(numericalized) > max_length:
48
+ numericalized = numericalized[:max_length]
49
+
50
+ pad_id = vocab['<PAD>']
51
+ padding_needed = max_length - len(numericalized)
52
+ padded = numericalized + [pad_id] * padding_needed
53
+
54
+ return torch.tensor([padded], dtype=torch.long)
55
+
56
+ def sequence_to_text(sequence, vocab):
57
+ """Converts a tensor of token IDs back to a string."""
58
+ id_to_token = {id_val: token for token, id_val in vocab.items()}
59
+ tokens = [id_to_token.get(id_val.item(), '<UNK>') for id_val in sequence if id_val.item() != vocab['<PAD>']]
60
+ return " ".join(tokens)
61
+
62
+ # --- Part 3: Main Prediction Logic ---
63
+
64
+ def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5):
65
+ """Predicts the top_k next tokens for a given text input."""
66
+ model.eval()
67
+ with torch.no_grad():
68
+ input_tensor = text_to_sequence(text, vocab, max_length).to(device)
69
+ logits = model(input_tensor)
70
+
71
+ num_input_tokens = len(text.split())
72
+ last_token_logits = logits[0, num_input_tokens - 1, :]
73
+
74
+ _, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1)
75
+ top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids]
76
+
77
+ return top_k_tokens
78
+
79
+ if __name__ == '__main__':
80
+ # --- Configuration ---
81
+ MODEL_PATH = 'model.pt'
82
+ VOCAB_PATH = 'vocab.pkl' # <-- Updated to use .pkl
83
+ MAX_LENGTH = 1000
84
+
85
+ # --- Load everything ---
86
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
+ print(f"Using device: {device}")
88
+
89
+ # Load vocabulary using pickle
90
+ with open(VOCAB_PATH, 'rb') as f: # <-- Use 'rb' for reading bytes
91
+ vocab = pickle.load(f)
92
+ print("Vocabulary loaded.")
93
+
94
+ # Load the model
95
+ model = torch.load(MODEL_PATH, map_location=device , weights_only=False)
96
+ print("Model loaded.")
97
+
98
+ # --- Make a Prediction ---
99
+ input_code = "import numpy as" # Example input
100
+
101
+ print(f"\nInput code: '{input_code}'")
102
+
103
+ suggestions = predict_next_tokens(model, input_code, vocab, device, max_length=MAX_LENGTH)
104
+
105
+ print("\nTop 5 suggestions:")
106
+ for i, suggestion in enumerate(suggestions):
107
+ print(f"{i+1}. {suggestion}")
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85ba12ee7eccdd7aed642f5dbcd46094cd5f32c501e3c237fe1d4e85ea11ac00
3
+ size 45484701
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Flask
2
+ torch
3
+ gunicorn
templates/index.html ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Code Completion AI</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <style>
9
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
10
+ html {
11
+ scroll-behavior: smooth;
12
+ }
13
+ body {
14
+ font-family: 'Inter', sans-serif;
15
+ background-color: #0f172a; /* slate-900 */
16
+ background-image: radial-gradient(circle at 1px 1px, rgba(255,255,255,0.05) 1px, transparent 0);
17
+ background-size: 2rem 2rem;
18
+ }
19
+ .suggestion-item {
20
+ transition: all 0.2s ease-in-out;
21
+ }
22
+ .info-card-grid {
23
+ display: grid;
24
+ grid-template-columns: repeat(auto-fit, minmax(140px, 1fr));
25
+ gap: 1rem;
26
+ }
27
+ </style>
28
+ </head>
29
+ <body class="text-gray-200 min-h-screen flex flex-col items-center justify-center p-4">
30
+
31
+ <main class="w-full max-w-5xl mx-auto grid grid-cols-1 lg:grid-cols-5 gap-8 lg:gap-12">
32
+
33
+ <!-- Left Column: Interaction -->
34
+ <div class="lg:col-span-3 bg-slate-900/50 backdrop-blur-sm border border-slate-700 rounded-2xl shadow-2xl p-6 md:p-8">
35
+ <div class="text-left mb-8">
36
+ <h1 class="text-4xl font-bold text-white tracking-tight">Code Completion AI</h1>
37
+ <p class="text-slate-400 mt-2">Enter a Python code snippet to get AI-powered suggestions.</p>
38
+ </div>
39
+
40
+ <div>
41
+ <label for="code-input" class="block text-sm font-medium text-slate-300 mb-2">Python Snippet</label>
42
+ <textarea id="code-input"
43
+ class="w-full h-48 p-4 bg-slate-900 border border-slate-700 rounded-lg text-slate-200 focus:ring-2 focus:ring-sky-500 focus:border-sky-500 transition duration-200 resize-none font-mono text-sm"
44
+ placeholder="e.g., import numpy as"></textarea>
45
+ </div>
46
+
47
+ <div class="mt-6 text-left">
48
+ <button id="predict-btn"
49
+ class="bg-sky-600 hover:bg-sky-700 text-white font-bold py-3 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-offset-slate-900 focus:ring-sky-500 flex items-center justify-center">
50
+ <span id="btn-text">Get Suggestions</span>
51
+ <span id="spinner" class="hidden">
52
+ <svg class="animate-spin h-5 w-5 text-white" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
53
+ <circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
54
+ <path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
55
+ </svg>
56
+ </span>
57
+ </button>
58
+ </div>
59
+
60
+ <div id="results-container" class="mt-8">
61
+ <h2 class="text-lg font-semibold text-white mb-3">Top 5 Suggestions</h2>
62
+ <div id="suggestions" class="bg-slate-900 p-3 rounded-lg min-h-[160px] border border-slate-700 space-y-1">
63
+ <p id="placeholder-text" class="text-slate-500 p-2">Suggestions will appear here...</p>
64
+ </div>
65
+ </div>
66
+ </div>
67
+
68
+ <!-- Right Column: Model Info -->
69
+ <div class="lg:col-span-2 space-y-6">
70
+ <div class="bg-slate-900/50 backdrop-blur-sm border border-slate-700 rounded-2xl shadow-xl p-6">
71
+ <h3 class="text-2xl font-bold text-white mb-4">Model Details</h3>
72
+ <p class="text-slate-400 mb-6">
73
+ This app uses a <span class="text-sky-400 font-semibold">Residual LSTM</span> model with two LSTM layers and a skip connection. It was trained on the Python subset of the <span class="text-sky-400">CodeXGlue</span> dataset to predict the next token in a sequence.
74
+ </p>
75
+ <div class="info-card-grid">
76
+ <div class="bg-slate-800 p-4 rounded-lg border border-slate-700">
77
+ <p class="text-sm text-slate-400">Top-5 Accuracy</p>
78
+ <p class="text-xl font-semibold text-white">86.82%</p>
79
+ </div>
80
+ <div class="bg-slate-800 p-4 rounded-lg border border-slate-700">
81
+ <p class="text-sm text-slate-400">Perplexity</p>
82
+ <p class="text-xl font-semibold text-white">4.19</p>
83
+ </div>
84
+ <div class="bg-slate-800 p-4 rounded-lg border border-slate-700">
85
+ <p class="text-sm text-slate-400">Embedding Dim</p>
86
+ <p class="text-xl font-semibold text-white">256</p>
87
+ </div>
88
+ <div class="bg-slate-800 p-4 rounded-lg border border-slate-700">
89
+ <p class="text-sm text-slate-400">Hidden Units</p>
90
+ <p class="text-xl font-semibold text-white">512</p>
91
+ </div>
92
+ <div class="bg-slate-800 p-4 rounded-lg border border-slate-700">
93
+ <p class="text-sm text-slate-400">Vocab Size</p>
94
+ <p class="text-xl font-semibold text-white">10,002</p>
95
+ </div>
96
+ <div class="bg-slate-800 p-4 rounded-lg border border-slate-700">
97
+ <p class="text-sm text-slate-400">Parameters</p>
98
+ <p class="text-xl font-semibold text-white">~15.5 M</p>
99
+ </div>
100
+ </div>
101
+ </div>
102
+ </div>
103
+ </main>
104
+
105
+ <footer class="w-full max-w-5xl mx-auto text-center text-slate-500 py-8 mt-4">
106
+ <p>Made with ❤️ by Ayush</p>
107
+ </footer>
108
+
109
+ <script>
110
+ const codeInput = document.getElementById('code-input');
111
+ const predictBtn = document.getElementById('predict-btn');
112
+ const suggestionsDiv = document.getElementById('suggestions');
113
+ const placeholderText = document.getElementById('placeholder-text');
114
+ const btnText = document.getElementById('btn-text');
115
+ const spinner = document.getElementById('spinner');
116
+
117
+ let debounceTimer;
118
+
119
+ const getPredictions = async () => {
120
+ const code = codeInput.value;
121
+ if (code.trim() === '') {
122
+ suggestionsDiv.innerHTML = '<p id="placeholder-text" class="text-slate-500 p-2">Suggestions will appear here...</p>';
123
+ return;
124
+ }
125
+
126
+ btnText.classList.add('hidden');
127
+ spinner.classList.remove('hidden');
128
+ predictBtn.disabled = true;
129
+
130
+ try {
131
+ const response = await fetch('/predict', {
132
+ method: 'POST',
133
+ headers: { 'Content-Type': 'application/json' },
134
+ body: JSON.stringify({ code: code }),
135
+ });
136
+
137
+ if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`);
138
+ const data = await response.json();
139
+
140
+ if (data.error) {
141
+ suggestionsDiv.innerHTML = `<p class="text-red-400 p-2">${data.error}</p>`;
142
+ return;
143
+ }
144
+
145
+ if (data.suggestions && data.suggestions.length > 0) {
146
+ suggestionsDiv.innerHTML = '';
147
+ data.suggestions.forEach(suggestion => {
148
+ const p = document.createElement('p');
149
+ p.textContent = suggestion;
150
+ p.className = 'suggestion-item p-2 rounded hover:bg-slate-700 cursor-pointer text-slate-300';
151
+ p.onclick = () => {
152
+ const lastCharIsSpace = codeInput.value.slice(-1) === ' ';
153
+ codeInput.value += (lastCharIsSpace ? '' : ' ') + suggestion;
154
+ codeInput.focus();
155
+ getPredictions();
156
+ };
157
+ suggestionsDiv.appendChild(p);
158
+ });
159
+ } else {
160
+ suggestionsDiv.innerHTML = '<p class="text-slate-500 p-2">No suggestions found.</p>';
161
+ }
162
+
163
+ } catch (error) {
164
+ console.error('Error:', error);
165
+ suggestionsDiv.innerHTML = '<p class="text-red-400 p-2">An error occurred. Check server logs.</p>';
166
+ } finally {
167
+ btnText.classList.remove('hidden');
168
+ spinner.classList.add('hidden');
169
+ predictBtn.disabled = false;
170
+ }
171
+ };
172
+
173
+ predictBtn.addEventListener('click', getPredictions);
174
+
175
+ codeInput.addEventListener('input', () => {
176
+ clearTimeout(debounceTimer);
177
+ debounceTimer = setTimeout(getPredictions, 500); // 500ms debounce
178
+ });
179
+
180
+ </script>
181
+ </body>
182
+ </html>
183
+
vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c19847861d949bbb046b890ac5fe8b0b11117eeeca46801ca82815ae3f071dcf
3
+ size 131947