kan0621 commited on
Commit
6b699cc
·
verified ·
1 Parent(s): 839e4d6

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +238 -74
backend.py CHANGED
@@ -1,17 +1,16 @@
1
- import openai
2
- import pandas as pd
3
  import json
4
-
5
- import time
6
  import random
7
- import requests
8
- from flask import request, jsonify
9
  from abc import ABC, abstractmethod
10
- import threading
11
- import multiprocessing
12
  from multiprocessing import Process, Queue
13
- import queue
14
- import traceback
 
 
 
15
 
16
  # Set OpenAI API key
17
  # openai.api_key = ""
@@ -66,9 +65,19 @@ Please help me construct one item as stimuli for a psycholinguistic experiment b
66
 
67
  Experimental stimuli design: {experiment_design}
68
 
69
- Existing stimuli: {previous_stimuli}
 
 
 
70
 
71
- Requirement: {generation_requirements} Please return in JSON format.
 
 
 
 
 
 
 
72
  """
73
 
74
  # ---- Agent 2 Prompt ----
@@ -132,7 +141,7 @@ class OpenAIClient(ModelClient):
132
  self.api_key = api_key
133
  if api_key:
134
  openai.api_key = api_key
135
- print(f"OpenAI API key set successfully, length: {len(api_key)}")
136
  else:
137
  print("Warning: No OpenAI API key provided!")
138
 
@@ -140,8 +149,6 @@ class OpenAIClient(ModelClient):
140
  """API call function, will be called by multiprocessing"""
141
  # set API key in subprocess
142
  openai.api_key = api_key
143
- print(
144
- f"OpenAI API key in subprocess: {api_key[:10]}..." if api_key else "None")
145
 
146
  return openai.ChatCompletion.create(
147
  model=params["model"],
@@ -181,12 +188,18 @@ class OpenAIClient(ModelClient):
181
  return json.loads(response['choices'][0]['message']['content'])
182
  except json.JSONDecodeError as e:
183
  print(f"Failed to parse OpenAI JSON response: {e}")
184
- return {"error": "Failed to parse response"}
185
- except Exception as e:
186
  print(f"OpenAI API error attempt {attempt + 1}/3: {e}")
187
  if attempt == 2:
188
- return {"error": f"API error after 3 attempts: {str(e)}"}
189
  time.sleep(2 ** attempt)
 
 
 
 
 
 
190
 
191
  def get_default_params(self):
192
  return {"model": "gpt-4o"}
@@ -273,21 +286,29 @@ class CustomModelClient(ModelClient):
273
 
274
  def _api_call(self, request_data, headers):
275
  """API call function, will be called by multiprocessing"""
276
- response = requests.post(
277
- self.api_url,
278
- headers=headers,
279
- json=request_data,
280
- timeout=60 # timeout for requests
281
- )
282
- response.raise_for_status()
283
- return response.json()
 
 
 
 
 
 
 
 
 
284
 
285
  def generate_completion(self, prompt, properties, params=None):
286
 
287
  is_deepseek = self.api_url.strip().startswith("https://api.deepseek.com")
288
 
289
  if is_deepseek:
290
- import time
291
  rand_stamp = int(time.time())
292
  # Generate field list
293
  field_list = ', '.join([f'"{k}"' for k in properties.keys()])
@@ -347,7 +368,7 @@ class CustomModelClient(ModelClient):
347
  json.dumps(request_data, indent=2))
348
 
349
  result = call_with_timeout(
350
- self._api_call, (request_data, headers), {}, 60)
351
 
352
  if isinstance(result, dict) and "error" in result:
353
  print(f"Custom API timeout attempt {attempt + 1}/3")
@@ -361,15 +382,28 @@ class CustomModelClient(ModelClient):
361
  content = result["choices"][0]["message"]["content"]
362
  return json.loads(content)
363
 
364
- except (json.JSONDecodeError, KeyError) as e:
365
- print(f"Custom API parsing error attempt {attempt + 1}/3: {e}")
 
366
  if attempt == 2:
367
- return {"error": f"API parsing error after 3 attempts: {str(e)}"}
368
  time.sleep(2 ** attempt)
369
- except Exception as e:
370
- print(f"Custom API error attempt {attempt + 1}/3: {e}")
 
 
 
 
 
 
 
371
  if attempt == 2:
372
- return {"error": f"API error after 3 attempts: {str(e)}"}
 
 
 
 
 
373
  time.sleep(2 ** attempt)
374
 
375
  def get_default_params(self):
@@ -384,7 +418,6 @@ def create_model_client(model_choice, settings=None):
384
  """Factory function to create appropriate model client"""
385
  if model_choice == 'GPT-4o':
386
  api_key = settings.get('api_key') if settings else None
387
- print(f"OpenAI API key length: {len(api_key) if api_key else 0}")
388
  return OpenAIClient(api_key)
389
  elif model_choice == 'custom':
390
  if not settings:
@@ -411,8 +444,15 @@ def check_stimulus_repetition(new_stimulus_dict, previous_stimuli_list):
411
  for existing_stimulus in previous_stimuli_list:
412
  for key, new_value in new_stimulus_dict.items():
413
  # If the key exists in existing_stimulus and the values are the same, it is considered a repetition
414
- if key in existing_stimulus and existing_stimulus[key].lower() == str(new_value.lower()):
415
- return True
 
 
 
 
 
 
 
416
 
417
  return False
418
 
@@ -422,6 +462,7 @@ def agent_1_generate_stimulus(
422
  experiment_design,
423
  previous_stimuli,
424
  properties,
 
425
  prompt_template=AGENT_1_PROMPT_TEMPLATE,
426
  params=None,
427
  stop_event=None):
@@ -433,11 +474,15 @@ def agent_1_generate_stimulus(
433
  return {"stimulus": "STOPPED"}
434
 
435
  # Use fixed generation_requirements
436
- generation_requirements = "Please generate a new stimulus in the same format as the existing stimuli, and ensure that the new stimulus is different from those in the existing stimuli."
 
 
 
437
 
438
  prompt = prompt_template.format(
439
  experiment_design=experiment_design,
440
  previous_stimuli=previous_stimuli,
 
441
  generation_requirements=generation_requirements
442
  )
443
 
@@ -454,8 +499,11 @@ def agent_1_generate_stimulus(
454
  return {"stimulus": "ERROR/ERROR"}
455
 
456
  return result
457
- except Exception as e:
458
- print(f"Error in agent_1_generate_stimulus: {e}")
 
 
 
459
  return {"stimulus": "ERROR/ERROR"}
460
 
461
 
@@ -498,9 +546,12 @@ def agent_2_validate_stimulus(
498
  return {"error": f"Failed to validate stimulus: {result.get('error', 'Unknown error')}"}
499
 
500
  return result
501
- except Exception as e:
502
- print(f"Error in agent_2_validate_stimulus: {e}")
503
- return {"error": "Failed to validate stimulus"}
 
 
 
504
 
505
 
506
  def agent_2_validate_stimulus_individual(
@@ -614,9 +665,12 @@ Please return in JSON format with only one field: "{property_name}" (boolean: tr
614
  "validator", "All criteria passed successfully!")
615
  return validation_results
616
 
617
- except Exception as e:
618
- print(f"Error in agent_2_validate_stimulus_individual: {e}")
619
- return {"error": "Failed to validate stimulus individually"}
 
 
 
620
 
621
 
622
  def generate_scoring_requirements(properties):
@@ -677,8 +731,11 @@ def agent_3_score_stimulus(
677
  return {field: 0 for field in properties.keys()}
678
 
679
  return result
680
- except Exception as e:
681
- print(f"Error in agent_3_score_stimulus: {e}")
 
 
 
682
  return {field: 0 for field in properties.keys()}
683
 
684
 
@@ -810,8 +867,11 @@ Please return in JSON format with only one field: "{aspect_name}" (integer score
810
  "scorer", f"Individual scoring completed! Total: {total_score}/{max_possible}")
811
  return scoring_results
812
 
813
- except Exception as e:
814
- print(f"Error in agent_3_score_stimulus_individual: {e}")
 
 
 
815
  return {field: 0 for field in properties.keys()}
816
 
817
 
@@ -860,11 +920,47 @@ def generate_stimuli(settings):
860
  return True
861
  return False
862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
  # Immediately check if stopped
864
  if check_stop("Generation stopped before starting."):
865
  return None, None
866
 
867
  record_list = []
 
868
  agent_1_properties = settings.get('agent_1_properties', {})
869
  print("Agent 1 Properties:", agent_1_properties)
870
  if websocket_callback:
@@ -923,8 +1019,9 @@ def generate_stimuli(settings):
923
  # Get actual total iterations
924
  total_iter_value = total_iterations.value
925
  for iteration_num in range(total_iter_value):
926
- if check_stop():
927
- return None, None
 
928
 
929
  round_message = f"=== No. {iteration_num + 1} Round ==="
930
  print(round_message)
@@ -932,9 +1029,11 @@ def generate_stimuli(settings):
932
  websocket_callback("all", round_message)
933
 
934
  # Step 1: Generate stimulus
 
935
  while True:
936
- if check_stop():
937
- return None, None
 
938
 
939
  try:
940
  stimuli = agent_1_generate_stimulus(
@@ -942,27 +1041,52 @@ def generate_stimuli(settings):
942
  experiment_design=experiment_design,
943
  previous_stimuli=previous_stimuli,
944
  properties=agent_1_properties,
 
945
  prompt_template=AGENT_1_PROMPT_TEMPLATE,
946
  params=custom_params,
947
  stop_event=stop_event
948
  )
949
 
950
  if isinstance(stimuli, dict) and stimuli.get('stimulus') == 'STOPPED':
951
- if check_stop("Generation stopped after 'Generator'."):
952
- return None, None
 
 
 
 
 
 
 
 
 
 
953
 
954
  print("Agent 1 Output:", stimuli)
955
  if websocket_callback:
956
  websocket_callback(
957
  "generator", f"Generator's Output: {json.dumps(stimuli, indent=2)}")
958
 
959
- if check_stop("Generation stopped after 'Generator'."):
960
- return None, None
 
 
961
 
962
  # Step 1.5: Check if stimulus already exists
963
 
964
  if check_stimulus_repetition(stimuli, previous_stimuli):
965
  repetition_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
966
  if ablation["use_agent_2"]:
967
  print("Detected repeated stimulus, regenerating...")
968
 
@@ -977,8 +1101,9 @@ def generate_stimuli(settings):
977
  websocket_callback(
978
  "generator", "Ablation: Skipping Agent 2 (Repetition Check)")
979
 
980
- if check_stop():
981
- return None, None
 
982
 
983
  # Step 2: Validate stimulus
984
  # Check if individual validation is enabled
@@ -1011,16 +1136,20 @@ def generate_stimuli(settings):
1011
  )
1012
 
1013
  if isinstance(validation_result, dict) and validation_result.get('error') == 'Stopped by user':
1014
- if check_stop("Generation stopped after 'Validator'."):
1015
- return None, None
 
 
1016
 
1017
  print("Agent 2 Output:", validation_result)
1018
  if websocket_callback:
1019
  websocket_callback(
1020
  "validator", f"Validator's Output: {json.dumps(validation_result, indent=2)}")
1021
 
1022
- if check_stop("Generation stopped after 'Validator'."):
1023
- return None, None
 
 
1024
 
1025
  # Check if there was an error first
1026
  if 'error' in validation_result:
@@ -1037,12 +1166,41 @@ def generate_stimuli(settings):
1037
  if failed_fields:
1038
  # Some fields failed validation
1039
  validation_fails += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1040
  print(
1041
  f"Failed validation for fields: {failed_fields}, regenerating...")
1042
  if websocket_callback:
1043
  websocket_callback(
1044
  "validator", f"Failed validation for fields: {failed_fields}, regenerating...")
1045
 
 
 
 
 
 
 
 
 
 
 
1046
  if ablation["use_agent_2"]:
1047
  continue # Regenerate
1048
  else:
@@ -1080,8 +1238,6 @@ def generate_stimuli(settings):
1080
  df['error_occurred'] = True
1081
  df['error_message'] = str(e)
1082
 
1083
- import os
1084
-
1085
  os.makedirs("outputs", exist_ok=True)
1086
  suggested_filename = os.path.join(
1087
  "outputs", f"experiment_stimuli_results_{session_id}_{timestamp}_{unique_id}.csv")
@@ -1090,12 +1246,16 @@ def generate_stimuli(settings):
1090
  else:
1091
  raise e
1092
 
1093
- if check_stop("Generation stopped after 'Validator'."):
1094
- return None, None
 
 
1095
 
1096
  try:
1097
- if check_stop("Generation stopped before Scorer."):
1098
- return None, None
 
 
1099
 
1100
  # Step 3: Score
1101
  if ablation["use_agent_3"]:
@@ -1130,16 +1290,20 @@ def generate_stimuli(settings):
1130
 
1131
  if isinstance(scores, dict) and all(v == 0 for v in scores.values()):
1132
  if stop_event.is_set():
1133
- if check_stop("Generation stopped after 'Scorer'."):
1134
- return None, None
 
 
1135
 
1136
  print("Agent 3 Output:", scores)
1137
  if websocket_callback:
1138
  websocket_callback(
1139
  "scorer", f"Scorer's Output: {json.dumps(scores, indent=2)}")
1140
 
1141
- if check_stop("Generation stopped after 'Scorer'."):
1142
- return None, None
 
 
1143
  else:
1144
  print("Ablation: Skipping Agent 3 (Scoring)")
1145
  if websocket_callback:
 
 
 
1
  import json
2
+ import os
3
+ import queue
4
  import random
5
+ import time
6
+ import traceback
7
  from abc import ABC, abstractmethod
 
 
8
  from multiprocessing import Process, Queue
9
+
10
+ import openai
11
+ import pandas as pd
12
+ import requests
13
+ from requests.exceptions import RequestException, Timeout, ConnectionError as RequestsConnectionError
14
 
15
  # Set OpenAI API key
16
  # openai.api_key = ""
 
65
 
66
  Experimental stimuli design: {experiment_design}
67
 
68
+ Existing stimuli (DO NOT repeat any of these): {previous_stimuli}
69
+
70
+ Previously rejected stimuli with validation feedback (learn from these failures and avoid similar issues):
71
+ {rejected_stimuli}
72
 
73
+ CRITICAL REQUIREMENTS:
74
+ 1. Generate a COMPLETELY NEW and UNIQUE stimulus that is DIFFERENT from ALL existing stimuli above.
75
+ 2. Do NOT repeat or slightly modify any existing stimulus - create something entirely original.
76
+ 3. Avoid any content that overlaps with existing or rejected stimuli.
77
+ 4. Learn from the rejected stimuli above - understand why they failed validation and avoid making similar mistakes.
78
+ {generation_requirements}
79
+
80
+ Please return in JSON format.
81
  """
82
 
83
  # ---- Agent 2 Prompt ----
 
141
  self.api_key = api_key
142
  if api_key:
143
  openai.api_key = api_key
144
+ print("OpenAI API key configured successfully")
145
  else:
146
  print("Warning: No OpenAI API key provided!")
147
 
 
149
  """API call function, will be called by multiprocessing"""
150
  # set API key in subprocess
151
  openai.api_key = api_key
 
 
152
 
153
  return openai.ChatCompletion.create(
154
  model=params["model"],
 
188
  return json.loads(response['choices'][0]['message']['content'])
189
  except json.JSONDecodeError as e:
190
  print(f"Failed to parse OpenAI JSON response: {e}")
191
+ return {"error": f"Failed to parse response: {str(e)}"}
192
+ except (openai.error.APIError, openai.error.RateLimitError) as e:
193
  print(f"OpenAI API error attempt {attempt + 1}/3: {e}")
194
  if attempt == 2:
195
+ return {"error": f"OpenAI API error after 3 attempts: {str(e)}"}
196
  time.sleep(2 ** attempt)
197
+ except openai.error.AuthenticationError as e:
198
+ print(f"OpenAI authentication error: {e}")
199
+ return {"error": f"Authentication failed: {str(e)}"}
200
+ except openai.error.InvalidRequestError as e:
201
+ print(f"OpenAI invalid request: {e}")
202
+ return {"error": f"Invalid request: {str(e)}"}
203
 
204
  def get_default_params(self):
205
  return {"model": "gpt-4o"}
 
286
 
287
  def _api_call(self, request_data, headers):
288
  """API call function, will be called by multiprocessing"""
289
+ try:
290
+ response = requests.post(
291
+ self.api_url,
292
+ headers=headers,
293
+ json=request_data,
294
+ timeout=60 # timeout for requests
295
+ )
296
+ response.raise_for_status()
297
+ return response.json()
298
+ except Timeout:
299
+ raise Timeout(
300
+ f"Request to {self.api_url} timed out after 60 seconds")
301
+ except RequestsConnectionError as e:
302
+ raise RequestsConnectionError(
303
+ f"Failed to connect to {self.api_url}: {str(e)}")
304
+ except RequestException as e:
305
+ raise RequestException(f"Request failed: {str(e)}")
306
 
307
  def generate_completion(self, prompt, properties, params=None):
308
 
309
  is_deepseek = self.api_url.strip().startswith("https://api.deepseek.com")
310
 
311
  if is_deepseek:
 
312
  rand_stamp = int(time.time())
313
  # Generate field list
314
  field_list = ', '.join([f'"{k}"' for k in properties.keys()])
 
368
  json.dumps(request_data, indent=2))
369
 
370
  result = call_with_timeout(
371
+ self._api_call, (request_data, headers), {}, 600)
372
 
373
  if isinstance(result, dict) and "error" in result:
374
  print(f"Custom API timeout attempt {attempt + 1}/3")
 
382
  content = result["choices"][0]["message"]["content"]
383
  return json.loads(content)
384
 
385
+ except json.JSONDecodeError as e:
386
+ print(
387
+ f"Custom API JSON parsing error attempt {attempt + 1}/3: {e}")
388
  if attempt == 2:
389
+ return {"error": f"API JSON parsing error after 3 attempts: {str(e)}"}
390
  time.sleep(2 ** attempt)
391
+ except KeyError as e:
392
+ print(
393
+ f"Custom API response missing expected key attempt {attempt + 1}/3: {e}")
394
+ if attempt == 2:
395
+ return {"error": f"API response missing expected key after 3 attempts: {str(e)}"}
396
+ time.sleep(2 ** attempt)
397
+ except (Timeout, RequestsConnectionError) as e:
398
+ print(
399
+ f"Custom API connection error attempt {attempt + 1}/3: {e}")
400
  if attempt == 2:
401
+ return {"error": f"API connection error after 3 attempts: {str(e)}"}
402
+ time.sleep(2 ** attempt)
403
+ except RequestException as e:
404
+ print(f"Custom API request error attempt {attempt + 1}/3: {e}")
405
+ if attempt == 2:
406
+ return {"error": f"API request error after 3 attempts: {str(e)}"}
407
  time.sleep(2 ** attempt)
408
 
409
  def get_default_params(self):
 
418
  """Factory function to create appropriate model client"""
419
  if model_choice == 'GPT-4o':
420
  api_key = settings.get('api_key') if settings else None
 
421
  return OpenAIClient(api_key)
422
  elif model_choice == 'custom':
423
  if not settings:
 
444
  for existing_stimulus in previous_stimuli_list:
445
  for key, new_value in new_stimulus_dict.items():
446
  # If the key exists in existing_stimulus and the values are the same, it is considered a repetition
447
+ if key in existing_stimulus:
448
+ try:
449
+ existing_val = str(existing_stimulus[key]).lower()
450
+ new_val = str(new_value).lower()
451
+ if existing_val == new_val:
452
+ return True
453
+ except (AttributeError, TypeError):
454
+ # Skip comparison if values can't be converted to string
455
+ continue
456
 
457
  return False
458
 
 
462
  experiment_design,
463
  previous_stimuli,
464
  properties,
465
+ rejected_stimuli=None,
466
  prompt_template=AGENT_1_PROMPT_TEMPLATE,
467
  params=None,
468
  stop_event=None):
 
474
  return {"stimulus": "STOPPED"}
475
 
476
  # Use fixed generation_requirements
477
+ generation_requirements = "5. Follow the same JSON format as the existing stimuli."
478
+
479
+ if rejected_stimuli is None:
480
+ rejected_stimuli = []
481
 
482
  prompt = prompt_template.format(
483
  experiment_design=experiment_design,
484
  previous_stimuli=previous_stimuli,
485
+ rejected_stimuli=rejected_stimuli,
486
  generation_requirements=generation_requirements
487
  )
488
 
 
499
  return {"stimulus": "ERROR/ERROR"}
500
 
501
  return result
502
+ except (json.JSONDecodeError, KeyError, TypeError) as e:
503
+ print(f"Error parsing response in agent_1_generate_stimulus: {e}")
504
+ return {"stimulus": "ERROR/ERROR"}
505
+ except (RequestException, Timeout) as e:
506
+ print(f"Network error in agent_1_generate_stimulus: {e}")
507
  return {"stimulus": "ERROR/ERROR"}
508
 
509
 
 
546
  return {"error": f"Failed to validate stimulus: {result.get('error', 'Unknown error')}"}
547
 
548
  return result
549
+ except (json.JSONDecodeError, KeyError, TypeError) as e:
550
+ print(f"Error parsing validation response: {e}")
551
+ return {"error": f"Failed to parse validation response: {str(e)}"}
552
+ except (RequestException, Timeout) as e:
553
+ print(f"Network error in validation: {e}")
554
+ return {"error": f"Network error during validation: {str(e)}"}
555
 
556
 
557
  def agent_2_validate_stimulus_individual(
 
665
  "validator", "All criteria passed successfully!")
666
  return validation_results
667
 
668
+ except (json.JSONDecodeError, KeyError, TypeError) as e:
669
+ print(f"Error parsing individual validation response: {e}")
670
+ return {"error": f"Failed to parse validation response: {str(e)}"}
671
+ except (RequestException, Timeout) as e:
672
+ print(f"Network error in individual validation: {e}")
673
+ return {"error": f"Network error during validation: {str(e)}"}
674
 
675
 
676
  def generate_scoring_requirements(properties):
 
731
  return {field: 0 for field in properties.keys()}
732
 
733
  return result
734
+ except (json.JSONDecodeError, KeyError, TypeError) as e:
735
+ print(f"Error parsing scoring response: {e}")
736
+ return {field: 0 for field in properties.keys()}
737
+ except (RequestException, Timeout) as e:
738
+ print(f"Network error in scoring: {e}")
739
  return {field: 0 for field in properties.keys()}
740
 
741
 
 
867
  "scorer", f"Individual scoring completed! Total: {total_score}/{max_possible}")
868
  return scoring_results
869
 
870
+ except (json.JSONDecodeError, KeyError, TypeError) as e:
871
+ print(f"Error parsing individual scoring response: {e}")
872
+ return {field: 0 for field in properties.keys()}
873
+ except (RequestException, Timeout) as e:
874
+ print(f"Network error in individual scoring: {e}")
875
  return {field: 0 for field in properties.keys()}
876
 
877
 
 
920
  return True
921
  return False
922
 
923
+ # Helper function to create partial result when error or stop occurs
924
+ def create_partial_result(record_list, message, is_error=True):
925
+ nonlocal total_iterations
926
+ if len(record_list) > 0:
927
+ df = pd.DataFrame(record_list)
928
+ session_id = settings.get('session_id', 'default')
929
+ timestamp = int(time.time())
930
+ unique_id = ''.join(random.choice('0123456789abcdef')
931
+ for _ in range(6))
932
+ suffix = "_partial" if is_error else "_stopped"
933
+ suggested_filename = f"experiment_stimuli_results_{session_id}_{timestamp}_{unique_id}{suffix}.csv"
934
+
935
+ df['generation_timestamp'] = timestamp
936
+ df['batch_id'] = unique_id
937
+ df['total_iterations'] = total_iterations.value
938
+ df['stopped_by_user'] = not is_error
939
+ df['error_occurred'] = is_error
940
+ df['message'] = message
941
+ df['completed_iterations'] = len(record_list)
942
+
943
+ os.makedirs("outputs", exist_ok=True)
944
+ suggested_filename = os.path.join("outputs", suggested_filename)
945
+
946
+ return df, suggested_filename
947
+ return None, None
948
+
949
+ # Helper function to check stop and return partial data if available
950
+ def check_stop_and_return(message="Generation stopped by user."):
951
+ if stop_event.is_set():
952
+ print(message)
953
+ if websocket_callback:
954
+ websocket_callback("all", message)
955
+ return True, create_partial_result(record_list, message, is_error=False)
956
+ return False, (None, None)
957
+
958
  # Immediately check if stopped
959
  if check_stop("Generation stopped before starting."):
960
  return None, None
961
 
962
  record_list = []
963
+ rejected_stimuli_memory = []
964
  agent_1_properties = settings.get('agent_1_properties', {})
965
  print("Agent 1 Properties:", agent_1_properties)
966
  if websocket_callback:
 
1019
  # Get actual total iterations
1020
  total_iter_value = total_iterations.value
1021
  for iteration_num in range(total_iter_value):
1022
+ stopped, partial_result = check_stop_and_return()
1023
+ if stopped:
1024
+ return partial_result
1025
 
1026
  round_message = f"=== No. {iteration_num + 1} Round ==="
1027
  print(round_message)
 
1029
  websocket_callback("all", round_message)
1030
 
1031
  # Step 1: Generate stimulus
1032
+ current_retry_count = 0 # Retry counter for this iteration
1033
  while True:
1034
+ stopped, partial_result = check_stop_and_return()
1035
+ if stopped:
1036
+ return partial_result
1037
 
1038
  try:
1039
  stimuli = agent_1_generate_stimulus(
 
1041
  experiment_design=experiment_design,
1042
  previous_stimuli=previous_stimuli,
1043
  properties=agent_1_properties,
1044
+ rejected_stimuli=rejected_stimuli_memory,
1045
  prompt_template=AGENT_1_PROMPT_TEMPLATE,
1046
  params=custom_params,
1047
  stop_event=stop_event
1048
  )
1049
 
1050
  if isinstance(stimuli, dict) and stimuli.get('stimulus') == 'STOPPED':
1051
+ stopped, partial_result = check_stop_and_return(
1052
+ "Generation stopped after 'Generator'.")
1053
+ if stopped:
1054
+ return partial_result
1055
+
1056
+ # Skip validation if Agent 1 returned an error
1057
+ if isinstance(stimuli, dict) and stimuli.get('stimulus') == 'ERROR/ERROR':
1058
+ print("Agent 1 returned ERROR, regenerating...")
1059
+ if websocket_callback:
1060
+ websocket_callback(
1061
+ "generator", "Generator returned ERROR, regenerating...")
1062
+ continue
1063
 
1064
  print("Agent 1 Output:", stimuli)
1065
  if websocket_callback:
1066
  websocket_callback(
1067
  "generator", f"Generator's Output: {json.dumps(stimuli, indent=2)}")
1068
 
1069
+ stopped, partial_result = check_stop_and_return(
1070
+ "Generation stopped after 'Generator'.")
1071
+ if stopped:
1072
+ return partial_result
1073
 
1074
  # Step 1.5: Check if stimulus already exists
1075
 
1076
  if check_stimulus_repetition(stimuli, previous_stimuli):
1077
  repetition_count += 1
1078
+ current_retry_count += 1
1079
+
1080
+ # Add retry limit to avoid infinite loops (but never accept duplicates)
1081
+ max_repetition_retries = 50
1082
+ if current_retry_count > max_repetition_retries:
1083
+ error_msg = f"Failed to generate unique stimulus after {max_repetition_retries} attempts. Consider adjusting experiment design or reducing target count."
1084
+ print(error_msg)
1085
+ if websocket_callback:
1086
+ websocket_callback("generator", error_msg)
1087
+ # Return partial results instead of raising exception
1088
+ return create_partial_result(record_list, error_msg)
1089
+
1090
  if ablation["use_agent_2"]:
1091
  print("Detected repeated stimulus, regenerating...")
1092
 
 
1101
  websocket_callback(
1102
  "generator", "Ablation: Skipping Agent 2 (Repetition Check)")
1103
 
1104
+ stopped, partial_result = check_stop_and_return()
1105
+ if stopped:
1106
+ return partial_result
1107
 
1108
  # Step 2: Validate stimulus
1109
  # Check if individual validation is enabled
 
1136
  )
1137
 
1138
  if isinstance(validation_result, dict) and validation_result.get('error') == 'Stopped by user':
1139
+ stopped, partial_result = check_stop_and_return(
1140
+ "Generation stopped after 'Validator'.")
1141
+ if stopped:
1142
+ return partial_result
1143
 
1144
  print("Agent 2 Output:", validation_result)
1145
  if websocket_callback:
1146
  websocket_callback(
1147
  "validator", f"Validator's Output: {json.dumps(validation_result, indent=2)}")
1148
 
1149
+ stopped, partial_result = check_stop_and_return(
1150
+ "Generation stopped after 'Validator'.")
1151
+ if stopped:
1152
+ return partial_result
1153
 
1154
  # Check if there was an error first
1155
  if 'error' in validation_result:
 
1166
  if failed_fields:
1167
  # Some fields failed validation
1168
  validation_fails += 1
1169
+ current_retry_count += 1
1170
+
1171
+ # Add to rejected memory (only if it's a valid stimulus, not an error)
1172
+ is_error_stimulus = (
1173
+ isinstance(stimuli, dict) and
1174
+ stimuli.get('stimulus') in ['ERROR/ERROR', 'STOPPED']
1175
+ )
1176
+ if not is_error_stimulus:
1177
+ rejected_item = {
1178
+ "stimulus": stimuli,
1179
+ "validation_result": validation_result,
1180
+ "failed_fields": failed_fields
1181
+ }
1182
+ rejected_stimuli_memory.append(rejected_item)
1183
+ # Limit memory size to prevent unbounded growth
1184
+ MAX_REJECTED_MEMORY = 20
1185
+ if len(rejected_stimuli_memory) > MAX_REJECTED_MEMORY:
1186
+ rejected_stimuli_memory = rejected_stimuli_memory[-MAX_REJECTED_MEMORY:]
1187
+
1188
  print(
1189
  f"Failed validation for fields: {failed_fields}, regenerating...")
1190
  if websocket_callback:
1191
  websocket_callback(
1192
  "validator", f"Failed validation for fields: {failed_fields}, regenerating...")
1193
 
1194
+ # Check retry limit to avoid infinite loops
1195
+ max_retries = 50
1196
+ if current_retry_count > max_retries:
1197
+ error_msg = f"Failed to generate valid stimulus after {max_retries} attempts. Consider adjusting validation criteria."
1198
+ print(error_msg)
1199
+ if websocket_callback:
1200
+ websocket_callback("validator", error_msg)
1201
+ # Return partial results instead of raising exception
1202
+ return create_partial_result(record_list, error_msg)
1203
+
1204
  if ablation["use_agent_2"]:
1205
  continue # Regenerate
1206
  else:
 
1238
  df['error_occurred'] = True
1239
  df['error_message'] = str(e)
1240
 
 
 
1241
  os.makedirs("outputs", exist_ok=True)
1242
  suggested_filename = os.path.join(
1243
  "outputs", f"experiment_stimuli_results_{session_id}_{timestamp}_{unique_id}.csv")
 
1246
  else:
1247
  raise e
1248
 
1249
+ stopped, partial_result = check_stop_and_return(
1250
+ "Generation stopped after 'Validator'.")
1251
+ if stopped:
1252
+ return partial_result
1253
 
1254
  try:
1255
+ stopped, partial_result = check_stop_and_return(
1256
+ "Generation stopped before Scorer.")
1257
+ if stopped:
1258
+ return partial_result
1259
 
1260
  # Step 3: Score
1261
  if ablation["use_agent_3"]:
 
1290
 
1291
  if isinstance(scores, dict) and all(v == 0 for v in scores.values()):
1292
  if stop_event.is_set():
1293
+ stopped, partial_result = check_stop_and_return(
1294
+ "Generation stopped after 'Scorer'.")
1295
+ if stopped:
1296
+ return partial_result
1297
 
1298
  print("Agent 3 Output:", scores)
1299
  if websocket_callback:
1300
  websocket_callback(
1301
  "scorer", f"Scorer's Output: {json.dumps(scores, indent=2)}")
1302
 
1303
+ stopped, partial_result = check_stop_and_return(
1304
+ "Generation stopped after 'Scorer'.")
1305
+ if stopped:
1306
+ return partial_result
1307
  else:
1308
  print("Ablation: Skipping Agent 3 (Scoring)")
1309
  if websocket_callback: