Vincent-Tann commited on
Commit
b8de9e2
·
1 Parent(s): 1477c41

fix Gradio Chatbot bug; clear comments.

Browse files
Files changed (6) hide show
  1. _appyibu.py +10 -8
  2. app.py +12 -15
  3. code_interpreter.py +8 -7
  4. display_model.py +16 -16
  5. object_filter_gpt4.py +2 -2
  6. transcrib3d_main.py +26 -23
_appyibu.py CHANGED
@@ -13,23 +13,25 @@ new_glb_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned_AddBox.glb"
13
  objects_info_file = os.path.join("objects_info", f"objects_info_{scan_id}.npy")
14
 
15
  def insert_user_none_between_assistant(messages):
16
- # 初始化结果列表
17
  result = []
18
- # 初始状态设置为"user",以确保列表第一个条目为"assistant"时能正确插入
 
19
  last_role = "user"
20
 
21
  for msg in messages:
22
- # 检查当前信息的角色
23
  current_role = msg["role"]
24
 
25
- # 如果上一个和当前信息均为"assistant",插入content为None的"user"信息
 
26
  if last_role == "assistant" and current_role == "assistant":
27
  result.append({"role": "user", "content": None})
28
 
29
- # 将当前信息添加到结果列表
30
  result.append(msg)
31
 
32
- # 更新上一条信息的角色
33
  last_role = current_role
34
 
35
  return result
@@ -122,11 +124,11 @@ with gr.Blocks() as demo:
122
  )
123
 
124
  # print("Type2:",type(model3d))
125
- # 直接在 inputs列表里写model3d,会导致实际传给callback函数的是str
126
  # bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
127
  bt.click(fn=process_instruction_callback, inputs=[user_instruction_textbox, gr.State(model3d), gr.State(dialogue)])#, outputs=[model3d,dialogue])
128
 
129
- # 直接用lambda函数定义一个映射
130
  # type(user_instruction_textbox.value)
131
  # user_instruction_textbox.
132
  # user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
 
13
  objects_info_file = os.path.join("objects_info", f"objects_info_{scan_id}.npy")
14
 
15
  def insert_user_none_between_assistant(messages):
16
+ # Initialize the result list
17
  result = []
18
+ # Set the initial state to "user" so insertion works correctly
19
+ # when the first item in the list is "assistant"
20
  last_role = "user"
21
 
22
  for msg in messages:
23
+ # Check the role of the current message
24
  current_role = msg["role"]
25
 
26
+ # If both the previous and current messages are "assistant",
27
+ # insert a "user" message whose content is None
28
  if last_role == "assistant" and current_role == "assistant":
29
  result.append({"role": "user", "content": None})
30
 
31
+ # Add the current message to the result list
32
  result.append(msg)
33
 
34
+ # Update the role of the previous message
35
  last_role = current_role
36
 
37
  return result
 
124
  )
125
 
126
  # print("Type2:",type(model3d))
127
+ # Passing model3d directly in the inputs list causes the callback to receive a string
128
  # bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
129
  bt.click(fn=process_instruction_callback, inputs=[user_instruction_textbox, gr.State(model3d), gr.State(dialogue)])#, outputs=[model3d,dialogue])
130
 
131
+ # Define a mapping directly with a lambda function
132
  # type(user_instruction_textbox.value)
133
  # user_instruction_textbox.
134
  # user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
app.py CHANGED
@@ -49,23 +49,23 @@ def insert_user_blank_between_assistant(messages):
49
 
50
  def timer_check_update(code_interpreter, update_interval, stop_event):
51
  """
52
- 定时检查 code_interpreter.has_update 是否为True
53
- 如果为True,则触发界面更新逻辑并重置状态。
54
- 参数:
55
- - code_interpreter: CodeInterpreter的实例,预期包含has_update属性。
56
- - update_interval: 定时器检查间隔,以秒为单位。
57
- - stop_event: 一个threading.Event()实例,用于停止定时器线程。
58
  """
59
  while not stop_event.is_set():
60
  if code_interpreter.has_update:
61
- # 实现更新界面显示的逻辑
62
  print("Detected update, trigger UI refreshing...")
63
- # 在这里添加更新界面显示的代码
64
  # ...
65
- # 重置has_update状态
66
  code_interpreter.has_update = False
67
 
68
- # 等待下次检查
69
  time.sleep(update_interval)
70
 
71
  def process_instruction_callback(inp_api_key, instruction, llm_name):
@@ -194,13 +194,12 @@ with gr.Blocks() as demo:
194
  )
195
  # Right-5: Dialogue
196
  dialogue = gr.Chatbot(
197
- type="messages",
198
  height=470
199
  )
200
 
201
 
202
  # print("Type2:",type(model3d))
203
- # 直接在 inputs列表里写model3d,会导致实际传给callback函数的是str
204
  # bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
205
  bt.click(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox,llm_name_text], outputs=[model3d,dialogue])
206
  user_instruction_textbox.submit(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox, llm_name_text], outputs=[model3d,dialogue])
@@ -208,15 +207,13 @@ with gr.Blocks() as demo:
208
  scene_type_dropdown.select(fn=scene_type_dropdown_callback, inputs=scene_type_dropdown, outputs=model3d)
209
  llm_dropdown.select(fn=llm_dropdown_callback, inputs=llm_dropdown, outputs=llm_name_text)
210
 
211
- # 直接用lambda函数定义一个映射
212
  # type(user_instruction_textbox.value)
213
  # user_instruction_textbox.
214
  # user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
215
  # user_instruction_textbox.
216
  # bt.click(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
217
 
218
- # os.system('uname -a') # 显示所有系统信息
219
- # demo.launch()
220
 
221
 
222
 
 
49
 
50
  def timer_check_update(code_interpreter, update_interval, stop_event):
51
  """
52
+ Periodically check whether code_interpreter.has_update is True.
53
+ If it is True, trigger the UI refresh logic and reset the state.
54
+ Args:
55
+ - code_interpreter: A CodeInterpreter instance expected to have a has_update attribute.
56
+ - update_interval: Timer polling interval in seconds.
57
+ - stop_event: A threading.Event() instance used to stop the timer thread.
58
  """
59
  while not stop_event.is_set():
60
  if code_interpreter.has_update:
61
+ # Implement the UI refresh logic
62
  print("Detected update, trigger UI refreshing...")
63
+ # Add UI refresh code here
64
  # ...
65
+ # Reset the has_update flag
66
  code_interpreter.has_update = False
67
 
68
+ # Wait until the next check
69
  time.sleep(update_interval)
70
 
71
  def process_instruction_callback(inp_api_key, instruction, llm_name):
 
194
  )
195
  # Right-5: Dialogue
196
  dialogue = gr.Chatbot(
 
197
  height=470
198
  )
199
 
200
 
201
  # print("Type2:",type(model3d))
202
+ # Passing model3d directly in the inputs list causes the callback to receive a string
203
  # bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
204
  bt.click(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox,llm_name_text], outputs=[model3d,dialogue])
205
  user_instruction_textbox.submit(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox, llm_name_text], outputs=[model3d,dialogue])
 
207
  scene_type_dropdown.select(fn=scene_type_dropdown_callback, inputs=scene_type_dropdown, outputs=model3d)
208
  llm_dropdown.select(fn=llm_dropdown_callback, inputs=llm_dropdown, outputs=llm_name_text)
209
 
210
+ # Define a mapping directly with a lambda function
211
  # type(user_instruction_textbox.value)
212
  # user_instruction_textbox.
213
  # user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
214
  # user_instruction_textbox.
215
  # bt.click(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
216
 
 
 
217
 
218
 
219
 
code_interpreter.py CHANGED
@@ -12,9 +12,10 @@ class CodeInterpreter(Dialogue):
12
  super().__init__(**kwargs)
13
 
14
  def call_openai_with_code_interpreter(self, user_prompt,namespace_for_exec={},token_usage_total=0):
15
- # 如果gpt回复的内容包含python代码,则把代码的执行结果发送给gpt,继续等待其回复
16
- # 如果gpt回复的内容不包含python代码,则此函数返回全部结果
17
- # 每次递归统计使用的token数,最终返回总的token数
 
18
  assistant_response,token_usage = self.call_openai(user_prompt)
19
  token_usage_total+=token_usage
20
 
@@ -52,15 +53,15 @@ class CodeInterpreter(Dialogue):
52
  # f.close()
53
  # sys.stdout = sys.__stdout__
54
 
55
- #############利用保存文件的方式####################
56
- # # 将代码片段保存到 code_snippet.py 文件
57
  # with open("code_snippet.py", "w") as file:
58
  # file.write(code_snippet)
59
 
60
- # # 执行 code_snippet.py 并将输出重定向到临时文件
61
  # os.system("python code_snippet.py > output.txt")
62
 
63
- # # 从临时文件中读取结果
64
  # with open("output.txt", "r") as file:
65
  # code_exe_result = file.read()
66
  ##################################################
 
12
  super().__init__(**kwargs)
13
 
14
  def call_openai_with_code_interpreter(self, user_prompt,namespace_for_exec={},token_usage_total=0):
15
+ # If the GPT response contains Python code, execute it and send the result
16
+ # back to GPT, then continue waiting for its reply.
17
+ # If the GPT response does not contain Python code, return the full result.
18
+ # Accumulate token usage on each recursive call and return the total at the end.
19
  assistant_response,token_usage = self.call_openai(user_prompt)
20
  token_usage_total+=token_usage
21
 
 
53
  # f.close()
54
  # sys.stdout = sys.__stdout__
55
 
56
+ #############Using the file-saving approach####################
57
+ # # Save the code snippet to the code_snippet.py file
58
  # with open("code_snippet.py", "w") as file:
59
  # file.write(code_snippet)
60
 
61
+ # # Execute code_snippet.py and redirect the output to a temporary file
62
  # os.system("python code_snippet.py > output.txt")
63
 
64
+ # # Read the result from the temporary file
65
  # with open("output.txt", "r") as file:
66
  # code_exe_result = file.read()
67
  ##################################################
display_model.py CHANGED
@@ -62,13 +62,13 @@ def add_1box_to_ply(box, ply_file, new_ply_file, line_width=0.05, obj_id=1):
62
  box_vertices['blue'] = [color[2]] * 16
63
  box_vertices['alpha'] = [obj_id] * 16
64
 
65
- # 将新的顶点数据添加到原始顶点数据后面
66
  updated_vertices = np.concatenate((vertices, box_vertices))
67
 
68
- # 创建包含新顶点的PlyElement对象
69
  updated_vertex_element = PlyElement.describe(updated_vertices, 'vertex')
70
 
71
- # 将更新后的PlyElement对象替换原始的顶点数据
72
  # ply_data['vertex'] = updated_vertex_element
73
 
74
  # get the number of original vertices:
@@ -108,18 +108,18 @@ def add_1box_to_ply(box, ply_file, new_ply_file, line_width=0.05, obj_id=1):
108
  box_faces = np.zeros(len(box_connections), dtype=faces.dtype)
109
  box_faces['vertex_indices'] = box_connections
110
 
111
- # 将新的face数据添加到原始顶点数据后面
112
  updated_faces = np.concatenate((faces, box_faces))
113
 
114
- # 创建包含新顶点的PlyElement对象
115
  updated_face_element = PlyElement.describe(updated_faces, 'face')
116
 
117
- # 将更新后的PlyElement对象替换原始的顶点数据
118
  # ply_data['face'] = updated_face_element
119
 
120
  new_ply_data = PlyData([updated_vertex_element, updated_face_element])
121
 
122
- # 将更新后的PlyData对象写回Ply文件
123
  with open(new_ply_file, 'wb') as f:
124
  new_ply_data.write(f)
125
 
@@ -127,36 +127,36 @@ def add_1box_to_ply(box, ply_file, new_ply_file, line_width=0.05, obj_id=1):
127
 
128
 
129
  def ply_to_obj(ply_file, obj_file, mtl_file):
130
- # 读取PLY文件
131
  with open(ply_file, 'rb') as f:
132
  plydata = PlyData.read(f)
133
 
134
- # 获取顶点和面数据
135
  vertices = np.vstack([plydata['vertex'][prop] for prop in ['x', 'y', 'z']]).T
136
  colors = np.vstack([plydata['vertex'][prop] for prop in ['red', 'green', 'blue', 'alpha']]).T/255.0
137
  faces = plydata['face']['vertex_indices']
138
 
139
- # 写入OBJ文件
140
  with open(obj_file, 'w') as f:
141
- # 写入依赖的mtl文件(颜色)
142
  f.write("mtllib %s\n"%mtl_file.split('/')[-1])
143
 
144
- # 写入顶点信息
145
  for vertex in vertices:
146
  f.write(f"v {' '.join(map(str, vertex))}\n")
147
 
148
- # 写入颜色信息
149
  for idx in range(len(vertices)):
150
  f.write("usemtl mat%d\n"%(idx+1))
151
 
152
- # 写入面信息
153
  for face in faces:
154
  f.write("f")
155
  for vertex_index in face:
156
- f.write(f" {vertex_index + 1}") # OBJ文件索引从1开始
157
  f.write("\n")
158
 
159
- # 写入mtl文件
160
  with open(mtl_file, 'w') as f:
161
  for idx, color in enumerate(colors):
162
  f.write("newmtl mat%d\n" % (idx+1))
 
62
  box_vertices['blue'] = [color[2]] * 16
63
  box_vertices['alpha'] = [obj_id] * 16
64
 
65
+ # Append the new vertex data after the original vertex data
66
  updated_vertices = np.concatenate((vertices, box_vertices))
67
 
68
+ # Create a PlyElement object containing the new vertices
69
  updated_vertex_element = PlyElement.describe(updated_vertices, 'vertex')
70
 
71
+ # Replace the original vertex data with the updated PlyElement object
72
  # ply_data['vertex'] = updated_vertex_element
73
 
74
  # get the number of original vertices:
 
108
  box_faces = np.zeros(len(box_connections), dtype=faces.dtype)
109
  box_faces['vertex_indices'] = box_connections
110
 
111
+ # Append the new face data after the original face data
112
  updated_faces = np.concatenate((faces, box_faces))
113
 
114
+ # Create a PlyElement object containing the new faces
115
  updated_face_element = PlyElement.describe(updated_faces, 'face')
116
 
117
+ # Replace the original face data with the updated PlyElement object
118
  # ply_data['face'] = updated_face_element
119
 
120
  new_ply_data = PlyData([updated_vertex_element, updated_face_element])
121
 
122
+ # Write the updated PlyData object back to the PLY file
123
  with open(new_ply_file, 'wb') as f:
124
  new_ply_data.write(f)
125
 
 
127
 
128
 
129
  def ply_to_obj(ply_file, obj_file, mtl_file):
130
+ # Read the PLY file
131
  with open(ply_file, 'rb') as f:
132
  plydata = PlyData.read(f)
133
 
134
+ # Get the vertex and face data
135
  vertices = np.vstack([plydata['vertex'][prop] for prop in ['x', 'y', 'z']]).T
136
  colors = np.vstack([plydata['vertex'][prop] for prop in ['red', 'green', 'blue', 'alpha']]).T/255.0
137
  faces = plydata['face']['vertex_indices']
138
 
139
+ # Write the OBJ file
140
  with open(obj_file, 'w') as f:
141
+ # Write the referenced MTL file (material colors)
142
  f.write("mtllib %s\n"%mtl_file.split('/')[-1])
143
 
144
+ # Write the vertex data
145
  for vertex in vertices:
146
  f.write(f"v {' '.join(map(str, vertex))}\n")
147
 
148
+ # Write the material assignments
149
  for idx in range(len(vertices)):
150
  f.write("usemtl mat%d\n"%(idx+1))
151
 
152
+ # Write the face data
153
  for face in faces:
154
  f.write("f")
155
  for vertex_index in face:
156
+ f.write(f" {vertex_index + 1}") # OBJ indices start from 1
157
  f.write("\n")
158
 
159
+ # Write the MTL file
160
  with open(mtl_file, 'w') as f:
161
  for idx, color in enumerate(colors):
162
  f.write("newmtl mat%d\n" % (idx+1))
object_filter_gpt4.py CHANGED
@@ -40,7 +40,7 @@ class ObjectFilter(Dialogue):
40
  super().__init__(**config)
41
 
42
  def extract_all_int_lists_from_text(self,text) ->list:
43
- # 匹配方括号内的内容
44
  pattern = r'\[([^\[\]]+)\]'
45
  matches = re.findall(pattern, text)
46
 
@@ -134,7 +134,7 @@ if __name__ == "__main__":
134
  scanrefer_data=json.load(json_file)
135
 
136
  from datetime import datetime
137
- # 记录时间作为文件名
138
  current_time = datetime.now()
139
  formatted_time = current_time.strftime("%Y-%m-%d-%H-%M-%S")
140
  print("formatted_time:",formatted_time)
 
40
  super().__init__(**config)
41
 
42
  def extract_all_int_lists_from_text(self,text) ->list:
43
+ # Match the content inside square brackets
44
  pattern = r'\[([^\[\]]+)\]'
45
  matches = re.findall(pattern, text)
46
 
 
134
  scanrefer_data=json.load(json_file)
135
 
136
  from datetime import datetime
137
+ # Record the current time to use as the folder name
138
  current_time = datetime.now()
139
  formatted_time = current_time.strftime("%Y-%m-%d-%H-%M-%S")
140
  print("formatted_time:",formatted_time)
transcrib3d_main.py CHANGED
@@ -98,10 +98,10 @@ def gen_prompt(user_instruction, scan_id):
98
  objects_related = objects_info
99
 
100
 
101
- # 获取场景的中心坐标
102
  # scene_center=get_scene_center(objects_related)
103
- scene_center = get_scene_center(objects_info) # 注意这里应该用所有物体的信息,而不只是relevant
104
- # 生成prompt中的背景信息部分
105
  prompt = scan_id + ":objects with quantitative description based on right-hand Cartesian coordinate system with x-y-z axes, x-y plane=ground, z-axis=up/down. Coords format [x, y, z].\n\n"
106
  # if dataset == 'nr3d':
107
  # prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
@@ -114,57 +114,58 @@ def gen_prompt(user_instruction, scan_id):
114
  prompt = prompt + "Scene center:%s. If no direction vector, observer at center for obj orientation.\n\n" % remove_spaces(str(scene_center))
115
  prompt = prompt + "objs list:\n"
116
  lines = []
117
- # 生成prompt中对物体的定量描述部分(遍历所有相关物体)
118
  for obj in objects_related:
119
- # 位置信息,保留2位小数
120
  center_position = obj['center_position']
121
  center_position = round_list(center_position, 2)
122
- # size信息,保留2位小数
123
  size = obj['size']
124
  size = round_list(size, 2)
125
- # extension信息,保留2位小数
126
  extension = obj['extension']
127
  extension = round_list(extension, 2)
128
- # 方向信息,用方向向量表示. 注意,scanrefer由于用的不是scannet原始obj id,所以不能用方向信息
 
129
  if obj['has_front']:
130
  front_point = np.array(obj['front_point'])
131
  center = np.array(obj['obb'][0:3])
132
  direction_vector = front_point - center
133
  direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector)
134
- # 再计算左和右的方向向量,全部保留两位小数
135
  front_vector = round_list(direction_vector_normalized, 2)
136
  up_vector = np.array([0, 0, 1])
137
  left_vector = round_list(np.cross(direction_vector_normalized, up_vector), 2)
138
  right_vector = round_list(np.cross(up_vector, direction_vector_normalized), 2)
139
  behind_vector = round_list(-np.array(front_vector), 2)
140
- # 生成方向信息
141
  direction_info = ";direction vectors:front=%s,left=%s,right=%s,behind=%s\n" %(front_vector, left_vector, right_vector, behind_vector)
142
  #
143
  else:
144
- direction_info = "\n" # 未知方向向量就啥都不写
145
 
146
- # sr3d,给出centersize
147
  # if dataset == 'sr3d':
148
  if False:
149
  line = f'{obj["label"]},id={obj["id"]},ctr={remove_spaces(str(center_position))},size={remove_spaces(str(size))}'
150
- # nr3d和scanrefer,给出centersizecolor
151
  else:
152
  rgb = obj['avg_rgba'][0:3]
153
  hsl = round_list(rgb_to_hsl(rgb), 2)
154
- # line="%s,id=%s,ctr=%s,size=%s,RGB=%s" %(obj['label'], obj['id'], self.remove_space(str(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(rgb) )) 原版rgb
155
- line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(obj['label'], obj['id'], remove_spaces(str(center_position)), remove_spaces(str(size)), remove_spaces(str(hsl)))#rgb换成hsl
156
- # line = "%s(relevant to %s),id=%s,ctr=%s,size=%s,HSL=%s" % (obj['label'],id_to_name_in_description[obj['id']], obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl))) # 格式:name=原名称(description里的名称)
157
  # if id_to_name_in_description[obj['id']]=='room':
158
  # name=obj['label']
159
  # else:
160
  # name=id_to_name_in_description[obj['id']]
161
- # line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(name, obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl) )) # 式:name=description里的名称
162
  lines.append(line + direction_info)
163
  # if self.obj_info_ablation_type == 4:
164
  # random.seed(0)
165
  # random.shuffle(lines)
166
  prompt += ''.join(lines)
167
- # prompt中的要求
168
  line = "\nInstruction:find the one described object in description: \n\"%s\"\n" % user_instruction
169
  prompt = prompt + line
170
 
@@ -208,21 +209,23 @@ def get_gpt_response(prompt: str, code_interpreter: CodeInterpreter):
208
  return response
209
 
210
  def extract_answer_id_from_last_line(last_line, random_choice_list=[0,]):
211
- # 如果没有按照预期格式回复则随机选取(Sr3d)或直接选成0(Nr3d和Scanrefer);按预期格式恢复则提取答案
 
212
  wrong_return_format = False
213
  last_line_split = last_line.split('--')
214
- # 使用正则表达式从字符串中提取字典部分
215
  pattern = r"\{[^\}]*\}"
216
  match = re.search(pattern, last_line_split[-1])
217
  if match:
218
- # 获取匹配的字典字符串
219
  matched_dict_str = match.group()
220
  try:
221
- # 解析字典字符串为字典对象
222
  extracted_dict = ast.literal_eval(matched_dict_str)
223
  print(extracted_dict)
224
  answer_id = extracted_dict['ID']
225
- # 如果确实以 Now the answer is complete -- {'ID': xxx} 的格式回复了,但是xxx不是数字(例如是None),也能随机选。
 
226
  if not isinstance(answer_id, int):
227
  if isinstance(answer_id, list) and all([isinstance(e, int) for e in answer_id]):
228
  print("Wrong answer format: %s. random choice from this list" % str(answer_id))
 
98
  objects_related = objects_info
99
 
100
 
101
+ # Get the center coordinates of the scene
102
  # scene_center=get_scene_center(objects_related)
103
+ scene_center = get_scene_center(objects_info) # Note: all object information should be used here, not just the relevant ones
104
+ # Generate the background information section of the prompt
105
  prompt = scan_id + ":objects with quantitative description based on right-hand Cartesian coordinate system with x-y-z axes, x-y plane=ground, z-axis=up/down. Coords format [x, y, z].\n\n"
106
  # if dataset == 'nr3d':
107
  # prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
 
114
  prompt = prompt + "Scene center:%s. If no direction vector, observer at center for obj orientation.\n\n" % remove_spaces(str(scene_center))
115
  prompt = prompt + "objs list:\n"
116
  lines = []
117
+ # Generate the quantitative object descriptions in the prompt (iterate over all relevant objects)
118
  for obj in objects_related:
119
+ # Position information, rounded to 2 decimal places
120
  center_position = obj['center_position']
121
  center_position = round_list(center_position, 2)
122
+ # Size information, rounded to 2 decimal places
123
  size = obj['size']
124
  size = round_list(size, 2)
125
+ # Extension information, rounded to 2 decimal places
126
  extension = obj['extension']
127
  extension = round_list(extension, 2)
128
+ # Direction information represented by direction vectors.
129
+ # Note: ScanRefer does not use the original ScanNet object IDs, so direction information cannot be used.
130
  if obj['has_front']:
131
  front_point = np.array(obj['front_point'])
132
  center = np.array(obj['obb'][0:3])
133
  direction_vector = front_point - center
134
  direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector)
135
+ # Compute the left and right direction vectors as well, all rounded to 2 decimal places
136
  front_vector = round_list(direction_vector_normalized, 2)
137
  up_vector = np.array([0, 0, 1])
138
  left_vector = round_list(np.cross(direction_vector_normalized, up_vector), 2)
139
  right_vector = round_list(np.cross(up_vector, direction_vector_normalized), 2)
140
  behind_vector = round_list(-np.array(front_vector), 2)
141
+ # Generate the direction information
142
  direction_info = ";direction vectors:front=%s,left=%s,right=%s,behind=%s\n" %(front_vector, left_vector, right_vector, behind_vector)
143
  #
144
  else:
145
+ direction_info = "\n" # If the direction vector is unknown, leave this blank
146
 
147
+ # For sr3d, provide center and size
148
  # if dataset == 'sr3d':
149
  if False:
150
  line = f'{obj["label"]},id={obj["id"]},ctr={remove_spaces(str(center_position))},size={remove_spaces(str(size))}'
151
+ # For nr3d and ScanRefer, provide center, size, and color
152
  else:
153
  rgb = obj['avg_rgba'][0:3]
154
  hsl = round_list(rgb_to_hsl(rgb), 2)
155
+ # line="%s,id=%s,ctr=%s,size=%s,RGB=%s" %(obj['label'], obj['id'], self.remove_space(str(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(rgb) )) original RGB version
156
+ line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(obj['label'], obj['id'], remove_spaces(str(center_position)), remove_spaces(str(size)), remove_spaces(str(hsl)))#switched from RGB to HSL
157
+ # line = "%s(relevant to %s),id=%s,ctr=%s,size=%s,HSL=%s" % (obj['label'],id_to_name_in_description[obj['id']], obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl))) # Format: name=original name (the name used in the description)
158
  # if id_to_name_in_description[obj['id']]=='room':
159
  # name=obj['label']
160
  # else:
161
  # name=id_to_name_in_description[obj['id']]
162
+ # line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(name, obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl) )) # Format: name=the name used in the description
163
  lines.append(line + direction_info)
164
  # if self.obj_info_ablation_type == 4:
165
  # random.seed(0)
166
  # random.shuffle(lines)
167
  prompt += ''.join(lines)
168
+ # Requirements in the prompt
169
  line = "\nInstruction:find the one described object in description: \n\"%s\"\n" % user_instruction
170
  prompt = prompt + line
171
 
 
209
  return response
210
 
211
  def extract_answer_id_from_last_line(last_line, random_choice_list=[0,]):
212
+ # If the reply does not follow the expected format, choose randomly (Sr3d) or default to 0 (Nr3d and ScanRefer);
213
+ # otherwise, extract the answer from the expected format.
214
  wrong_return_format = False
215
  last_line_split = last_line.split('--')
216
+ # Use a regular expression to extract the dictionary portion from the string
217
  pattern = r"\{[^\}]*\}"
218
  match = re.search(pattern, last_line_split[-1])
219
  if match:
220
+ # Get the matched dictionary string
221
  matched_dict_str = match.group()
222
  try:
223
+ # Parse the dictionary string into a dictionary object
224
  extracted_dict = ast.literal_eval(matched_dict_str)
225
  print(extracted_dict)
226
  answer_id = extracted_dict['ID']
227
+ # If the response does follow the expected format but xxx is not a number
228
+ # (for example, None), still fall back to a random choice.
229
  if not isinstance(answer_id, int):
230
  if isinstance(answer_id, list) and all([isinstance(e, int) for e in answer_id]):
231
  print("Wrong answer format: %s. random choice from this list" % str(answer_id))