Naisong Zhou commited on
Commit
d089d7b
·
1 Parent(s): 856481b

data save position fixed to id-determined row

Browse files
Files changed (3) hide show
  1. app.py +23 -19
  2. constants.py +4 -0
  3. save_data.py +27 -62
app.py CHANGED
@@ -1,19 +1,19 @@
1
  import gradio as gr
2
  from utils import *
3
- from save_data import add_new_data, get_sheet_service
4
  from instructions import *
5
  from user_groups import user_data
6
- from constants import SDG_DETAILS, GPT_PROMPT_parallel, GPT_PROMPT_sequential, GPT_PROMPT_reverse_sequential
7
  from html_codes import *
8
 
9
  class SessionManager:
10
  def __init__(self):
11
- self.sessions = []
12
 
13
- def add_session(self, cooperate_style, task):
14
  if cooperate_style == "sequential":
15
  session = {
16
- "user_identification_code": None,
17
  "task": task,
18
  "cooperate_style": cooperate_style,
19
  "human_initial_answer": None,
@@ -22,7 +22,7 @@ class SessionManager:
22
  }
23
  elif cooperate_style == "reverse_sequential":
24
  session = {
25
- "user_identification_code": None,
26
  "task": task,
27
  "cooperate_style": cooperate_style,
28
  "ai_initial_answer": None,
@@ -32,7 +32,7 @@ class SessionManager:
32
  }
33
  elif cooperate_style == "parallel":
34
  session = {
35
- "user_identification_code": None,
36
  "task": task,
37
  "cooperate_style": cooperate_style,
38
  "ai_initial_answer": None,
@@ -40,8 +40,8 @@ class SessionManager:
40
  "merged_final_answer": None,
41
  "evaluation": None
42
  }
43
- self.sessions.append(session)
44
- return len(self.sessions) - 1
45
 
46
  def update(self, index, output_content, key='final_output'):
47
  self.sessions[index][key] = output_content
@@ -51,14 +51,20 @@ class SessionManager:
51
 
52
  def save_session_to_sheet(self, index, service, SHEET_ID):
53
  session = self.sessions[index]
 
54
  new_row = list(session.values())
55
- add_new_data(new_row, service, SHEET_ID, num_of_columns=len(new_row)) # 动态列数
 
 
 
 
 
56
 
57
 
58
 
59
  def handle_create_sequential(task, human_input, session_manager, api_key, identification_code):
60
  cooperate_style = "sequential"
61
- session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style)
62
  session_manager.update(session_index, human_input, 'human_initial_answer')
63
  session_manager.update(session_index, identification_code, 'user_identification_code')
64
  if word_limit_validation(human_input):
@@ -71,7 +77,7 @@ def handle_create_sequential(task, human_input, session_manager, api_key, identi
71
 
72
  def handle_create_parallel(task, human_input, session_manager, api_key, identification_code):
73
  cooperate_style = "parallel"
74
- session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style)
75
  if word_limit_validation(human_input):
76
  ai_initial_answer = word_limit_validation(human_input)
77
  final_answer = word_limit_validation(human_input)
@@ -86,7 +92,7 @@ def handle_create_parallel(task, human_input, session_manager, api_key, identifi
86
 
87
  def handle_create_reverse_sequential(task, session_manager, api_key, identification_code):
88
  cooperate_style = "reverse_sequential"
89
- session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style)
90
  ai_initial_answer = generate_ai_initial_answer(task, api_key)
91
  session_manager.update(session_index, ai_initial_answer, 'ai_initial_answer')
92
  session_manager.update(session_index, identification_code, 'user_identification_code')
@@ -134,7 +140,7 @@ def login(identification_code):
134
 
135
  def word_limit_validation(human_input):
136
  words = human_input.split()
137
- if len(words) < 50:
138
  return f"Error: Please enter at least 50 words."
139
  return None
140
 
@@ -151,9 +157,8 @@ def check_initial_generated(initial_answer):
151
  return None
152
 
153
  if __name__ == "__main__":
154
- deploy_local = True
155
- api_key = get_api_key(local=deploy_local)
156
- service, SHEET_IDs = get_sheet_service(local=deploy_local)
157
  SHEET_ID1, SHEET_ID2, SHEET_ID3 = SHEET_IDs
158
 
159
  session_manager = SessionManager()
@@ -220,8 +225,7 @@ if __name__ == "__main__":
220
  )
221
 
222
 
223
- # evaluate same for every group
224
- #evaluate_btn = gr.Button("Evaluate", visible=False)
225
  # Evaluate without showing
226
  evaluation_result = gr.Textbox(label="Evaluation Result", visible=False, interactive = False)
227
 
 
1
  import gradio as gr
2
  from utils import *
3
+ from save_data import add_or_update_row_at_fixed_position, get_sheet_service
4
  from instructions import *
5
  from user_groups import user_data
6
+ from constants import SDG_DETAILS, WORD_LIMIT_MIN, GROUP_SEPERATION, LOCAL_PARAMS
7
  from html_codes import *
8
 
9
  class SessionManager:
10
  def __init__(self):
11
+ self.sessions = {}
12
 
13
+ def add_session(self, cooperate_style, task, identification_code):
14
  if cooperate_style == "sequential":
15
  session = {
16
+ "user_identification_code": identification_code,
17
  "task": task,
18
  "cooperate_style": cooperate_style,
19
  "human_initial_answer": None,
 
22
  }
23
  elif cooperate_style == "reverse_sequential":
24
  session = {
25
+ "user_identification_code": identification_code,
26
  "task": task,
27
  "cooperate_style": cooperate_style,
28
  "ai_initial_answer": None,
 
32
  }
33
  elif cooperate_style == "parallel":
34
  session = {
35
+ "user_identification_code": identification_code,
36
  "task": task,
37
  "cooperate_style": cooperate_style,
38
  "ai_initial_answer": None,
 
40
  "merged_final_answer": None,
41
  "evaluation": None
42
  }
43
+ self.sessions[identification_code] = session
44
+ return identification_code
45
 
46
  def update(self, index, output_content, key='final_output'):
47
  self.sessions[index][key] = output_content
 
51
 
52
  def save_session_to_sheet(self, index, service, SHEET_ID):
53
  session = self.sessions[index]
54
+ row_id = int(index) % GROUP_SEPERATION + 2 # user data starts from row 2
55
  new_row = list(session.values())
56
+ add_or_update_row_at_fixed_position(
57
+ row_id = row_id,
58
+ new_row = new_row,
59
+ service = service,
60
+ SPREADSHEET_ID = SHEET_ID,
61
+ num_of_columns=len(new_row))
62
 
63
 
64
 
65
  def handle_create_sequential(task, human_input, session_manager, api_key, identification_code):
66
  cooperate_style = "sequential"
67
+ session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style, identification_code = identification_code)
68
  session_manager.update(session_index, human_input, 'human_initial_answer')
69
  session_manager.update(session_index, identification_code, 'user_identification_code')
70
  if word_limit_validation(human_input):
 
77
 
78
  def handle_create_parallel(task, human_input, session_manager, api_key, identification_code):
79
  cooperate_style = "parallel"
80
+ session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style, identification_code = identification_code)
81
  if word_limit_validation(human_input):
82
  ai_initial_answer = word_limit_validation(human_input)
83
  final_answer = word_limit_validation(human_input)
 
92
 
93
  def handle_create_reverse_sequential(task, session_manager, api_key, identification_code):
94
  cooperate_style = "reverse_sequential"
95
+ session_index = session_manager.add_session(task=task, cooperate_style=cooperate_style, identification_code = identification_code)
96
  ai_initial_answer = generate_ai_initial_answer(task, api_key)
97
  session_manager.update(session_index, ai_initial_answer, 'ai_initial_answer')
98
  session_manager.update(session_index, identification_code, 'user_identification_code')
 
140
 
141
  def word_limit_validation(human_input):
142
  words = human_input.split()
143
+ if len(words) < WORD_LIMIT_MIN:
144
  return f"Error: Please enter at least 50 words."
145
  return None
146
 
 
157
  return None
158
 
159
  if __name__ == "__main__":
160
+ api_key = get_api_key(local=LOCAL_PARAMS)
161
+ service, SHEET_IDs = get_sheet_service(local=LOCAL_PARAMS)
 
162
  SHEET_ID1, SHEET_ID2, SHEET_ID3 = SHEET_IDs
163
 
164
  session_manager = SessionManager()
 
225
  )
226
 
227
 
228
+
 
229
  # Evaluate without showing
230
  evaluation_result = gr.Textbox(label="Evaluation Result", visible=False, interactive = False)
231
 
constants.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  SDG_DETAILS = """
2
  1. No Poverty: Eradicate all forms of poverty.
3
  2. Zero Hunger: End hunger and promote sustainable agriculture.
 
1
+ WORD_LIMIT_MIN = 50
2
+ GROUP_SEPERATION = 1000
3
+ LOCAL_PARAMS = False
4
+
5
  SDG_DETAILS = """
6
  1. No Poverty: Eradicate all forms of poverty.
7
  2. Zero Hunger: End hunger and promote sustainable agriculture.
save_data.py CHANGED
@@ -21,7 +21,7 @@ def load_envs(local=False):
21
 
22
  def get_sheet_service(local=False):
23
  """Get the google sheet service object."""
24
- service_account_info, SHEET_ID = load_envs(local=local)
25
  # verify the service_account_info
26
  credentials = service_account.Credentials.from_service_account_info(
27
  service_account_info,
@@ -30,15 +30,8 @@ def get_sheet_service(local=False):
30
 
31
  # build the service object
32
  service = build('sheets', 'v4', credentials=credentials)
33
- return service, SHEET_ID
34
 
35
- def col_letter(col_num):
36
- """Convert a column number to a column letter (1-indexed)."""
37
- letter = ''
38
- while col_num > 0:
39
- col_num, remainder = divmod(col_num - 1, 26)
40
- letter = chr(65 + remainder) + letter
41
- return letter
42
 
43
  def col_letter(col_num):
44
  """Convert a column number to its corresponding Excel-style letter."""
@@ -48,64 +41,36 @@ def col_letter(col_num):
48
  letter = chr(65 + remainder) + letter
49
  return letter
50
 
51
- def add_new_data(new_row, service, SPREADSHEET_ID, num_of_columns = 5):
52
- """Add new data to the spreadsheet.
53
- new_row: list of data to be added. """
54
- # read the existing data
55
- range_to_read = f'Sheet1!A:{col_letter(num_of_columns)}'
56
-
57
- result = service.spreadsheets().values().get(
 
 
 
 
 
 
 
 
 
 
 
 
58
  spreadsheetId=SPREADSHEET_ID,
59
- range=range_to_read
 
 
60
  ).execute()
61
-
62
-
63
- # search for first columns match
64
- match_found = False
65
- update_idx = None
66
- for idx, row in enumerate(result.get('values', [])):
67
- if row[0] == new_row[0]:
68
- match_found = True
69
- range_to_write = f'Sheet1!A{idx + 1}:{col_letter(num_of_columns)}{idx + 1}'
70
- update_idx = idx + 1
71
- break
72
-
73
- values = result.get('values', [])
74
- number_of_rows = len(values)
75
- new_row = [new_row]
76
- if not match_found:
77
- range_to_write = f'Sheet1!A{number_of_rows + 1}'
78
-
79
- request_body = {
80
- 'values': new_row
81
- }
82
- if match_found:
83
- service.spreadsheets().values().clear(
84
- spreadsheetId=SPREADSHEET_ID,
85
- range=range_to_write,
86
- body={},
87
- ).execute()
88
- response = service.spreadsheets().values().update(
89
- spreadsheetId=SPREADSHEET_ID,
90
- range=range_to_write,
91
- valueInputOption='RAW',
92
- body=request_body
93
- ).execute()
94
- print(f"Updated row at position {update_idx}")
95
- else:
96
- response = service.spreadsheets().values().append(
97
- spreadsheetId=SPREADSHEET_ID,
98
- range=range_to_write,
99
- valueInputOption='RAW',
100
- insertDataOption='INSERT_ROWS',
101
- body=request_body
102
- ).execute()
103
- print(f"Added new row at position {number_of_rows + 1}")
104
 
 
105
 
106
 
107
  if __name__ == "__main__":
108
- service, SHEET_ID = get_sheet_service(local=True)
109
  new_row = ["test1", "test2", "test3", "test4", "test5"]
110
- add_new_data(new_row, service, SHEET_ID)
111
 
 
21
 
22
  def get_sheet_service(local=False):
23
  """Get the google sheet service object."""
24
+ service_account_info, SHEET_IDs = load_envs(local=local)
25
  # verify the service_account_info
26
  credentials = service_account.Credentials.from_service_account_info(
27
  service_account_info,
 
30
 
31
  # build the service object
32
  service = build('sheets', 'v4', credentials=credentials)
33
+ return service, SHEET_IDs
34
 
 
 
 
 
 
 
 
35
 
36
  def col_letter(col_num):
37
  """Convert a column number to its corresponding Excel-style letter."""
 
41
  letter = chr(65 + remainder) + letter
42
  return letter
43
 
44
+ def add_or_update_row_at_fixed_position(row_id, new_row, service, SPREADSHEET_ID, num_of_columns=5):
45
+ """
46
+ Add or update data at a fixed row position in Google Sheet.
47
+
48
+ Args:
49
+ row_id: int, the row number (1-based index) where the data should be written.
50
+ new_row: list, the new data to be added or updated.
51
+ service: Google Sheets API service instance.
52
+ SPREADSHEET_ID: str, the spreadsheet ID.
53
+ num_of_columns: int, the number of columns to consider (default: 5).
54
+ """
55
+ # Ensure `new_row` has exactly `num_of_columns` elements
56
+ new_row = new_row + [""] * (num_of_columns - len(new_row)) if len(new_row) < num_of_columns else new_row[:num_of_columns]
57
+
58
+ # Determine the write range for the specified row
59
+ range_to_write = f"Sheet1!A{row_id}:{col_letter(num_of_columns)}{row_id}"
60
+
61
+ # Directly write the data to the specified row
62
+ response = service.spreadsheets().values().update(
63
  spreadsheetId=SPREADSHEET_ID,
64
+ range=range_to_write,
65
+ valueInputOption='RAW',
66
+ body={"values": [new_row]}
67
  ).execute()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ print(f"Data written to row {row_id}. Response: {response}")
70
 
71
 
72
  if __name__ == "__main__":
73
+ service, SHEET_IDs = get_sheet_service(local=True)
74
  new_row = ["test1", "test2", "test3", "test4", "test5"]
75
+ add_or_update_row_at_fixed_position(12, new_row, service, SHEET_IDs[0])
76