Rafs-an09002 commited on
Commit
3a461b8
·
verified ·
1 Parent(s): bdad329

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -0
app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GambitFlow Bridge API - HuggingFace Space
3
+ Unified API gateway with Firebase analytics and rate limiting
4
+ """
5
+
6
+ from flask import Flask, request, jsonify, Response
7
+ from flask_cors import CORS
8
+ import requests
9
+ import time
10
+ import os
11
+ from functools import wraps
12
+ import firebase_admin
13
+ from firebase_admin import credentials, db
14
+ import json
15
+
16
+ app = Flask(__name__)
17
+ CORS(app)
18
+
19
+ # ==================== FIREBASE SETUP ====================
20
+
21
+ def initialize_firebase():
22
+ """Initialize Firebase Admin SDK"""
23
+ try:
24
+ # Load credentials from environment variable
25
+ firebase_creds = os.getenv('FIREBASE_CREDENTIALS')
26
+ if firebase_creds:
27
+ cred_dict = json.loads(firebase_creds)
28
+ cred = credentials.Certificate(cred_dict)
29
+ else:
30
+ # Fallback to service account file
31
+ cred = credentials.Certificate('firebase-credentials.json')
32
+
33
+ firebase_admin.initialize_app(cred, {
34
+ 'databaseURL': os.getenv('FIREBASE_DATABASE_URL', 'https://YOUR-PROJECT.firebaseio.com')
35
+ })
36
+ print("✅ Firebase initialized successfully")
37
+ except Exception as e:
38
+ print(f"⚠️ Firebase initialization failed: {e}")
39
+
40
+ # Initialize Firebase
41
+ initialize_firebase()
42
+
43
+ # ==================== MODEL CONFIGURATION ====================
44
+
45
+ MODELS = {
46
+ 'nano': {
47
+ 'name': 'Nexus-Nano',
48
+ 'endpoint': os.getenv('NANO_ENDPOINT', 'https://gambitflow-nexus-nano-inference-api.hf.space'),
49
+ 'timeout': 30
50
+ },
51
+ 'core': {
52
+ 'name': 'Nexus-Core',
53
+ 'endpoint': os.getenv('CORE_ENDPOINT', 'https://gambitflow-nexus-core-inference-api.hf.space'),
54
+ 'timeout': 40
55
+ },
56
+ 'base': {
57
+ 'name': 'Synapse-Base',
58
+ 'endpoint': os.getenv('BASE_ENDPOINT', 'https://gambitflow-synapse-base-inference-api.hf.space'),
59
+ 'timeout': 60
60
+ }
61
+ }
62
+
63
+ # ==================== FIREBASE ANALYTICS ====================
64
+
65
+ def increment_stats(model_name, stat_type='moves'):
66
+ """
67
+ Increment statistics in Firebase
68
+ stat_type: 'moves' or 'matches'
69
+ """
70
+ try:
71
+ ref = db.reference('stats')
72
+
73
+ # Increment total stats
74
+ total_ref = ref.child('total').child(stat_type)
75
+ current = total_ref.get() or 0
76
+ total_ref.set(current + 1)
77
+
78
+ # Increment model-specific stats
79
+ model_ref = ref.child('models').child(model_name).child(stat_type)
80
+ current = model_ref.get() or 0
81
+ model_ref.set(current + 1)
82
+
83
+ # Update last_updated timestamp
84
+ ref.child('last_updated').set(int(time.time()))
85
+
86
+ except Exception as e:
87
+ print(f"Firebase stats update error: {e}")
88
+
89
+ def get_all_stats():
90
+ """Get all statistics from Firebase"""
91
+ try:
92
+ ref = db.reference('stats')
93
+ stats = ref.get() or {}
94
+
95
+ if not stats:
96
+ # Initialize default structure
97
+ stats = {
98
+ 'total': {'moves': 0, 'matches': 0},
99
+ 'models': {
100
+ 'nano': {'moves': 0, 'matches': 0},
101
+ 'core': {'moves': 0, 'matches': 0},
102
+ 'base': {'moves': 0, 'matches': 0}
103
+ },
104
+ 'last_updated': int(time.time())
105
+ }
106
+ ref.set(stats)
107
+
108
+ return stats
109
+ except Exception as e:
110
+ print(f"Firebase stats fetch error: {e}")
111
+ return {
112
+ 'total': {'moves': 0, 'matches': 0},
113
+ 'models': {
114
+ 'nano': {'moves': 0, 'matches': 0},
115
+ 'core': {'moves': 0, 'matches': 0},
116
+ 'base': {'moves': 0, 'matches': 0}
117
+ },
118
+ 'last_updated': int(time.time())
119
+ }
120
+
121
+ # ==================== CACHE ====================
122
+
123
+ class SimpleCache:
124
+ def __init__(self, ttl=300):
125
+ self.cache = {}
126
+ self.ttl = ttl
127
+
128
+ def get(self, key):
129
+ if key in self.cache:
130
+ value, timestamp = self.cache[key]
131
+ if time.time() - timestamp < self.ttl:
132
+ return value
133
+ del self.cache[key]
134
+ return None
135
+
136
+ def set(self, key, value):
137
+ self.cache[key] = (value, time.time())
138
+
139
+ def clear_old(self):
140
+ current_time = time.time()
141
+ expired = [k for k, (_, t) in self.cache.items() if current_time - t >= self.ttl]
142
+ for k in expired:
143
+ del self.cache[k]
144
+
145
+ cache = SimpleCache(ttl=300)
146
+
147
+ # ==================== ROUTES ====================
148
+
149
+ @app.route('/')
150
+ def index():
151
+ """API documentation"""
152
+ return jsonify({
153
+ 'name': 'GambitFlow Bridge API',
154
+ 'version': '1.0.0',
155
+ 'description': 'Unified gateway for all GambitFlow chess engines',
156
+ 'endpoints': {
157
+ '/predict': 'POST - Get best move prediction',
158
+ '/health': 'GET - Health check',
159
+ '/stats': 'GET - Get usage statistics',
160
+ '/models': 'GET - List available models'
161
+ },
162
+ 'models': list(MODELS.keys())
163
+ })
164
+
165
+ @app.route('/health')
166
+ def health():
167
+ """Health check endpoint"""
168
+ return jsonify({
169
+ 'status': 'healthy',
170
+ 'timestamp': int(time.time()),
171
+ 'models': len(MODELS),
172
+ 'cache_size': len(cache.cache)
173
+ })
174
+
175
+ @app.route('/stats')
176
+ def get_stats():
177
+ """Get usage statistics from Firebase"""
178
+ stats = get_all_stats()
179
+ return jsonify(stats)
180
+
181
+ @app.route('/models')
182
+ def list_models():
183
+ """List all available models"""
184
+ models_info = {}
185
+ for key, config in MODELS.items():
186
+ models_info[key] = {
187
+ 'name': config['name'],
188
+ 'endpoint': config['endpoint'],
189
+ 'timeout': config['timeout']
190
+ }
191
+ return jsonify({'models': models_info})
192
+
193
+ @app.route('/predict', methods=['POST'])
194
+ def predict():
195
+ """
196
+ Main prediction endpoint
197
+ Forwards request to appropriate model and tracks statistics
198
+ """
199
+ try:
200
+ data = request.get_json()
201
+
202
+ if not data:
203
+ return jsonify({'error': 'No data provided'}), 400
204
+
205
+ # Extract parameters
206
+ fen = data.get('fen')
207
+ model = data.get('model', 'core')
208
+ depth = data.get('depth', 5)
209
+ time_limit = data.get('time_limit', 3000)
210
+ track_stats = data.get('track_stats', True) # Allow disabling stats tracking
211
+
212
+ if not fen:
213
+ return jsonify({'error': 'FEN position required'}), 400
214
+
215
+ if model not in MODELS:
216
+ return jsonify({'error': f'Invalid model: {model}'}), 400
217
+
218
+ # Check cache
219
+ cache_key = f"{model}:{fen}:{depth}:{time_limit}"
220
+ cached = cache.get(cache_key)
221
+ if cached:
222
+ cached['from_cache'] = True
223
+ if track_stats:
224
+ increment_stats(model, 'moves')
225
+ return jsonify(cached)
226
+
227
+ # Forward to model API
228
+ model_config = MODELS[model]
229
+ endpoint = f"{model_config['endpoint']}/predict"
230
+
231
+ response = requests.post(
232
+ endpoint,
233
+ json={
234
+ 'fen': fen,
235
+ 'depth': depth,
236
+ 'time_limit': time_limit
237
+ },
238
+ timeout=model_config['timeout']
239
+ )
240
+
241
+ if response.status_code == 200:
242
+ result = response.json()
243
+
244
+ # Cache the result
245
+ cache.set(cache_key, result)
246
+
247
+ # Track statistics in Firebase
248
+ if track_stats:
249
+ increment_stats(model, 'moves')
250
+
251
+ result['from_cache'] = False
252
+ result['model'] = model
253
+
254
+ return jsonify(result)
255
+ else:
256
+ return jsonify({
257
+ 'error': 'Model API error',
258
+ 'status_code': response.status_code,
259
+ 'details': response.text
260
+ }), response.status_code
261
+
262
+ except requests.Timeout:
263
+ return jsonify({'error': 'Request timeout'}), 504
264
+ except Exception as e:
265
+ return jsonify({'error': str(e)}), 500
266
+
267
+ @app.route('/match/start', methods=['POST'])
268
+ def start_match():
269
+ """Track match start"""
270
+ try:
271
+ data = request.get_json()
272
+ model = data.get('model', 'core')
273
+
274
+ if model not in MODELS:
275
+ return jsonify({'error': 'Invalid model'}), 400
276
+
277
+ increment_stats(model, 'matches')
278
+
279
+ return jsonify({
280
+ 'success': True,
281
+ 'model': model,
282
+ 'message': 'Match started'
283
+ })
284
+ except Exception as e:
285
+ return jsonify({'error': str(e)}), 500
286
+
287
+ @app.route('/batch', methods=['POST'])
288
+ def batch_predict():
289
+ """
290
+ Batch prediction endpoint for multiple positions
291
+ """
292
+ try:
293
+ data = request.get_json()
294
+ positions = data.get('positions', [])
295
+ model = data.get('model', 'core')
296
+
297
+ if not positions:
298
+ return jsonify({'error': 'No positions provided'}), 400
299
+
300
+ if len(positions) > 10:
301
+ return jsonify({'error': 'Maximum 10 positions per batch'}), 400
302
+
303
+ results = []
304
+ for pos in positions:
305
+ fen = pos.get('fen')
306
+ depth = pos.get('depth', 5)
307
+ time_limit = pos.get('time_limit', 3000)
308
+
309
+ # Make individual request
310
+ pred_data = {
311
+ 'fen': fen,
312
+ 'model': model,
313
+ 'depth': depth,
314
+ 'time_limit': time_limit,
315
+ 'track_stats': False # Don't track for batch
316
+ }
317
+
318
+ result = predict_single(pred_data)
319
+ results.append(result)
320
+
321
+ # Track batch as single operation
322
+ increment_stats(model, 'moves')
323
+
324
+ return jsonify({
325
+ 'success': True,
326
+ 'count': len(results),
327
+ 'results': results
328
+ })
329
+
330
+ except Exception as e:
331
+ return jsonify({'error': str(e)}), 500
332
+
333
+ def predict_single(data):
334
+ """Helper function for single prediction"""
335
+ try:
336
+ fen = data.get('fen')
337
+ model = data.get('model', 'core')
338
+ depth = data.get('depth', 5)
339
+ time_limit = data.get('time_limit', 3000)
340
+
341
+ model_config = MODELS[model]
342
+ endpoint = f"{model_config['endpoint']}/predict"
343
+
344
+ response = requests.post(
345
+ endpoint,
346
+ json={
347
+ 'fen': fen,
348
+ 'depth': depth,
349
+ 'time_limit': time_limit
350
+ },
351
+ timeout=model_config['timeout']
352
+ )
353
+
354
+ if response.status_code == 200:
355
+ return response.json()
356
+ else:
357
+ return {'error': 'Prediction failed'}
358
+ except:
359
+ return {'error': 'Request failed'}
360
+
361
+ # ==================== CLEANUP ====================
362
+
363
+ @app.before_request
364
+ def before_request():
365
+ """Clean old cache entries before each request"""
366
+ cache.clear_old()
367
+
368
+ # ==================== RUN ====================
369
+
370
+ if __name__ == '__main__':
371
+ port = int(os.getenv('PORT', 7860))
372
+ app.run(host='0.0.0.0', port=port, debug=False)