Stanley03 commited on
Commit
eb9dd30
Β·
verified Β·
1 Parent(s): 1585d87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -131
app.py CHANGED
@@ -6,6 +6,7 @@ import time
6
  import logging
7
  import os
8
  import json
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
@@ -24,15 +25,20 @@ logger.info(f"πŸš€ Running on Hugging Face Spaces: {ON_SPACES}")
24
 
25
  # Use the smallest Qwen model available
26
  MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
27
- # Or even smaller alternative: "Qwen/Qwen2.5-Coder-0.5B-Instruct"
28
 
29
  model = None
30
  tokenizer = None
31
  model_loaded = False
 
32
 
33
  def load_model_fast():
34
  """Fast model loading optimized for Spaces"""
35
- global model, tokenizer, model_loaded
 
 
 
 
 
36
 
37
  try:
38
  logger.info(f"πŸ”„ Loading {MODEL_NAME}...")
@@ -71,13 +77,11 @@ def load_model_fast():
71
  model_loaded = True
72
  logger.info("βœ… Model loaded successfully!")
73
 
74
- # Test a quick generation
75
- test_response = generate_quick("Hello")
76
- logger.info(f"πŸ§ͺ Test generation: {test_response[:50]}...")
77
-
78
  except Exception as e:
79
  logger.error(f"❌ Model loading failed: {str(e)[:200]}")
80
  model_loaded = False
 
 
81
 
82
  # ============================================================================
83
  # OPTIMIZED GENERATION FUNCTIONS
@@ -86,12 +90,16 @@ def load_model_fast():
86
  def generate_quick(user_message, max_tokens=256):
87
  """Ultra-fast generation with minimal overhead"""
88
  if not model_loaded:
89
- return "Model is still loading, please wait..."
90
 
91
  try:
 
 
 
 
92
  # Format the prompt for Qwen chat template
93
  messages = [
94
- {"role": "system", "content": "You are Stanley AI, a helpful assistant."},
95
  {"role": "user", "content": user_message}
96
  ]
97
 
@@ -102,7 +110,7 @@ def generate_quick(user_message, max_tokens=256):
102
  add_generation_prompt=True
103
  )
104
 
105
- # Tokenize
106
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
107
 
108
  # Move to device
@@ -132,74 +140,23 @@ def generate_quick(user_message, max_tokens=256):
132
 
133
  except Exception as e:
134
  logger.error(f"Generation error: {e}")
135
- return f"I encountered an error: {str(e)[:100]}"
136
-
137
- def generate_streaming(user_message, max_tokens=256):
138
- """Streaming response for better UX"""
139
- if not model_loaded:
140
- yield "data: Model is still loading, please wait...\n\n"
141
- return
142
-
143
- try:
144
- # Format prompt
145
- messages = [
146
- {"role": "system", "content": "You are Stanley AI, a helpful assistant."},
147
- {"role": "user", "content": user_message}
148
- ]
149
-
150
- text = tokenizer.apply_chat_template(
151
- messages,
152
- tokenize=False,
153
- add_generation_prompt=True
154
- )
155
-
156
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
157
- device = model.device
158
- inputs = {k: v.to(device) for k, v in inputs.items()}
159
-
160
- # Generate token by token
161
- with torch.no_grad():
162
- generated = inputs['input_ids'].clone()
163
- for _ in range(max_tokens):
164
- outputs = model(
165
- input_ids=generated,
166
- attention_mask=torch.ones_like(generated) if 'attention_mask' not in inputs else None,
167
- use_cache=True
168
- )
169
-
170
- # Get next token
171
- next_token_logits = outputs.logits[:, -1, :]
172
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
173
-
174
- # Check for eos
175
- if next_token.item() == tokenizer.eos_token_id:
176
- break
177
-
178
- # Decode and yield
179
- generated = torch.cat([generated, next_token], dim=-1)
180
- token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
181
-
182
- yield f"data: {json.dumps({'token': token_text})}\n\n"
183
-
184
- except Exception as e:
185
- logger.error(f"Streaming error: {e}")
186
- yield f"data: {json.dumps({'error': str(e)[:100]})}\n\n"
187
 
188
  # ============================================================================
189
- # CACHE SYSTEM FOR REPEATED QUERIES
190
  # ============================================================================
191
 
192
  response_cache = {}
193
- CACHE_SIZE = 100
194
 
195
  def get_cached_response(query):
196
  """Get response from cache"""
197
- key = query.lower().strip()[:100]
198
  return response_cache.get(key)
199
 
200
  def cache_response(query, response):
201
  """Cache response"""
202
- key = query.lower().strip()[:100]
203
  if len(response_cache) >= CACHE_SIZE:
204
  # Remove oldest
205
  response_cache.pop(next(iter(response_cache)))
@@ -212,22 +169,22 @@ def cache_response(query, response):
212
  @app.route('/')
213
  def home():
214
  return jsonify({
215
- "name": "Stanley AI - Text Only",
216
- "version": "4.0",
217
  "model": MODEL_NAME,
218
  "status": "ready" if model_loaded else "loading",
219
- "optimized_for": "huggingface-spaces",
220
  "endpoints": {
221
- "chat": "/api/chat",
222
- "stream": "/api/chat/stream",
223
- "status": "/api/status"
224
  },
225
- "note": "Ultra-fast text-only version using Qwen 0.5B"
226
  })
227
 
228
  @app.route('/api/chat', methods=['POST', 'GET'])
229
  def chat():
230
- """Main chat endpoint - supports both POST and GET for testing"""
231
  start_time = time.time()
232
 
233
  try:
@@ -243,6 +200,19 @@ def chat():
243
  if not user_message:
244
  return jsonify({"error": "No message provided"}), 400
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  # Check cache first
247
  cached = get_cached_response(user_message)
248
  if cached:
@@ -251,7 +221,8 @@ def chat():
251
  "response": cached,
252
  "cached": True,
253
  "response_time": round(time.time() - start_time, 3),
254
- "model": MODEL_NAME
 
255
  })
256
 
257
  # Generate response
@@ -278,76 +249,61 @@ def chat():
278
  "status": "error"
279
  }), 500
280
 
281
- @app.route('/api/chat/stream')
282
- def chat_stream():
283
- """Streaming chat endpoint"""
284
- user_message = request.args.get('message', 'Hello')
285
-
286
- def generate():
287
- for token in generate_streaming(user_message):
288
- yield token
289
- yield "data: [DONE]\n\n"
290
-
291
- return app.response_class(
292
- generate(),
293
- mimetype='text/event-stream',
294
- headers={
295
- 'Cache-Control': 'no-cache',
296
- 'X-Accel-Buffering': 'no'
297
- }
298
- )
299
-
300
  @app.route('/api/status')
301
  def status():
302
  """Health check"""
303
  return jsonify({
304
  "model_loaded": model_loaded,
 
305
  "model_name": MODEL_NAME,
306
  "device": str(model.device) if model_loaded else "none",
307
  "cache_size": len(response_cache),
308
  "timestamp": time.time(),
309
- "memory_allocated": f"{torch.cuda.memory_allocated() / 1024**2:.1f} MB" if torch.cuda.is_available() else "CPU mode"
 
310
  })
311
 
312
  @app.route('/api/test')
313
  def test():
314
  """Quick test endpoint"""
315
- test_queries = [
316
- "Hello, how are you?",
317
- "What is AI?",
318
- "Tell me a joke",
319
- "Explain quantum computing simply"
320
- ]
321
-
322
- results = []
323
- for query in test_queries[:2]: # Test only 2 to be fast
324
- start = time.time()
325
- response = generate_quick(query, max_tokens=100)
326
- time_taken = round(time.time() - start, 3)
327
- results.append({
328
- "query": query,
329
- "response": response[:100] + "..." if len(response) > 100 else response,
330
- "time": time_taken
331
  })
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  return jsonify({
334
- "tests": results,
335
- "average_time": round(sum(r['time'] for r in results) / len(results), 3) if results else 0
336
  })
337
 
338
  # ============================================================================
339
- # STARTUP OPTIMIZATION
340
  # ============================================================================
341
 
342
- @app.before_first_request
343
- def startup():
344
- """Load model on first request to avoid startup timeout"""
345
- if not model_loaded:
346
- load_model_fast()
347
-
348
- # Preload model immediately if not on Spaces (for local testing)
349
- if not ON_SPACES:
350
- logger.info("🌍 Local mode - loading model immediately")
351
  load_model_fast()
352
 
353
  # ============================================================================
@@ -356,19 +312,13 @@ if not ON_SPACES:
356
 
357
  if __name__ == '__main__':
358
  print("=" * 50)
359
- print("πŸš€ STANLEY AI - Ultra Fast Text Edition")
360
  print(f"πŸ“¦ Model: {MODEL_NAME}")
361
  print(f"🌍 Platform: {'Hugging Face Spaces' if ON_SPACES else 'Local'}")
362
- print(f"⚑ Optimized for: {'CPU' if not torch.cuda.is_available() else 'GPU'}")
 
363
  print("=" * 50)
364
 
365
- # Load model in background thread to avoid timeout
366
- import threading
367
- if ON_SPACES and not model_loaded:
368
- print("πŸ”„ Loading model in background thread...")
369
- thread = threading.Thread(target=load_model_fast, daemon=True)
370
- thread.start()
371
-
372
  # Run app
373
  port = int(os.environ.get('PORT', 7860))
374
  app.run(
 
6
  import logging
7
  import os
8
  import json
9
+ import threading
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
 
25
 
26
  # Use the smallest Qwen model available
27
  MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
 
28
 
29
  model = None
30
  tokenizer = None
31
  model_loaded = False
32
+ model_loading = False
33
 
34
  def load_model_fast():
35
  """Fast model loading optimized for Spaces"""
36
+ global model, tokenizer, model_loaded, model_loading
37
+
38
+ if model_loading or model_loaded:
39
+ return
40
+
41
+ model_loading = True
42
 
43
  try:
44
  logger.info(f"πŸ”„ Loading {MODEL_NAME}...")
 
77
  model_loaded = True
78
  logger.info("βœ… Model loaded successfully!")
79
 
 
 
 
 
80
  except Exception as e:
81
  logger.error(f"❌ Model loading failed: {str(e)[:200]}")
82
  model_loaded = False
83
+ finally:
84
+ model_loading = False
85
 
86
  # ============================================================================
87
  # OPTIMIZED GENERATION FUNCTIONS
 
90
  def generate_quick(user_message, max_tokens=256):
91
  """Ultra-fast generation with minimal overhead"""
92
  if not model_loaded:
93
+ return "Model is still loading, please wait a few seconds and try again..."
94
 
95
  try:
96
+ # Truncate long messages
97
+ if len(user_message) > 1000:
98
+ user_message = user_message[:1000]
99
+
100
  # Format the prompt for Qwen chat template
101
  messages = [
102
+ {"role": "system", "content": "You are Stanley AI, a helpful and knowledgeable assistant. Keep responses concise and helpful."},
103
  {"role": "user", "content": user_message}
104
  ]
105
 
 
110
  add_generation_prompt=True
111
  )
112
 
113
+ # Tokenize with truncation
114
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
115
 
116
  # Move to device
 
140
 
141
  except Exception as e:
142
  logger.error(f"Generation error: {e}")
143
+ return f"I encountered an error. Please try again or rephrase your question."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  # ============================================================================
146
+ # SIMPLE CACHE SYSTEM
147
  # ============================================================================
148
 
149
  response_cache = {}
150
+ CACHE_SIZE = 50
151
 
152
  def get_cached_response(query):
153
  """Get response from cache"""
154
+ key = query.lower().strip()[:80]
155
  return response_cache.get(key)
156
 
157
  def cache_response(query, response):
158
  """Cache response"""
159
+ key = query.lower().strip()[:80]
160
  if len(response_cache) >= CACHE_SIZE:
161
  # Remove oldest
162
  response_cache.pop(next(iter(response_cache)))
 
169
  @app.route('/')
170
  def home():
171
  return jsonify({
172
+ "name": "Stanley AI",
173
+ "version": "4.1",
174
  "model": MODEL_NAME,
175
  "status": "ready" if model_loaded else "loading",
176
+ "platform": "huggingface-spaces",
177
  "endpoints": {
178
+ "chat": "POST /api/chat",
179
+ "status": "GET /api/status",
180
+ "test": "GET /api/test"
181
  },
182
+ "note": "Fast text generation with Qwen 0.5B model"
183
  })
184
 
185
  @app.route('/api/chat', methods=['POST', 'GET'])
186
  def chat():
187
+ """Main chat endpoint"""
188
  start_time = time.time()
189
 
190
  try:
 
200
  if not user_message:
201
  return jsonify({"error": "No message provided"}), 400
202
 
203
+ # If model is not loaded yet
204
+ if not model_loaded:
205
+ # Start loading if not already loading
206
+ if not model_loading:
207
+ thread = threading.Thread(target=load_model_fast, daemon=True)
208
+ thread.start()
209
+
210
+ return jsonify({
211
+ "response": "Model is loading... Please wait a few seconds and try again.",
212
+ "status": "loading",
213
+ "response_time": round(time.time() - start_time, 3)
214
+ })
215
+
216
  # Check cache first
217
  cached = get_cached_response(user_message)
218
  if cached:
 
221
  "response": cached,
222
  "cached": True,
223
  "response_time": round(time.time() - start_time, 3),
224
+ "model": MODEL_NAME,
225
+ "tokens": len(cached.split())
226
  })
227
 
228
  # Generate response
 
249
  "status": "error"
250
  }), 500
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  @app.route('/api/status')
253
  def status():
254
  """Health check"""
255
  return jsonify({
256
  "model_loaded": model_loaded,
257
+ "model_loading": model_loading,
258
  "model_name": MODEL_NAME,
259
  "device": str(model.device) if model_loaded else "none",
260
  "cache_size": len(response_cache),
261
  "timestamp": time.time(),
262
+ "on_spaces": ON_SPACES,
263
+ "memory": f"{torch.cuda.memory_allocated() / 1024**2:.1f} MB" if torch.cuda.is_available() and model_loaded else "CPU mode"
264
  })
265
 
266
  @app.route('/api/test')
267
  def test():
268
  """Quick test endpoint"""
269
+ if not model_loaded:
270
+ return jsonify({
271
+ "status": "model_not_loaded",
272
+ "message": "Model is still loading. Try /api/chat endpoint in a few seconds."
 
 
 
 
 
 
 
 
 
 
 
 
273
  })
274
 
275
+ test_query = "Hello, who are you?"
276
+ start = time.time()
277
+ response = generate_quick(test_query, max_tokens=100)
278
+ time_taken = round(time.time() - start, 3)
279
+
280
+ return jsonify({
281
+ "test": "success",
282
+ "query": test_query,
283
+ "response_preview": response[:200] + "..." if len(response) > 200 else response,
284
+ "response_time": time_taken,
285
+ "model": MODEL_NAME
286
+ })
287
+
288
+ @app.route('/api/health')
289
+ def health():
290
+ """Simple health check for Spaces"""
291
  return jsonify({
292
+ "status": "healthy",
293
+ "timestamp": time.time()
294
  })
295
 
296
  # ============================================================================
297
+ # STARTUP
298
  # ============================================================================
299
 
300
+ # Start model loading in background when app starts
301
+ if ON_SPACES:
302
+ logger.info("Starting model load in background thread...")
303
+ thread = threading.Thread(target=load_model_fast, daemon=True)
304
+ thread.start()
305
+ else:
306
+ # Load immediately for local testing
 
 
307
  load_model_fast()
308
 
309
  # ============================================================================
 
312
 
313
  if __name__ == '__main__':
314
  print("=" * 50)
315
+ print("πŸš€ STANLEY AI - Hugging Face Spaces Edition")
316
  print(f"πŸ“¦ Model: {MODEL_NAME}")
317
  print(f"🌍 Platform: {'Hugging Face Spaces' if ON_SPACES else 'Local'}")
318
+ print(f"⚑ Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
319
+ print(f"πŸ“Š Model Status: {'Loaded' if model_loaded else 'Loading...'}")
320
  print("=" * 50)
321
 
 
 
 
 
 
 
 
322
  # Run app
323
  port = int(os.environ.get('PORT', 7860))
324
  app.run(