kan0621 commited on
Commit
d2e0e3f
·
verified ·
1 Parent(s): bffa085

Version 1.0

Browse files
Files changed (1) hide show
  1. app.py +1162 -0
app.py ADDED
@@ -0,0 +1,1162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_file, send_from_directory, redirect, url_for
2
+ from flask_cors import CORS
3
+ import json
4
+ import uuid
5
+ import os
6
+ import pickle
7
+ from threading import Event, Thread, Lock
8
+ from multiprocessing import Value
9
+ from backend import generate_stimuli, custom_model_inference_handler
10
+ from collections import defaultdict
11
+ import io
12
+ from flask_socketio import SocketIO, emit
13
+ import time
14
+ from apscheduler.schedulers.background import BackgroundScheduler
15
+ import requests
16
+
17
+
18
+ app = Flask(__name__, static_folder='static')
19
+ CORS(app)
20
+ socketio = SocketIO(
21
+ app,
22
+ cors_allowed_origins="*",
23
+ async_mode="threading",
24
+ logger=False,
25
+ engineio_logger=False,
26
+ ping_timeout=30, # Reduced timeout
27
+ ping_interval=10, # Reduced interval
28
+ http_compression=False,
29
+ manage_session=False,
30
+ allow_upgrades=True, # Allow transport upgrades
31
+ transports=['polling', 'websocket'], # Explicit transport order
32
+ )
33
+
34
+ # Create a session store directory
35
+ SESSION_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'sessions')
36
+ if not os.path.exists(SESSION_DIR):
37
+ os.makedirs(SESSION_DIR)
38
+
39
+ # Global session state dictionary, save the object reference of active sessions
40
+ # Use a lock to protect access to the dictionary
41
+ active_sessions = {}
42
+ sessions_lock = Lock()
43
+
44
+ # Get the session file path
45
+
46
+
47
+ def get_session_file(session_id):
48
+ return os.path.join(SESSION_DIR, f"{session_id}.pkl")
49
+
50
+
51
+ def format_stimulus_content(stimulus_content):
52
+ """
53
+ Format stimulus_content, convert dictionary format to readable string format
54
+ From: {'key1': 'value1', 'key2': 'value2'}
55
+ Convert to:
56
+ Key 1: key1
57
+ Content: value1
58
+ Key 2: key2
59
+ Content: value2
60
+ """
61
+ if not isinstance(stimulus_content, dict):
62
+ # If not a dictionary, try to parse string
63
+ try:
64
+ if isinstance(stimulus_content, str):
65
+ # Try to parse JSON string
66
+ stimulus_dict = json.loads(stimulus_content)
67
+ else:
68
+ return str(stimulus_content)
69
+ except (json.JSONDecodeError, TypeError):
70
+ return str(stimulus_content)
71
+ else:
72
+ stimulus_dict = stimulus_content
73
+
74
+ formatted_lines = []
75
+ for i, (key, value) in enumerate(stimulus_dict.items(), 1):
76
+ formatted_lines.append(f"Key {i}: {key}")
77
+ formatted_lines.append(f"Content: {value}")
78
+ if i < len(stimulus_dict): # Don't add empty line after the last item
79
+ formatted_lines.append("")
80
+
81
+ return "\n".join(formatted_lines)
82
+
83
+
84
+ def expand_stimulus_content_to_columns(df):
85
+ """
86
+ Expand stimulus_content column into multiple Key and Content columns
87
+ """
88
+ if 'stimulus_content' not in df.columns:
89
+ return df
90
+
91
+ # Create new dataframe copy
92
+ df_expanded = df.copy()
93
+
94
+ # Collect all required columns
95
+ max_items = 0
96
+ expanded_data = []
97
+
98
+ for idx, row in df_expanded.iterrows():
99
+ stimulus_content = row['stimulus_content']
100
+
101
+ # Parse stimulus_content
102
+ if not isinstance(stimulus_content, dict):
103
+ try:
104
+ if isinstance(stimulus_content, str):
105
+ stimulus_dict = json.loads(stimulus_content)
106
+ else:
107
+ stimulus_dict = {}
108
+ except (json.JSONDecodeError, TypeError):
109
+ stimulus_dict = {}
110
+ else:
111
+ stimulus_dict = stimulus_content
112
+
113
+ # Record maximum number of items
114
+ max_items = max(max_items, len(stimulus_dict))
115
+
116
+ # Create expanded data for this row
117
+ row_data = {}
118
+ for i, (key, value) in enumerate(stimulus_dict.items(), 1):
119
+ row_data[f'Key{i}'] = key
120
+ row_data[f'Content{i}'] = value
121
+
122
+ expanded_data.append(row_data)
123
+
124
+ # Create all new columns
125
+ for i in range(1, max_items + 1):
126
+ df_expanded[f'Key{i}'] = ''
127
+ df_expanded[f'Content{i}'] = ''
128
+
129
+ # Fill data
130
+ for idx, row_data in enumerate(expanded_data):
131
+ for col, value in row_data.items():
132
+ df_expanded.at[idx, col] = value
133
+
134
+ # Remove original stimulus_content column
135
+ df_expanded = df_expanded.drop(columns=['stimulus_content'])
136
+
137
+ return df_expanded
138
+
139
+ # Save the session state to a file
140
+
141
+
142
+ def save_session(session_id, session_state):
143
+ # Create a serializable session state copy
144
+ serializable_state = {
145
+ 'generation_file': session_state['generation_file'],
146
+ 'error_message': session_state['error_message'],
147
+ 'current_iteration_value': session_state['current_iteration'].value,
148
+ 'total_iterations_value': session_state['total_iterations'].value,
149
+ # Save stop_event state
150
+ 'stop_event_is_set': session_state['stop_event'].is_set(),
151
+ }
152
+
153
+ # If there is a dataframe and it is not None, convert it to a CSV string and save it
154
+ if session_state.get('dataframe') is not None:
155
+ try:
156
+ csv_data = session_state['dataframe'].to_csv(index=False)
157
+ serializable_state['dataframe_csv'] = csv_data
158
+ except Exception as e:
159
+ print(
160
+ f"Error serializing dataframe for session {session_id}: {str(e)}")
161
+
162
+ try:
163
+ with open(get_session_file(session_id), 'wb') as f:
164
+ pickle.dump(serializable_state, f)
165
+ except Exception as e:
166
+ print(f"Error saving session {session_id}: {str(e)}")
167
+
168
+ # Create a new session state
169
+
170
+
171
+ def create_session_state():
172
+ state = {
173
+ 'stop_event': Event(), # Control the generation interrupt event
174
+ 'generation_file': None, # Store the generated file path
175
+ 'error_message': None, # Store the error message
176
+ # Shared variable, used to store the current iteration count
177
+ 'current_iteration': Value('i', 0),
178
+ # Shared variable, used to store the total iteration count
179
+ 'total_iterations': Value('i', 1),
180
+ 'generation_thread': None, # Store the generation thread
181
+ 'dataframe': None # Reset the dataframe
182
+ }
183
+ return state
184
+
185
+ # Load the session state from a file or use the state in the global dictionary
186
+
187
+
188
+ def load_session(session_id):
189
+ # First check if there is an active session in the global dictionary
190
+ with sessions_lock:
191
+ if session_id in active_sessions:
192
+ return active_sessions[session_id]
193
+
194
+ # If there is no active session in the global dictionary, try to load it from the file
195
+ try:
196
+ session_file = get_session_file(session_id)
197
+ if os.path.exists(session_file):
198
+ with open(session_file, 'rb') as f:
199
+ serialized_state = pickle.load(f)
200
+
201
+ # Create a complete session state
202
+ session_state = create_session_state()
203
+ session_state['generation_file'] = serialized_state.get(
204
+ 'generation_file')
205
+ session_state['error_message'] = serialized_state.get(
206
+ 'error_message')
207
+
208
+ with session_state['current_iteration'].get_lock():
209
+ session_state['current_iteration'].value = serialized_state.get(
210
+ 'current_iteration_value', 0)
211
+
212
+ with session_state['total_iterations'].get_lock():
213
+ session_state['total_iterations'].value = serialized_state.get(
214
+ 'total_iterations_value', 1)
215
+
216
+ # Restore the stop_event state
217
+ if serialized_state.get('stop_event_is_set', False):
218
+ session_state['stop_event'].set()
219
+ else:
220
+ session_state['stop_event'].clear()
221
+
222
+ # Restore the dataframe (if it exists)
223
+ if 'dataframe_csv' in serialized_state and serialized_state['dataframe_csv']:
224
+ try:
225
+ import pandas as pd
226
+ import io
227
+ csv_data = serialized_state['dataframe_csv']
228
+ session_state['dataframe'] = pd.read_csv(
229
+ io.StringIO(csv_data))
230
+ except Exception as e:
231
+ print(
232
+ f"Error deserializing dataframe for session {session_id}: {str(e)}")
233
+
234
+ # Add the loaded session state to the global dictionary
235
+ with sessions_lock:
236
+ active_sessions[session_id] = session_state
237
+
238
+ return session_state
239
+ except Exception as e:
240
+ print(f"Error loading session {session_id}: {str(e)}")
241
+
242
+ # If loading fails, return a new session state and add it to the global dictionary
243
+ session_state = create_session_state()
244
+ with sessions_lock:
245
+ active_sessions[session_id] = session_state
246
+ return session_state
247
+
248
+ # WebSocket callback function, used to send messages to the frontend
249
+
250
+
251
+ def websocket_send(session_id, message_type, message):
252
+ """
253
+ Send WebSocket messages to the frontend
254
+ """
255
+ try:
256
+ if not session_id:
257
+ return
258
+
259
+ # Truncate long messages
260
+ if isinstance(message, str) and len(message) > 1000:
261
+ message_preview = message[:1000] + "... [truncated]"
262
+ else:
263
+ message_preview = message
264
+
265
+ message_data = {
266
+ 'session_id': session_id,
267
+ 'type': message_type,
268
+ 'message': message_preview,
269
+ 'timestamp': time.time()
270
+ }
271
+
272
+ # Send message directly but with error protection
273
+ try:
274
+ socketio.emit('stimulus_update', message_data, room=session_id)
275
+ except Exception as e:
276
+ print(f"WebSocket send error: {str(e)}")
277
+
278
+ except Exception as e:
279
+ print(f"Error in websocket_send: {str(e)}")
280
+
281
+ # health check endpoint, used for cloud service monitoring
282
+
283
+
284
+ @app.route("/health")
285
+ def health_check():
286
+ return jsonify({
287
+ "status": "ok",
288
+ "timestamp": time.time(),
289
+ "version": "1.0.0",
290
+ "service": "stimulus-generator"
291
+ })
292
+
293
+
294
+ @app.route("/")
295
+ def homepage():
296
+ # Create a new session ID
297
+ session_id = str(uuid.uuid4())
298
+ # Redirect to the URL with the session ID
299
+ return redirect(f"/{session_id}")
300
+
301
+
302
+ @app.route("/<session_id>")
303
+ def session_homepage(session_id):
304
+ # Return the homepage HTML
305
+ root_dir = os.path.dirname(os.path.abspath(__file__))
306
+ return send_from_directory(root_dir, 'index.html')
307
+
308
+
309
+ @app.route('/static/<path:filename>')
310
+ def serve_static(filename):
311
+ root_dir = os.path.dirname(os.path.abspath(__file__))
312
+ return send_from_directory(os.path.join(root_dir, 'static'), filename)
313
+
314
+
315
+ @app.route('/<session_id>/generate_stimulus', methods=['POST'])
316
+ def generate_stimulus(session_id):
317
+ try:
318
+ # Load or create the session state
319
+ session_state = load_session(session_id)
320
+
321
+ # Ensure the previous state is thoroughly cleaned
322
+ with sessions_lock:
323
+ if session_id in active_sessions:
324
+ print(
325
+ f"Session {session_id} - Thoroughly cleaning previous state before new generation")
326
+
327
+ # Stop any running threads
328
+ session_state['stop_event'].set()
329
+ if session_state['generation_thread'] and session_state['generation_thread'].is_alive():
330
+ try:
331
+ # Try to wait for the thread to end, but not too long
332
+ session_state['generation_thread'].join(timeout=0.5)
333
+ except Exception as e:
334
+ print(
335
+ f"Session {session_id} - Error joining previous thread: {str(e)}")
336
+
337
+ # Reset all states
338
+ # Clear the previous stop signal
339
+ session_state['stop_event'].clear()
340
+ # Reset the generated file path
341
+ session_state['generation_file'] = None
342
+ # Clear the error message before each generation
343
+ session_state['error_message'] = None
344
+ session_state['dataframe'] = None # Reset the dataframe
345
+ # Clear the old thread reference
346
+ session_state['generation_thread'] = None
347
+
348
+ # Ensure the current iteration count is reset to zero
349
+ with session_state['current_iteration'].get_lock():
350
+ session_state['current_iteration'].value = 0
351
+
352
+ # Save the cleared state, ensuring persistence
353
+ save_session(session_id, session_state)
354
+
355
+ data = request.get_json()
356
+ if not data:
357
+ error_msg = "Missing request data"
358
+ return jsonify({'status': 'error', 'message': error_msg}), 400
359
+
360
+ # Record more detailed request information, including the iteration count
361
+ iteration_count = data.get('iteration', 'unknown')
362
+ print(
363
+ f"Session {session_id} - Starting new generation with {iteration_count} iterations.")
364
+
365
+ # Create a WebSocket callback function for the current session
366
+ def session_websocket_callback(message_type, message):
367
+ websocket_send(session_id, message_type, message)
368
+
369
+ # Verify that the necessary parameters exist
370
+ required_fields = ['experimentDesign', 'iteration']
371
+ for field in required_fields:
372
+ if field not in data:
373
+ error_msg = f"Missing required field: {field}"
374
+ return jsonify({'status': 'error', 'message': error_msg}), 400
375
+
376
+ # Verify that the iteration count is a positive integer
377
+ try:
378
+ iteration = int(data['iteration'])
379
+ if iteration <= 0:
380
+ return jsonify({'status': 'error', 'message': 'Iteration must be a positive integer'}), 400
381
+ except ValueError:
382
+ return jsonify({'status': 'error', 'message': 'Iteration must be a valid number'}), 400
383
+
384
+ # Check the model choice
385
+ model_choice = data.get('modelChoice', '')
386
+ if not model_choice:
387
+ return jsonify({'status': 'error', 'message': 'Please select a model'}), 400
388
+
389
+ settings = {
390
+ 'agent_1_properties': json.loads(data.get('agent1Properties', '{}')),
391
+ 'agent_2_properties': json.loads(data.get('agent2Properties', '{}')),
392
+ 'agent_3_properties': json.loads(data.get('agent3Properties', '{}')),
393
+ 'api_key': data.get('apiKey', ''),
394
+ 'model_choice': model_choice,
395
+ 'experiment_design': data['experimentDesign'],
396
+ 'previous_stimuli': json.loads(data.get('previousStimuli', '[]')),
397
+ 'iteration': iteration,
398
+ 'stop_event': session_state['stop_event'],
399
+ 'current_iteration': session_state['current_iteration'],
400
+ 'total_iterations': session_state['total_iterations'],
401
+ 'session_id': session_id,
402
+ 'websocket_callback': session_websocket_callback,
403
+ 'agent_2_individual_validation': data.get('agent2IndividualValidation', False),
404
+ 'agent_3_individual_scoring': data.get('agent3IndividualScoring', False),
405
+ }
406
+
407
+ # Add custom model parameters if custom model is selected
408
+ if model_choice == 'custom':
409
+ settings['apiUrl'] = data.get('apiUrl', '')
410
+ settings['modelName'] = data.get('modelName', '')
411
+ settings['params'] = data.get('params', None)
412
+ print(
413
+ f"Session {session_id} - Custom model parameters: URL={settings['apiUrl']}, Model={settings['modelName']}")
414
+
415
+ # Initialize the total iteration count
416
+ with session_state['total_iterations'].get_lock():
417
+ session_state['total_iterations'].value = settings['iteration']
418
+
419
+ # Save the updated session state
420
+ save_session(session_id, session_state)
421
+
422
+ # Create a callback function, used to update the session state
423
+ def update_session_callback():
424
+ # Save the updated session state
425
+ save_session(session_id, session_state)
426
+ with session_state['current_iteration'].get_lock(), session_state['total_iterations'].get_lock():
427
+ # Ensure a valid denominator (not 0)
428
+ denominator = max(1, session_state['total_iterations'].value)
429
+ progress = (
430
+ session_state['current_iteration'].value / denominator) * 100
431
+ # Avoid floating point precision issues causing abnormal progress display
432
+ progress = min(100, max(0, progress))
433
+ # Ensure the progress is an integer value, avoiding small numbers on the frontend
434
+ progress = int(round(progress))
435
+
436
+ print(
437
+ f"Session {session_id} updated - Progress: {progress}%, Current: {session_state['current_iteration'].value}, Total: {session_state['total_iterations'].value}")
438
+ # Send the progress update through WebSocket
439
+ try:
440
+ socketio.emit('progress_update', {
441
+ 'session_id': session_id,
442
+ 'progress': progress,
443
+ 'timestamp': time.time()
444
+ }, namespace='/', room=session_id)
445
+ except Exception as e:
446
+ print(f"Error sending progress update: {str(e)}")
447
+
448
+ # Add the callback function to settings
449
+ settings['session_update_callback'] = update_session_callback
450
+
451
+ def run_generation():
452
+ try:
453
+ # Send the start generation message
454
+ websocket_send(session_id, 'all',
455
+ "Starting generation process...")
456
+
457
+ # Check if there is a stop signal
458
+ if session_state['stop_event'].is_set():
459
+ print(
460
+ f"Session {session_id} - Stop detected before generation start.")
461
+ websocket_send(session_id, 'all',
462
+ "Generation stopped before it started.")
463
+ return
464
+
465
+ # Generate data
466
+ settings["ablation"] = {
467
+ "use_agent_2": True,
468
+ "use_agent_3": True
469
+ }
470
+ df, filename = generate_stimuli(settings)
471
+
472
+ # Check if there is a stop signal - but still save partial data if available
473
+ is_stopped = session_state['stop_event'].is_set()
474
+ if is_stopped:
475
+ print(
476
+ f"Session {session_id} - Stop detected after generation completed.")
477
+ # Don't return immediately - continue to save any partial data
478
+
479
+ # Verify the returned results
480
+ print(settings["ablation"])
481
+ if df is None or filename is None:
482
+ if is_stopped:
483
+ # Stopped before any data was generated
484
+ print(
485
+ f"Session {session_id} - Stopped with no data generated")
486
+ websocket_send(
487
+ session_id, 'all', "Generation stopped. No data was generated.")
488
+ return
489
+ else:
490
+ error_msg = "Generation process returned None for dataframe or filename"
491
+ print(f"Session {session_id} - {error_msg}")
492
+ session_state['error_message'] = error_msg
493
+ session_state['generation_file'] = None
494
+ websocket_send(session_id, 'error', error_msg)
495
+ return
496
+
497
+ # Verify the number of generated stimuli
498
+ if not is_stopped and settings["ablation"]["use_agent_2"] == True:
499
+ if len(df) != settings['iteration']:
500
+ warning_msg = f"Warning: Expected {settings['iteration']} stimuli but got {len(df)}"
501
+ print(f"Session {session_id} - {warning_msg}")
502
+ websocket_send(session_id, 'all', warning_msg)
503
+
504
+ # Force using a new timestamp, ensuring the file name is unique
505
+ import time
506
+ timestamp = int(time.time())
507
+ updated_filename = f"experiment_stimuli_results_{session_id}_{timestamp}.csv"
508
+
509
+ # Ensure the generated dataframe is new and contains complete data
510
+ print(
511
+ f"Session {session_id} - Received dataframe with {len(df)} rows from generate_stimuli")
512
+
513
+ # Lock the session access to ensure thread safety
514
+ with sessions_lock:
515
+ if session_id in active_sessions:
516
+ # Clear the old dataframe and file information
517
+ print(
518
+ f"Session {session_id} - Cleaning up old dataframe and file information")
519
+ session_state['dataframe'] = None
520
+ session_state['generation_file'] = None
521
+
522
+ # Force saving the empty state, ensuring old data is cleared
523
+ save_session(session_id, session_state)
524
+
525
+ # Then set the new data
526
+ # Use deep copy, ensuring data is not shared
527
+ session_state['dataframe'] = df.copy()
528
+ # Use the newly generated file name
529
+ session_state['generation_file'] = updated_filename
530
+
531
+ # Record the completion status
532
+ if is_stopped:
533
+ print(
534
+ f"Session {session_id} - Generation stopped with partial data. File: {updated_filename}, Stimuli count: {len(df)}")
535
+ websocket_send(
536
+ session_id, 'all', f"Generation stopped. Saved {len(df)} stimuli.")
537
+ # Don't clear stop_event here - let status check handle it
538
+ # so it can detect this is a stopped generation with partial data
539
+ else:
540
+ print(
541
+ f"Session {session_id} - Generation completed successfully. New file: {updated_filename}, Stimuli count: {len(df)}")
542
+ websocket_send(
543
+ session_id, 'all', f"Generation completed. Generated {len(df)} stimuli.")
544
+
545
+ # Save the updated session state
546
+ save_session(session_id, session_state)
547
+ else:
548
+ print(
549
+ f"Session {session_id} - Warning: Session no longer active, cannot update state")
550
+ return
551
+ except Exception as e:
552
+ session_state['error_message'] = str(
553
+ e) # Record the error message
554
+ print(
555
+ f"Session {session_id} - Error during generation:", str(e))
556
+ # Send the error message through WebSocket
557
+ websocket_send(session_id, 'error', str(e))
558
+ # Save the updated session state
559
+ save_session(session_id, session_state)
560
+
561
+ # Create and start the generation thread
562
+ # Final verification before starting new thread
563
+ with sessions_lock:
564
+ if session_id in active_sessions:
565
+ old_thread = session_state.get('generation_thread')
566
+ if old_thread and old_thread.is_alive():
567
+ error_msg = f"Session {session_id} - Cannot start new generation: old thread is still running"
568
+ print(error_msg)
569
+ return jsonify({'status': 'error', 'message': 'Previous generation is still running. Please wait and try again.'}), 409
570
+
571
+ session_state['generation_thread'] = Thread(target=run_generation)
572
+ # Set as a daemon thread
573
+ session_state['generation_thread'].daemon = True
574
+
575
+ # Add thread information for debugging
576
+ thread_id = session_state['generation_thread'].ident
577
+ print(
578
+ f"Session {session_id} - Starting new generation thread (ID will be available after start)")
579
+
580
+ session_state['generation_thread'].start()
581
+
582
+ # Get actual thread ID after start
583
+ actual_thread_id = session_state['generation_thread'].ident
584
+ print(
585
+ f"Session {session_id} - New generation thread started with ID: {actual_thread_id}")
586
+
587
+ return jsonify({
588
+ 'status': 'success',
589
+ 'message': 'Stimulus generation started.',
590
+ 'session_id': session_id,
591
+ 'total_iterations': settings['iteration']
592
+ })
593
+ except Exception as e:
594
+ print(f"Unexpected error in generate_stimulus API: {str(e)}")
595
+ return jsonify({'status': 'error', 'message': f'Server error: {str(e)}'}), 500
596
+
597
+
598
+ @app.route('/<session_id>/generation_status', methods=['GET'])
599
+ def generation_status(session_id):
600
+ # First try to get the session state from the global dictionary
601
+ with sessions_lock:
602
+ if session_id in active_sessions:
603
+ session_state = active_sessions[session_id]
604
+
605
+ # First check stop_event
606
+ if session_state['stop_event'].is_set():
607
+ # Check if the generation thread is still running
608
+ thread_running = session_state['generation_thread'] and session_state['generation_thread'].is_alive(
609
+ )
610
+
611
+ if thread_running:
612
+ # Thread is still running, wait for it to save partial data
613
+ print(
614
+ f"Session {session_id} - Stop signal set, waiting for thread to save partial data...")
615
+ # Return running status so frontend keeps polling
616
+ progress = (session_state['current_iteration'].value /
617
+ session_state['total_iterations'].value) * 100
618
+ progress = min(100, max(0, progress))
619
+ return jsonify({'status': 'running', 'progress': progress, 'stopping': True})
620
+
621
+ print(
622
+ f"Session {session_id} - Generation stopped by user (in-memory check).")
623
+ # Thread has finished - check if there's partial data available for download
624
+ if session_state['generation_file'] and session_state['dataframe'] is not None:
625
+ filename = session_state['generation_file']
626
+ row_count = len(session_state['dataframe'])
627
+ print(
628
+ f"Session {session_id} - Returning partial data file: {filename} with {row_count} rows")
629
+ websocket_send(session_id, 'all',
630
+ f"Generation stopped. {row_count} stimuli saved.")
631
+ # Clear stop event so we don't keep returning stopped status
632
+ session_state['stop_event'].clear()
633
+ save_session(session_id, session_state)
634
+ return jsonify({'status': 'completed', 'file': filename, 'partial': True})
635
+ else:
636
+ # No data to save
637
+ websocket_send(session_id, 'all',
638
+ "Generation stopped by user.")
639
+ session_state['generation_file'] = None
640
+ save_session(session_id, session_state)
641
+ return jsonify({'status': 'stopped'})
642
+
643
+ if session_state['error_message']:
644
+ return jsonify({'status': 'error', 'error_message': session_state['error_message']})
645
+
646
+ with session_state['current_iteration'].get_lock(), session_state['total_iterations'].get_lock():
647
+ # Check if the thread is still running
648
+ thread_running = session_state['generation_thread'] and session_state['generation_thread'].is_alive(
649
+ )
650
+
651
+ # Get the current progress
652
+ progress = (session_state['current_iteration'].value /
653
+ session_state['total_iterations'].value) * 100
654
+ # Ensure the progress is within 0-100 range
655
+ progress = min(100, max(0, progress))
656
+ print(
657
+ f"Session {session_id} - Progress: {progress:.2f}%, Current: {session_state['current_iteration'].value}, Total: {session_state['total_iterations'].value}, Thread running: {thread_running}")
658
+
659
+ # Only when the iteration is complete, the file is generated, and the generation thread is completed, it is truly completed
660
+ if (session_state['current_iteration'].value == session_state['total_iterations'].value
661
+ and session_state['generation_file']
662
+ and not thread_running):
663
+
664
+ # Check if the file name contains the current session ID, avoiding returning files from other sessions
665
+ filename = session_state['generation_file']
666
+ if session_id in filename:
667
+ # Output file information for debugging
668
+ print(f"Returning completed file: {filename}")
669
+ return jsonify({'status': 'completed', 'file': filename})
670
+ else:
671
+ # The file name does not match the current session, possibly an incorrect file
672
+ print(
673
+ f"Warning: File {filename} does not match session {session_id}")
674
+ session_state['error_message'] = "Generated file does not match current session"
675
+ return jsonify({'status': 'error', 'error_message': 'Generated file does not match session'})
676
+ else:
677
+ # If the thread is completed but the progress is incomplete or no file is produced, it may be an error during generation
678
+ if progress >= 100 and not thread_running and not session_state['generation_file']:
679
+ session_state['error_message'] = "Generation completed but no file was produced"
680
+ return jsonify({'status': 'error', 'error_message': 'Generation completed but no file was produced. Please refresh the page and try again.'})
681
+
682
+ return jsonify({'status': 'running', 'progress': progress})
683
+
684
+ # If there is no active session in the global dictionary, fall back to loading from the file
685
+ session_state = load_session(session_id)
686
+
687
+ # First check stop_event
688
+ if session_state['stop_event'].is_set():
689
+ print(f"Session {session_id} - Generation stopped by user.")
690
+ # Check if there's partial data available for download
691
+ if session_state['generation_file'] and session_state.get('dataframe') is not None:
692
+ filename = session_state['generation_file']
693
+ row_count = len(session_state['dataframe'])
694
+ print(
695
+ f"Session {session_id} - Returning partial data file (from disk): {filename} with {row_count} rows")
696
+ websocket_send(session_id, 'all',
697
+ f"Generation stopped. {row_count} stimuli saved.")
698
+ session_state['stop_event'].clear()
699
+ save_session(session_id, session_state)
700
+ return jsonify({'status': 'completed', 'file': filename, 'partial': True})
701
+ else:
702
+ websocket_send(session_id, 'all', "Generation stopped by user.")
703
+ session_state['generation_file'] = None
704
+ save_session(session_id, session_state)
705
+ return jsonify({'status': 'stopped'})
706
+
707
+ if session_state['error_message']:
708
+ return jsonify({'status': 'error', 'error_message': session_state['error_message']})
709
+
710
+ with session_state['current_iteration'].get_lock(), session_state['total_iterations'].get_lock():
711
+ progress = (session_state['current_iteration'].value /
712
+ session_state['total_iterations'].value) * 100
713
+ # Ensure the progress is within 0-100 range
714
+ progress = min(100, max(0, progress))
715
+ print(
716
+ f"Session {session_id} - Progress: {progress:.2f}%, Current: {session_state['current_iteration'].value}, Total: {session_state['total_iterations'].value}")
717
+
718
+ # When loading from a file, there is no thread information, so only based on progress and file situation to judge
719
+ if session_state['current_iteration'].value == session_state['total_iterations'].value and session_state['generation_file']:
720
+ # Check if the file name contains the current session ID
721
+ filename = session_state['generation_file']
722
+ if session_id in filename:
723
+ print(f"Returning completed file (from disk): {filename}")
724
+ return jsonify({'status': 'completed', 'file': filename})
725
+ else:
726
+ print(
727
+ f"Warning: File {filename} does not match session {session_id}")
728
+ session_state['error_message'] = "Generated file does not match current session"
729
+ return jsonify({'status': 'error', 'error_message': 'Generated file does not match session'})
730
+ else:
731
+ return jsonify({'status': 'running', 'progress': progress})
732
+
733
+
734
+ @app.route('/<session_id>/stop_generation', methods=['POST'])
735
+ def stop_generation(session_id):
736
+ # First try to get the session state from the global dictionary
737
+ with sessions_lock:
738
+ if session_id in active_sessions:
739
+ session_state = active_sessions[session_id]
740
+ # Directly set the stop_event in memory
741
+ session_state['stop_event'].set()
742
+ # Don't clear generation_file here - let the generation thread save partial data
743
+ # The status check will return the file if partial data is available
744
+ print(
745
+ f"Session {session_id} - Stop signal set directly in memory. Generation will be stopped and partial data saved.")
746
+ # Send the stop message through WebSocket
747
+ websocket_send(
748
+ session_id, 'all', "Stopping... Please wait for partial data to be saved.")
749
+ # Still save to file for persistence
750
+ save_session(session_id, session_state)
751
+ return jsonify({'message': 'Stopping generation. Partial data will be saved if available.'})
752
+
753
+ # If there is no active session in the global dictionary, fall back to loading from the file
754
+ session_state = load_session(session_id)
755
+ # Set the stop signal
756
+ session_state['stop_event'].set()
757
+ # Don't clear generation_file - let the generation thread save partial data
758
+ # Send the stop message through WebSocket
759
+ websocket_send(session_id, 'all',
760
+ "Stopping... Please wait for partial data to be saved.")
761
+ print(
762
+ f"Session {session_id} - Stop signal set. Generation will be stopped.")
763
+ # Save the updated state
764
+ save_session(session_id, session_state)
765
+ return jsonify({'message': 'Stopping generation. Partial data will be saved if available.'})
766
+
767
+
768
+ @app.route('/<session_id>/download/<filename>', methods=['GET'])
769
+ def download_file(session_id, filename):
770
+ # Load the session state
771
+ session_state = load_session(session_id)
772
+
773
+ # Check if there is a dataframe available
774
+ if session_state.get('dataframe') is None:
775
+ print(f"Error: No dataframe available for session {session_id}")
776
+ return jsonify({'message': 'No data available for download.'}), 404
777
+
778
+ # Compare the requested file name with the file name in the current session state
779
+ stored_filename = session_state.get('generation_file')
780
+ if stored_filename != filename:
781
+ print(
782
+ f"Warning: Requested file {filename} does not match current session file {stored_filename}")
783
+ # Force verification: if the file name does not match, reject the request instead of continuing
784
+ return jsonify({'message': 'Requested file does not match current session file'}), 400
785
+
786
+ try:
787
+ # Check if the requested filename contains a session ID
788
+ if session_id not in filename:
789
+ print(
790
+ f"Error: Requested file {filename} does not contain session ID {session_id}")
791
+ return jsonify({'message': 'Invalid file request: session ID mismatch'}), 400
792
+
793
+ # Get a dataframe copy for download, excluding metadata columns
794
+ df_to_download = session_state['dataframe'].copy()
795
+
796
+ # Remove metadata columns if they exist
797
+ metadata_columns = ['generation_timestamp', 'batch_id', 'total_iterations',
798
+ 'download_timestamp', 'error_occurred', 'error_message']
799
+ for col in metadata_columns:
800
+ if col in df_to_download.columns:
801
+ df_to_download = df_to_download.drop(columns=[col])
802
+
803
+ # Expand stimulus_content column if it exists
804
+ df_to_download = expand_stimulus_content_to_columns(df_to_download)
805
+
806
+ # Store timestamp in variable for logging but don't add to dataframe
807
+ current_timestamp = int(time.time())
808
+
809
+ # Create a temporary memory file object
810
+ buffer = io.StringIO()
811
+
812
+ # Write the dataframe to the buffer
813
+ df_to_download.to_csv(buffer, index=False)
814
+ buffer.seek(0) # Move the pointer back to the beginning
815
+
816
+ print(
817
+ f"Serving file {filename} with {len(df_to_download)} rows for session {session_id}")
818
+
819
+ # After download, clear the dataframe and file name in the session to avoid repeated download of old data
820
+ # Create a delayed cleanup function
821
+ def delayed_cleanup():
822
+ # Wait 2 seconds to ensure the file has been fully downloaded
823
+ time.sleep(2)
824
+ with sessions_lock:
825
+ if session_id in active_sessions:
826
+ session_state = active_sessions[session_id]
827
+ old_filename = session_state.get('generation_file')
828
+ if old_filename == filename: # Ensure we do not clear the new generation results
829
+ print(
830
+ f"Cleaning up dataframe and filename for session {session_id} after download")
831
+ session_state['dataframe'] = None
832
+ session_state['generation_file'] = None
833
+ save_session(session_id, session_state)
834
+
835
+ # Execute cleanup in the background thread, not blocking the current request
836
+ cleanup_thread = Thread(target=delayed_cleanup)
837
+ cleanup_thread.daemon = True
838
+ cleanup_thread.start()
839
+
840
+ # Return the CSV file from memory
841
+ try:
842
+ # Try using the newer version of Flask's parameter name
843
+ return send_file(
844
+ io.BytesIO(buffer.getvalue().encode()),
845
+ as_attachment=True,
846
+ download_name=filename,
847
+ mimetype='text/csv'
848
+ )
849
+ except TypeError:
850
+ # If failed, try using the older version of Flask's parameter name
851
+ return send_file(
852
+ io.BytesIO(buffer.getvalue().encode()),
853
+ as_attachment=True,
854
+ attachment_filename=filename,
855
+ mimetype='text/csv'
856
+ )
857
+ except Exception as e:
858
+ print(f"Error generating download file: {str(e)}")
859
+ return jsonify({'message': f'Error generating file: {str(e)}'}), 500
860
+
861
+
862
+ # WebSocket initialization event
863
+ @socketio.on('connect')
864
+ def handle_connect():
865
+ """
866
+ Handle WebSocket connection event - simplified to avoid WSGI conflicts
867
+ """
868
+ try:
869
+ # Get basic client info
870
+ client_sid = request.sid if request and hasattr(
871
+ request, 'sid') else None
872
+ session_id = None
873
+
874
+ try:
875
+ if request and hasattr(request, 'args'):
876
+ session_id = request.args.get('session_id')
877
+ except:
878
+ pass
879
+
880
+ print(f'Client connected: {client_sid}, Session ID: {session_id}')
881
+
882
+ # Do NOT perform any operations that might write to the response here
883
+ # Just return True to allow the connection
884
+ return True
885
+
886
+ except Exception as e:
887
+ print(f"Error in connection handler: {str(e)}")
888
+ return True
889
+
890
+
891
+ @socketio.on('join_session')
892
+ def handle_join_session(data):
893
+ """
894
+ Handle explicit join session request from client after connection is established
895
+ """
896
+ try:
897
+ client_sid = request.sid if request and hasattr(
898
+ request, 'sid') else None
899
+ if not client_sid:
900
+ return
901
+
902
+ session_id = data.get('session_id') if data else None
903
+ if not session_id:
904
+ return
905
+
906
+ # Join room safely after connection is fully established
907
+ from flask_socketio import join_room
908
+ join_room(session_id)
909
+ print(f'Client {client_sid} joined room {session_id}')
910
+
911
+ # Send confirmation
912
+ socketio.emit('server_status', {
913
+ 'status': 'connected',
914
+ 'message': f'Joined room {session_id}',
915
+ 'room_joined': True,
916
+ 'session_id': session_id
917
+ }, room=client_sid)
918
+ print(f'Sent confirmation to client {client_sid}')
919
+
920
+ except Exception as e:
921
+ print(f"Error in join_session handler: {str(e)}")
922
+
923
+
924
+ @socketio.on('disconnect')
925
+ def handle_disconnect(disconnect_reason=None):
926
+ """
927
+ Handle WebSocket disconnection event
928
+ This function is called when a client disconnects from the WebSocket server
929
+ """
930
+ try:
931
+ # Get client SID safely
932
+ client_sid = request.sid if request and hasattr(
933
+ request, 'sid') else "unknown"
934
+ print(
935
+ f'Client disconnected: {client_sid}, reason: {disconnect_reason}')
936
+
937
+ # No need to explicitly leave rooms - Socket.IO does this automatically
938
+ # But we can add additional cleanup if needed in the future
939
+ except Exception as e:
940
+ # Just log the error, don't try to send any messages (client is already disconnected)
941
+ print(f"Error in disconnect handler: {str(e)}")
942
+
943
+ # Return immediately to complete the disconnection process
944
+ return
945
+
946
+
947
+ @socketio.on('ping')
948
+ def handle_ping(data):
949
+ """
950
+ Handle ping messages from clients to keep connections alive
951
+ """
952
+ try:
953
+ client_sid = request.sid if request and hasattr(
954
+ request, 'sid') else None
955
+ if not client_sid:
956
+ return
957
+
958
+ response = {
959
+ 'time': data.get('time', 0) if data else 0,
960
+ 'server_time': time.time()
961
+ }
962
+
963
+ try:
964
+ socketio.emit('pong', response, room=client_sid)
965
+ except Exception as e:
966
+ print(f"Pong error: {str(e)}")
967
+
968
+ except Exception as e:
969
+ print(f"Error in ping handler: {str(e)}")
970
+
971
+ return
972
+
973
+
974
+ # Add session destruction function
975
+ def cleanup_session(session_id):
976
+ """Remove the specified session from the global dictionary"""
977
+ with sessions_lock:
978
+ if session_id in active_sessions:
979
+ # Ensure any running threads are stopped
980
+ session_state = active_sessions[session_id]
981
+ session_state['stop_event'].set()
982
+ if session_state['generation_thread'] and session_state['generation_thread'].is_alive():
983
+ try:
984
+ # Try to wait for the thread to end on its own
985
+ session_state['generation_thread'].join(timeout=0.5)
986
+ except Exception as e:
987
+ print(
988
+ f"Error joining thread for session {session_id}: {str(e)}")
989
+
990
+ # Remove from the dictionary
991
+ del active_sessions[session_id]
992
+ print(f"Session {session_id} removed from active sessions.")
993
+ return True
994
+ return False
995
+
996
+
997
+ # Modify the cleanup expired session function, including sessions in memory
998
+ def cleanup_sessions():
999
+ import time
1000
+ import glob
1001
+
1002
+ # Get all session files
1003
+ session_files = glob.glob(os.path.join(SESSION_DIR, '*.pkl'))
1004
+ current_time = time.time()
1005
+
1006
+ for session_file in session_files:
1007
+ # Get the last modified time of the file
1008
+ file_mod_time = os.path.getmtime(session_file)
1009
+ # If the file has not been modified in the last 24 hours, delete it
1010
+ if current_time - file_mod_time > 86400: # 24 hours = 86400 seconds
1011
+ try:
1012
+ # Extract the session ID from the file name
1013
+ session_id = os.path.basename(session_file).split('.')[0]
1014
+ # Clean up the session in the global dictionary
1015
+ cleanup_session(session_id)
1016
+ # Delete the file
1017
+ os.remove(session_file)
1018
+ print(f"Removed expired session file: {session_file}")
1019
+ except Exception as e:
1020
+ print(
1021
+ f"Failed to remove session file {session_file}: {str(e)}")
1022
+
1023
+ # Additional check, clean up all sessions in the global dictionary that do not exist in the file
1024
+ with sessions_lock:
1025
+ active_session_ids = list(active_sessions.keys())
1026
+
1027
+ for session_id in active_session_ids:
1028
+ session_file = get_session_file(session_id)
1029
+ if not os.path.exists(session_file):
1030
+ cleanup_session(session_id)
1031
+ print(f"Cleaned up orphaned session {session_id} from memory.")
1032
+
1033
+
1034
+ # Add error handlers for Socket.IO
1035
+ @socketio.on_error()
1036
+ def handle_error(e):
1037
+ """
1038
+ Global error handler for all Socket.IO events
1039
+ This function is called when an error occurs during Socket.IO event handling
1040
+ """
1041
+ print(f"Socket.IO error: {str(e)}")
1042
+ # Don't try to send error messages here, as it might cause another error
1043
+
1044
+
1045
+ @socketio.on_error_default
1046
+ def handle_default_error(e):
1047
+ """
1048
+ Default error handler for Socket.IO events
1049
+ This function is called when an error occurs during Socket.IO event handling
1050
+ and no specific error handler exists
1051
+ """
1052
+ print(f"Socket.IO default error: {str(e)}")
1053
+ # Don't try to send error messages here, as it might cause another error
1054
+
1055
+
1056
+ # Clean up expired sessions at startup
1057
+ cleanup_sessions()
1058
+
1059
+
1060
+ @app.route('/api/custom_model_inference', methods=['POST'])
1061
+ def custom_model_inference():
1062
+ data = request.get_json()
1063
+ session_id = data.get('session_id')
1064
+ prompt = data.get('prompt')
1065
+ model = data.get('model')
1066
+ api_url = data.get('api_url')
1067
+ api_key = data.get('api_key')
1068
+ params = data.get('params')
1069
+
1070
+ result, status_code = custom_model_inference_handler(
1071
+ session_id,
1072
+ prompt,
1073
+ model,
1074
+ api_url,
1075
+ api_key,
1076
+ params
1077
+ )
1078
+ return jsonify(result), status_code
1079
+
1080
+
1081
+ def restart_space():
1082
+ """
1083
+ Restart Huggingface Spaces to prevent automatic sleep after 48 hours
1084
+ """
1085
+ try:
1086
+ # Get Huggingface token from environment
1087
+ hf_token = os.environ.get("HF_TOKEN")
1088
+ if not hf_token:
1089
+ print("Warning: HF_TOKEN not found in environment variables. Auto-restart will not work.")
1090
+ return
1091
+
1092
+ # Get Space repository ID from environment (format: username/space-name)
1093
+ repo_id = os.environ.get("SPACE_ID")
1094
+ if not repo_id:
1095
+ print("Warning: SPACE_ID not found in environment variables. Auto-restart will not work.")
1096
+ return
1097
+
1098
+ # Restart the space using Huggingface API
1099
+ api_url = f"https://huggingface.co/api/spaces/{repo_id}/restart"
1100
+ headers = {
1101
+ "Authorization": f"Bearer {hf_token}",
1102
+ "Content-Type": "application/json"
1103
+ }
1104
+
1105
+ response = requests.post(api_url, headers=headers)
1106
+ if response.status_code == 200:
1107
+ print(f"Successfully restarted Huggingface Space: {repo_id}")
1108
+ else:
1109
+ print(f"Failed to restart Huggingface Space: {response.status_code} - {response.text}")
1110
+
1111
+ except Exception as e:
1112
+ print(f"Error restarting Huggingface Space: {str(e)}")
1113
+
1114
+
1115
+ # Initialize scheduler for anti-sleep functionality
1116
+ scheduler = BackgroundScheduler()
1117
+
1118
+
1119
+ if __name__ == '__main__':
1120
+ print("Starting Stimulus Generator server...")
1121
+ print("WebSocket server configured with:")
1122
+ print(f" - Async mode: {socketio.async_mode}")
1123
+ print(f" - Ping interval: 10s")
1124
+ print(f" - Ping timeout: 30s")
1125
+ print(f" - HTTP Compression: Disabled")
1126
+ print(f" - Session Management: Disabled")
1127
+
1128
+ # Set up anti-sleep scheduler for Huggingface Spaces
1129
+ try:
1130
+ # Add job to restart space every 47.7 hours (172,000 seconds) to prevent 48h auto-sleep
1131
+ scheduler.add_job(restart_space, "interval", seconds=172000)
1132
+ scheduler.start()
1133
+ print("Anti-sleep scheduler started - will restart every 47.7 hours to prevent auto-sleep")
1134
+ except Exception as e:
1135
+ print(f"Warning: Failed to start anti-sleep scheduler: {str(e)}")
1136
+
1137
+ # detect if running in production environment
1138
+ is_production = os.environ.get('PRODUCTION', 'false').lower() == 'true'
1139
+
1140
+ # choose different configurations based on the environment
1141
+ if is_production:
1142
+ # production environment configuration
1143
+ print("Running in production mode")
1144
+ socketio.run(
1145
+ app,
1146
+ host='0.0.0.0',
1147
+ port=int(os.environ.get('PORT', 5000)),
1148
+ debug=False,
1149
+ allow_unsafe_werkzeug=True,
1150
+ log_output=False
1151
+ )
1152
+ else:
1153
+ # development environment configuration
1154
+ print("Running in development mode")
1155
+ socketio.run(
1156
+ app,
1157
+ host='0.0.0.0',
1158
+ port=5000,
1159
+ debug=True,
1160
+ allow_unsafe_werkzeug=True,
1161
+ log_output=True
1162
+ )