vinay0123 commited on
Commit
a29ef81
·
verified ·
1 Parent(s): c1f1f3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -52
app.py CHANGED
@@ -1,35 +1,58 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.optim as optim
4
  import pandas as pd
5
- from torch.utils.data import Dataset, DataLoader
6
  from flask import Flask, request, jsonify, Response, stream_with_context
7
  from sklearn.model_selection import train_test_split
8
  import os
9
  import time
10
  import json
 
 
 
 
 
 
 
 
11
 
12
  url = "https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY"
13
  df = pd.read_csv(url)
14
- # Tokenizer
 
15
  class ScratchTokenizer:
16
  def __init__(self):
17
- self.word2idx = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
18
- self.idx2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
19
  self.vocab_size = 4
 
20
 
21
  def build_vocab(self, texts):
 
 
22
  for text in texts:
23
- for word in text.split():
24
- if word not in self.word2idx:
25
- self.word2idx[word] = self.vocab_size
26
- self.idx2word[self.vocab_size] = word
27
- self.vocab_size += 1
 
 
28
 
29
  def encode(self, text, max_len=200):
 
 
 
 
 
30
  tokens = [self.word2idx.get(word, 3) for word in text.split()]
31
  tokens = [1] + tokens[:max_len - 2] + [2]
32
- return tokens + [0] * (max_len - len(tokens))
 
 
 
 
 
 
33
 
34
  def decode(self, tokens):
35
  return " ".join([self.idx2word.get(idx, "<UNK>") for idx in tokens if idx > 0])
@@ -41,86 +64,246 @@ train_data, test_data = train_test_split(df, test_size=0.2, random_state=42)
41
  tokenizer = ScratchTokenizer()
42
  tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist())
43
 
44
- # Model
45
  class GPTModel(nn.Module):
46
  def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, max_len=200):
47
  super(GPTModel, self).__init__()
 
48
  self.embedding = nn.Embedding(vocab_size, embed_size)
49
  self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_size))
50
- self.transformer = nn.TransformerDecoder(
51
- nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads),
52
- num_layers=num_layers
 
 
 
 
 
 
53
  )
 
 
54
  self.fc_out = nn.Linear(embed_size, vocab_size)
 
55
 
56
  def forward(self, src, tgt):
57
  src_emb = self.embedding(src) + self.pos_embedding[:, :src.size(1), :]
58
  tgt_emb = self.embedding(tgt) + self.pos_embedding[:, :tgt.size(1), :]
 
59
  tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
60
- output = self.transformer(tgt_emb.permute(1, 0, 2), src_emb.permute(1, 0, 2), tgt_mask=tgt_mask)
 
 
 
 
 
61
  return self.fc_out(output.permute(1, 0, 2))
62
 
63
  # Load model
64
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
  model = GPTModel(tokenizer.vocab_size).to(device)
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  def load_model(model, path="gpt_model.pth"):
68
  if os.path.exists(path):
69
- model.load_state_dict(torch.load(path, map_location=device))
 
 
 
 
 
 
 
 
70
  model.eval()
71
  print("Model loaded successfully.")
72
  else:
73
- print("Model file not found!")
74
-
75
 
76
- def generate_response_stream(model, query, max_length=200):
 
 
77
  model.eval()
 
78
  with torch.no_grad():
79
- src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
80
- tgt = torch.tensor([[1]]).to(device) # <SOS>
 
81
 
82
- for _ in range(max_length):
83
- output = model(src, tgt)
84
- next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
85
- tgt = torch.cat([tgt, next_token], dim=1)
86
-
87
- # Get the current word
88
- current_word = tokenizer.idx2word.get(next_token.item(), "<UNK>")
89
- if current_word != "<PAD>":
90
- yield current_word + " "
91
-
92
- if next_token.item() == 2: # <EOS>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  break
 
 
 
 
 
 
 
 
94
 
95
- # Flask App
96
  app = Flask(__name__)
97
 
98
  @app.route("/")
99
  def home():
100
- return {"message": "Streaming Transformer-based Response Generator API is running!"}
 
 
 
 
 
 
 
 
 
101
 
102
  @app.route("/intent")
103
  def intents():
104
- return jsonify({"intents": list(set(df['intent'].dropna()))})
 
 
 
105
 
106
  @app.route("/query", methods=["POST"])
107
  def query_model():
108
- data = request.get_json()
109
- query = data.get("query", "")
110
- if not query:
111
- return jsonify({"error": "Query cannot be empty"}), 400
112
-
113
- def generate():
114
- start = time.time()
115
- for word in generate_response_stream(model, query):
116
- response_data = {
117
- "word": word,
118
- "timestamp": time.time() - start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  }
120
- yield f"data: {json.dumps(response_data)}\n\n"
121
-
122
- return Response(stream_with_context(generate()), mimetype='text/event-stream')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  if __name__ == "__main__":
 
125
  load_model(model)
126
- app.run(host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
 
3
  import pandas as pd
 
4
  from flask import Flask, request, jsonify, Response, stream_with_context
5
  from sklearn.model_selection import train_test_split
6
  import os
7
  import time
8
  import json
9
+ import threading
10
+ from queue import Queue
11
+ import multiprocessing
12
+
13
+ # Optimize for Hugging Face Spaces CPU limits
14
+ num_cores = min(multiprocessing.cpu_count(), 4) # HF Spaces usually have 2-4 cores
15
+ torch.set_num_threads(num_cores)
16
+ torch.set_num_interop_threads(num_cores)
17
 
18
  url = "https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY"
19
  df = pd.read_csv(url)
20
+
21
+ # Optimized Tokenizer with caching
22
  class ScratchTokenizer:
23
  def __init__(self):
24
+ self.word2idx = {"<PAD>": 0, "< SOS >": 1, "<EOS>": 2, "<UNK>": 3}
25
+ self.idx2word = {0: "<PAD>", 1: "< SOS >", 2: "<EOS>", 3: "<UNK>"}
26
  self.vocab_size = 4
27
+ self._encode_cache = {} # Cache for faster encoding
28
 
29
  def build_vocab(self, texts):
30
+ # Optimized vocabulary building
31
+ unique_words = set()
32
  for text in texts:
33
+ unique_words.update(text.split())
34
+
35
+ for word in sorted(unique_words): # Sort for consistent ordering
36
+ if word not in self.word2idx:
37
+ self.word2idx[word] = self.vocab_size
38
+ self.idx2word[self.vocab_size] = word
39
+ self.vocab_size += 1
40
 
41
  def encode(self, text, max_len=200):
42
+ # Use cache for repeated queries
43
+ cache_key = (text, max_len)
44
+ if cache_key in self._encode_cache:
45
+ return self._encode_cache[cache_key]
46
+
47
  tokens = [self.word2idx.get(word, 3) for word in text.split()]
48
  tokens = [1] + tokens[:max_len - 2] + [2]
49
+ encoded = tokens + [0] * (max_len - len(tokens))
50
+
51
+ # Cache result
52
+ if len(self._encode_cache) < 1000: # Limit cache size
53
+ self._encode_cache[cache_key] = encoded
54
+
55
+ return encoded
56
 
57
  def decode(self, tokens):
58
  return " ".join([self.idx2word.get(idx, "<UNK>") for idx in tokens if idx > 0])
 
64
  tokenizer = ScratchTokenizer()
65
  tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist())
66
 
67
+ # Optimized Model for HF Spaces
68
  class GPTModel(nn.Module):
69
  def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, max_len=200):
70
  super(GPTModel, self).__init__()
71
+ # Reduced model size for HF Spaces memory limits
72
  self.embedding = nn.Embedding(vocab_size, embed_size)
73
  self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_size))
74
+
75
+ decoder_layer = nn.TransformerDecoderLayer(
76
+ d_model=embed_size,
77
+ nhead=num_heads,
78
+ dim_feedforward=embed_size * 2, # Reduced from 4x to 2x
79
+ dropout=0.1,
80
+ activation='gelu',
81
+ batch_first=False,
82
+ norm_first=True # Pre-norm for better stability
83
  )
84
+
85
+ self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
86
  self.fc_out = nn.Linear(embed_size, vocab_size)
87
+ self.max_len = max_len
88
 
89
  def forward(self, src, tgt):
90
  src_emb = self.embedding(src) + self.pos_embedding[:, :src.size(1), :]
91
  tgt_emb = self.embedding(tgt) + self.pos_embedding[:, :tgt.size(1), :]
92
+
93
  tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
94
+
95
+ output = self.transformer(
96
+ tgt_emb.permute(1, 0, 2),
97
+ src_emb.permute(1, 0, 2),
98
+ tgt_mask=tgt_mask
99
+ )
100
  return self.fc_out(output.permute(1, 0, 2))
101
 
102
  # Load model
103
+ device = torch.device("cpu") # HF Spaces typically CPU-only
104
  model = GPTModel(tokenizer.vocab_size).to(device)
105
 
106
+ # Try to optimize with torch.jit if available
107
+ try:
108
+ # Create a traced model for faster inference
109
+ sample_src = torch.randint(0, tokenizer.vocab_size, (1, 50))
110
+ sample_tgt = torch.randint(0, tokenizer.vocab_size, (1, 10))
111
+ traced_model = torch.jit.trace(model, (sample_src, sample_tgt))
112
+ model = traced_model
113
+ print("Model traced with TorchScript for faster inference")
114
+ except Exception as e:
115
+ print(f"TorchScript tracing failed: {e}, using regular model")
116
+
117
  def load_model(model, path="gpt_model.pth"):
118
  if os.path.exists(path):
119
+ # Load with CPU mapping for HF Spaces
120
+ checkpoint = torch.load(path, map_location='cpu')
121
+
122
+ # Handle different checkpoint formats
123
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
124
+ model.load_state_dict(checkpoint['state_dict'])
125
+ else:
126
+ model.load_state_dict(checkpoint)
127
+
128
  model.eval()
129
  print("Model loaded successfully.")
130
  else:
131
+ print("Model file not found! Using randomly initialized model.")
 
132
 
133
+ # Optimized generation with batching and early stopping
134
+ def generate_response_stream_fast(model, query, max_length=200, chunk_size=3):
135
+ """Optimized generation for HF Spaces"""
136
  model.eval()
137
+
138
  with torch.no_grad():
139
+ # Use smaller sequences for HF Spaces
140
+ src = torch.tensor(tokenizer.encode(query, max_len=200)).unsqueeze(0).to(device)
141
+ tgt = torch.tensor([[1]]).to(device) # SOS token
142
 
143
+ words_buffer = []
144
+ consecutive_repeats = 0
145
+ last_word = ""
146
+
147
+ for step in range(max_length):
148
+ try:
149
+ output = model(src, tgt)
150
+
151
+ # Use top-k sampling instead of greedy for better responses
152
+ logits = output[:, -1, :] / 0.8 # Temperature scaling
153
+ top_k = torch.topk(logits, k=5)
154
+ probs = torch.softmax(top_k.values, dim=-1)
155
+ next_token_idx = torch.multinomial(probs, 1)
156
+ next_token = top_k.indices.gather(-1, next_token_idx)
157
+
158
+ tgt = torch.cat([tgt, next_token], dim=1)
159
+
160
+ token_id = next_token.item()
161
+ if token_id == 2: # EOS
162
+ break
163
+
164
+ word = tokenizer.idx2word.get(token_id, "<UNK>")
165
+
166
+ # Skip special tokens and repeated words
167
+ if word in ["<PAD>", "< SOS >", "<EOS>", "<UNK>"]:
168
+ continue
169
+
170
+ # Prevent infinite loops
171
+ if word == last_word:
172
+ consecutive_repeats += 1
173
+ if consecutive_repeats > 2:
174
+ continue
175
+ else:
176
+ consecutive_repeats = 0
177
+ last_word = word
178
+
179
+ words_buffer.append(word)
180
+
181
+ # Stream in chunks for better perceived performance
182
+ if len(words_buffer) >= chunk_size:
183
+ chunk_text = " ".join(words_buffer) + " "
184
+ words_buffer = []
185
+ yield chunk_text
186
+
187
+ except Exception as e:
188
+ print(f"Generation error at step {step}: {e}")
189
  break
190
+
191
+ # Yield remaining words
192
+ if words_buffer:
193
+ yield " ".join(words_buffer) + " "
194
+
195
+ # Simple request queue for better CPU utilization
196
+ request_queue = Queue(maxsize=10)
197
+ processing_lock = threading.Lock()
198
 
199
+ # Flask App optimized for HF Spaces
200
  app = Flask(__name__)
201
 
202
  @app.route("/")
203
  def home():
204
+ return {
205
+ "message": "HF Spaces Optimized Transformer API",
206
+ "status": "running",
207
+ "device": str(device),
208
+ "vocab_size": tokenizer.vocab_size
209
+ }
210
+
211
+ @app.route("/health")
212
+ def health():
213
+ return {"status": "healthy", "model_loaded": True}
214
 
215
  @app.route("/intent")
216
  def intents():
217
+ try:
218
+ return jsonify({"intents": list(set(df['intent'].dropna()))})
219
+ except Exception as e:
220
+ return jsonify({"error": str(e)}), 500
221
 
222
  @app.route("/query", methods=["POST"])
223
  def query_model():
224
+ try:
225
+ data = request.get_json()
226
+ query = data.get("query", "").strip()
227
+
228
+ if not query:
229
+ return jsonify({"error": "Query cannot be empty"}), 400
230
+
231
+ if len(query) > 500: # Limit input length for HF Spaces
232
+ query = query[:500]
233
+
234
+ def generate():
235
+ start_time = time.time()
236
+ word_count = 0
237
+
238
+ try:
239
+ for chunk in generate_response_stream_fast(model, query, max_length=50):
240
+ word_count += len(chunk.split())
241
+ response_data = {
242
+ "chunk": chunk,
243
+ "timestamp": time.time() - start_time,
244
+ "word_count": word_count
245
+ }
246
+ yield f"data: {json.dumps(response_data)}\n\n"
247
+
248
+ # Prevent very long responses on HF Spaces
249
+ if word_count > 100:
250
+ break
251
+
252
+ except Exception as e:
253
+ error_data = {
254
+ "error": f"Generation failed: {str(e)}",
255
+ "timestamp": time.time() - start_time
256
+ }
257
+ yield f"data: {json.dumps(error_data)}\n\n"
258
+
259
+ return Response(
260
+ stream_with_context(generate()),
261
+ mimetype='text/event-stream',
262
+ headers={
263
+ 'Cache-Control': 'no-cache',
264
+ 'Connection': 'keep-alive',
265
+ 'Access-Control-Allow-Origin': '*'
266
  }
267
+ )
268
+
269
+ except Exception as e:
270
+ return jsonify({"error": str(e)}), 500
271
+
272
+ @app.route("/simple_query", methods=["POST"])
273
+ def simple_query():
274
+ """Non-streaming endpoint for simpler clients"""
275
+ try:
276
+ data = request.get_json()
277
+ query = data.get("query", "").strip()
278
+
279
+ if not query:
280
+ return jsonify({"error": "Query cannot be empty"}), 400
281
+
282
+ start_time = time.time()
283
+ response_text = ""
284
+
285
+ for chunk in generate_response_stream_fast(model, query, max_length=50):
286
+ response_text += chunk
287
+
288
+ return jsonify({
289
+ "query": query,
290
+ "response": response_text.strip(),
291
+ "processing_time": time.time() - start_time
292
+ })
293
+
294
+ except Exception as e:
295
+ return jsonify({"error": str(e)}), 500
296
 
297
  if __name__ == "__main__":
298
+ print("Loading model...")
299
  load_model(model)
300
+ print("Starting HF Spaces optimized server...")
301
+
302
+ # HF Spaces compatible settings
303
+ port = int(os.environ.get("PORT", 7860))
304
+ app.run(
305
+ host="0.0.0.0",
306
+ port=port,
307
+ debug=False, # Disable debug for production
308
+ threaded=True
309
+ )