vinay0123 commited on
Commit
5154835
·
verified ·
1 Parent(s): f7a4422

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -222
app.py CHANGED
@@ -1,58 +1,43 @@
1
  import torch
2
  import torch.nn as nn
 
3
  import pandas as pd
 
4
  from flask import Flask, request, jsonify, Response, stream_with_context
5
  from sklearn.model_selection import train_test_split
6
  import os
7
  import time
8
  import json
9
- import threading
10
- from queue import Queue
11
- import multiprocessing
12
 
13
- # Optimize for Hugging Face Spaces CPU limits
14
- num_cores = min(multiprocessing.cpu_count(), 4) # HF Spaces usually have 2-4 cores
15
- torch.set_num_threads(num_cores)
16
- torch.set_num_interop_threads(num_cores)
 
 
17
 
18
  url = "https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY"
19
  df = pd.read_csv(url)
20
 
21
- # Optimized Tokenizer with caching
22
  class ScratchTokenizer:
23
  def __init__(self):
24
  self.word2idx = {"<PAD>": 0, "< SOS >": 1, "<EOS>": 2, "<UNK>": 3}
25
  self.idx2word = {0: "<PAD>", 1: "< SOS >", 2: "<EOS>", 3: "<UNK>"}
26
  self.vocab_size = 4
27
- self._encode_cache = {} # Cache for faster encoding
28
 
29
  def build_vocab(self, texts):
30
- # Optimized vocabulary building
31
- unique_words = set()
32
  for text in texts:
33
- unique_words.update(text.split())
34
-
35
- for word in sorted(unique_words): # Sort for consistent ordering
36
- if word not in self.word2idx:
37
- self.word2idx[word] = self.vocab_size
38
- self.idx2word[self.vocab_size] = word
39
- self.vocab_size += 1
40
 
41
  def encode(self, text, max_len=200):
42
- # Use cache for repeated queries
43
- cache_key = (text, max_len)
44
- if cache_key in self._encode_cache:
45
- return self._encode_cache[cache_key]
46
-
47
  tokens = [self.word2idx.get(word, 3) for word in text.split()]
48
  tokens = [1] + tokens[:max_len - 2] + [2]
49
- encoded = tokens + [0] * (max_len - len(tokens))
50
-
51
- # Cache result
52
- if len(self._encode_cache) < 1000: # Limit cache size
53
- self._encode_cache[cache_key] = encoded
54
-
55
- return encoded
56
 
57
  def decode(self, tokens):
58
  return " ".join([self.idx2word.get(idx, "<UNK>") for idx in tokens if idx > 0])
@@ -64,246 +49,131 @@ train_data, test_data = train_test_split(df, test_size=0.2, random_state=42)
64
  tokenizer = ScratchTokenizer()
65
  tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist())
66
 
67
- # Optimized Model for HF Spaces
68
  class GPTModel(nn.Module):
69
  def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, max_len=200):
70
  super(GPTModel, self).__init__()
71
- # Reduced model size for HF Spaces memory limits
72
  self.embedding = nn.Embedding(vocab_size, embed_size)
73
  self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_size))
74
-
75
- decoder_layer = nn.TransformerDecoderLayer(
76
- d_model=embed_size,
77
- nhead=num_heads,
78
- dim_feedforward=embed_size * 2, # Reduced from 4x to 2x
79
- dropout=0.1,
80
- activation='gelu',
81
- batch_first=False,
82
- norm_first=True # Pre-norm for better stability
83
  )
84
-
85
- self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
86
  self.fc_out = nn.Linear(embed_size, vocab_size)
87
- self.max_len = max_len
88
 
89
  def forward(self, src, tgt):
90
  src_emb = self.embedding(src) + self.pos_embedding[:, :src.size(1), :]
91
  tgt_emb = self.embedding(tgt) + self.pos_embedding[:, :tgt.size(1), :]
92
-
93
  tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
94
-
95
- output = self.transformer(
96
- tgt_emb.permute(1, 0, 2),
97
- src_emb.permute(1, 0, 2),
98
- tgt_mask=tgt_mask
99
- )
100
  return self.fc_out(output.permute(1, 0, 2))
101
 
102
  # Load model
103
- device = torch.device("cpu") # HF Spaces typically CPU-only
104
  model = GPTModel(tokenizer.vocab_size).to(device)
105
 
106
- # Try to optimize with torch.jit if available
107
- try:
108
- # Create a traced model for faster inference
109
- sample_src = torch.randint(0, tokenizer.vocab_size, (1, 50))
110
- sample_tgt = torch.randint(0, tokenizer.vocab_size, (1, 10))
111
- traced_model = torch.jit.trace(model, (sample_src, sample_tgt))
112
- model = traced_model
113
- print("Model traced with TorchScript for faster inference")
114
- except Exception as e:
115
- print(f"TorchScript tracing failed: {e}, using regular model")
116
-
117
  def load_model(model, path="gpt_model.pth"):
118
  if os.path.exists(path):
119
- # Load with CPU mapping for HF Spaces
120
- checkpoint = torch.load(path, map_location='cpu')
121
-
122
- # Handle different checkpoint formats
123
- if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
124
- model.load_state_dict(checkpoint['state_dict'])
125
- else:
126
- model.load_state_dict(checkpoint)
127
-
128
  model.eval()
 
 
 
129
  print("Model loaded successfully.")
130
  else:
131
- print("Model file not found! Using randomly initialized model.")
 
132
 
133
- # Optimized generation with batching and early stopping
134
- def generate_response_stream_fast(model, query, max_length=200, chunk_size=3):
135
- """Optimized generation for HF Spaces"""
136
  model.eval()
137
 
 
 
 
 
 
 
 
 
138
  with torch.no_grad():
139
- # Use smaller sequences for HF Spaces
140
- src = torch.tensor(tokenizer.encode(query, max_len=200)).unsqueeze(0).to(device)
141
- tgt = torch.tensor([[1]]).to(device) # SOS token
142
-
143
- words_buffer = []
144
- consecutive_repeats = 0
145
- last_word = ""
146
-
147
- for step in range(max_length):
148
- try:
149
  output = model(src, tgt)
150
 
151
- # Use top-k sampling instead of greedy for better responses
152
- logits = output[:, -1, :] / 0.8 # Temperature scaling
153
- top_k = torch.topk(logits, k=5)
154
- probs = torch.softmax(top_k.values, dim=-1)
155
- next_token_idx = torch.multinomial(probs, 1)
156
- next_token = top_k.indices.gather(-1, next_token_idx)
157
 
158
- tgt = torch.cat([tgt, next_token], dim=1)
159
-
160
- token_id = next_token.item()
161
- if token_id == 2: # EOS
162
  break
163
-
164
- word = tokenizer.idx2word.get(token_id, "<UNK>")
165
 
166
- # Skip special tokens and repeated words
167
- if word in ["<PAD>", "< SOS >", "<EOS>", "<UNK>"]:
168
- continue
169
-
170
- # Prevent infinite loops
171
- if word == last_word:
172
- consecutive_repeats += 1
173
- if consecutive_repeats > 2:
174
- continue
175
- else:
176
- consecutive_repeats = 0
177
- last_word = word
178
 
179
- words_buffer.append(word)
 
 
 
180
 
181
- # Stream in chunks for better perceived performance
182
- if len(words_buffer) >= chunk_size:
183
- chunk_text = " ".join(words_buffer) + " "
184
- words_buffer = []
185
- yield chunk_text
186
-
187
- except Exception as e:
188
- print(f"Generation error at step {step}: {e}")
189
- break
190
-
191
- # Yield remaining words
192
- if words_buffer:
193
- yield " ".join(words_buffer) + " "
194
-
195
- # Simple request queue for better CPU utilization
196
- request_queue = Queue(maxsize=10)
197
- processing_lock = threading.Lock()
198
 
199
- # Flask App optimized for HF Spaces
200
  app = Flask(__name__)
201
 
 
 
 
202
  @app.route("/")
203
  def home():
204
- return {
205
- "message": "HF Spaces Optimized Transformer API",
206
- "status": "running",
207
- "device": str(device),
208
- "vocab_size": tokenizer.vocab_size
209
- }
210
-
211
- @app.route("/health")
212
- def health():
213
- return {"status": "healthy", "model_loaded": True}
214
 
215
  @app.route("/intent")
216
  def intents():
217
- try:
218
- return jsonify({"intents": list(set(df['intent'].dropna()))})
219
- except Exception as e:
220
- return jsonify({"error": str(e)}), 500
221
 
222
  @app.route("/query", methods=["POST"])
223
  def query_model():
224
- try:
225
- data = request.get_json()
226
- query = data.get("query", "").strip()
227
-
228
- if not query:
229
- return jsonify({"error": "Query cannot be empty"}), 400
230
-
231
- if len(query) > 500: # Limit input length for HF Spaces
232
- query = query[:500]
233
-
234
- def generate():
235
- start_time = time.time()
236
- word_count = 0
237
-
238
- try:
239
- for chunk in generate_response_stream_fast(model, query, max_length=50):
240
- word_count += len(chunk.split())
241
- response_data = {
242
- "chunk": chunk,
243
- "timestamp": time.time() - start_time,
244
- "word_count": word_count
245
- }
246
- yield f"data: {json.dumps(response_data)}\n\n"
247
-
248
- # Prevent very long responses on HF Spaces
249
- if word_count > 100:
250
- break
251
-
252
- except Exception as e:
253
- error_data = {
254
- "error": f"Generation failed: {str(e)}",
255
- "timestamp": time.time() - start_time
256
- }
257
- yield f"data: {json.dumps(error_data)}\n\n"
258
-
259
- return Response(
260
- stream_with_context(generate()),
261
- mimetype='text/event-stream',
262
- headers={
263
- 'Cache-Control': 'no-cache',
264
- 'Connection': 'keep-alive',
265
- 'Access-Control-Allow-Origin': '*'
266
  }
267
- )
268
-
269
- except Exception as e:
270
- return jsonify({"error": str(e)}), 500
271
-
272
- @app.route("/simple_query", methods=["POST"])
273
- def simple_query():
274
- """Non-streaming endpoint for simpler clients"""
275
- try:
276
- data = request.get_json()
277
- query = data.get("query", "").strip()
278
-
279
- if not query:
280
- return jsonify({"error": "Query cannot be empty"}), 400
281
-
282
- start_time = time.time()
283
- response_text = ""
284
-
285
- for chunk in generate_response_stream_fast(model, query, max_length=50):
286
- response_text += chunk
287
-
288
- return jsonify({
289
- "query": query,
290
- "response": response_text.strip(),
291
- "processing_time": time.time() - start_time
292
- })
293
-
294
- except Exception as e:
295
- return jsonify({"error": str(e)}), 500
296
 
297
  if __name__ == "__main__":
298
- print("Loading model...")
299
- load_model(model)
300
- print("Starting HF Spaces optimized server...")
301
 
302
- # HF Spaces compatible settings
303
- port = int(os.environ.get("PORT", 7860))
304
  app.run(
305
  host="0.0.0.0",
306
- port=port,
307
- debug=False, # Disable debug for production
308
- threaded=True
 
309
  )
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.optim as optim
4
  import pandas as pd
5
+ from torch.utils.data import Dataset, DataLoader
6
  from flask import Flask, request, jsonify, Response, stream_with_context
7
  from sklearn.model_selection import train_test_split
8
  import os
9
  import time
10
  import json
 
 
 
11
 
12
+ # Set PyTorch to use all available CPU threads
13
+ torch.set_num_threads(os.cpu_count())
14
+ torch.set_num_interop_threads(os.cpu_count())
15
+
16
+ # Enable optimizations
17
+ torch.backends.mkldnn.enabled = True if hasattr(torch.backends, 'mkldnn') else False
18
 
19
  url = "https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY"
20
  df = pd.read_csv(url)
21
 
22
+ # Tokenizer
23
  class ScratchTokenizer:
24
  def __init__(self):
25
  self.word2idx = {"<PAD>": 0, "< SOS >": 1, "<EOS>": 2, "<UNK>": 3}
26
  self.idx2word = {0: "<PAD>", 1: "< SOS >", 2: "<EOS>", 3: "<UNK>"}
27
  self.vocab_size = 4
 
28
 
29
  def build_vocab(self, texts):
 
 
30
  for text in texts:
31
+ for word in text.split():
32
+ if word not in self.word2idx:
33
+ self.word2idx[word] = self.vocab_size
34
+ self.idx2word[self.vocab_size] = word
35
+ self.vocab_size += 1
 
 
36
 
37
  def encode(self, text, max_len=200):
 
 
 
 
 
38
  tokens = [self.word2idx.get(word, 3) for word in text.split()]
39
  tokens = [1] + tokens[:max_len - 2] + [2]
40
+ return tokens + [0] * (max_len - len(tokens))
 
 
 
 
 
 
41
 
42
  def decode(self, tokens):
43
  return " ".join([self.idx2word.get(idx, "<UNK>") for idx in tokens if idx > 0])
 
49
  tokenizer = ScratchTokenizer()
50
  tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist())
51
 
52
+ # Model
53
  class GPTModel(nn.Module):
54
  def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, max_len=200):
55
  super(GPTModel, self).__init__()
 
56
  self.embedding = nn.Embedding(vocab_size, embed_size)
57
  self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_size))
58
+ self.transformer = nn.TransformerDecoder(
59
+ nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads),
60
+ num_layers=num_layers
 
 
 
 
 
 
61
  )
 
 
62
  self.fc_out = nn.Linear(embed_size, vocab_size)
 
63
 
64
  def forward(self, src, tgt):
65
  src_emb = self.embedding(src) + self.pos_embedding[:, :src.size(1), :]
66
  tgt_emb = self.embedding(tgt) + self.pos_embedding[:, :tgt.size(1), :]
 
67
  tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
68
+ output = self.transformer(tgt_emb.permute(1, 0, 2), src_emb.permute(1, 0, 2), tgt_mask=tgt_mask)
 
 
 
 
 
69
  return self.fc_out(output.permute(1, 0, 2))
70
 
71
  # Load model
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  model = GPTModel(tokenizer.vocab_size).to(device)
74
 
 
 
 
 
 
 
 
 
 
 
 
75
  def load_model(model, path="gpt_model.pth"):
76
  if os.path.exists(path):
77
+ model.load_state_dict(torch.load(path, map_location=device, weights_only=True))
 
 
 
 
 
 
 
 
78
  model.eval()
79
+ # Enable inference optimizations
80
+ if hasattr(torch.jit, 'optimize_for_inference'):
81
+ model = torch.jit.optimize_for_inference(torch.jit.script(model))
82
  print("Model loaded successfully.")
83
  else:
84
+ print("Model file not found!")
85
+ return model
86
 
87
+ def generate_response_stream(model, query, max_length=200):
 
 
88
  model.eval()
89
 
90
+ # Pre-encode the query once
91
+ src_tokens = tokenizer.encode(query)
92
+ src = torch.tensor(src_tokens).unsqueeze(0).to(device)
93
+ tgt = torch.tensor([[1]], dtype=torch.long).to(device) # < SOS >
94
+
95
+ # Pre-allocate tensor for better memory efficiency
96
+ max_tgt_len = min(max_length, 200)
97
+
98
  with torch.no_grad():
99
+ # Use torch.inference_mode for better performance
100
+ with torch.inference_mode():
101
+ for step in range(max_length):
102
+ # Forward pass
 
 
 
 
 
 
103
  output = model(src, tgt)
104
 
105
+ # Get next token more efficiently
106
+ logits = output[:, -1, :]
107
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
 
 
 
108
 
109
+ # Check for EOS early
110
+ if next_token.item() == 2: # <EOS>
 
 
111
  break
 
 
112
 
113
+ # Concatenate token
114
+ tgt = torch.cat([tgt, next_token], dim=1)
 
 
 
 
 
 
 
 
 
 
115
 
116
+ # Get the current word
117
+ current_word = tokenizer.idx2word.get(next_token.item(), "<UNK>")
118
+ if current_word not in ["<PAD>", "<EOS>", "< SOS >"]:
119
+ yield current_word + " "
120
 
121
+ # Prevent infinite loops
122
+ if tgt.size(1) >= max_tgt_len:
123
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ # Flask App with threading optimizations
126
  app = Flask(__name__)
127
 
128
+ # Configure Flask for better performance
129
+ app.config['THREADED'] = True
130
+
131
  @app.route("/")
132
  def home():
133
+ return {"message": "Streaming Transformer-based Response Generator API is running!"}
 
 
 
 
 
 
 
 
 
134
 
135
  @app.route("/intent")
136
  def intents():
137
+ return jsonify({"intents": list(set(df['intent'].dropna()))})
 
 
 
138
 
139
  @app.route("/query", methods=["POST"])
140
  def query_model():
141
+ data = request.get_json()
142
+ query = data.get("query", "")
143
+ if not query:
144
+ return jsonify({"error": "Query cannot be empty"}), 400
145
+
146
+ def generate():
147
+ start = time.time()
148
+ word_count = 0
149
+ for word in generate_response_stream(model, query):
150
+ word_count += 1
151
+ response_data = {
152
+ "word": word.strip(),
153
+ "timestamp": time.time() - start,
154
+ "word_count": word_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  }
156
+ yield f"data: {json.dumps(response_data)}\n\n"
157
+
158
+ return Response(
159
+ stream_with_context(generate()),
160
+ mimetype='text/event-stream',
161
+ headers={
162
+ 'Cache-Control': 'no-cache',
163
+ 'Connection': 'keep-alive',
164
+ 'X-Accel-Buffering': 'no' # Disable nginx buffering if present
165
+ }
166
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  if __name__ == "__main__":
169
+ # Load and optimize model
170
+ model = load_model(model)
 
171
 
172
+ # Run Flask with threading enabled and optimized worker settings
 
173
  app.run(
174
  host="0.0.0.0",
175
+ port=7860,
176
+ threaded=True,
177
+ processes=1, # Use threading instead of multiprocessing for better memory sharing
178
+ debug=False # Disable debug mode for better performance
179
  )