Spaces:
Sleeping
Sleeping
Vincent-Tann commited on
Commit ·
b8de9e2
1
Parent(s): 1477c41
fix Gradio Chatbot bug; clear comments.
Browse files- _appyibu.py +10 -8
- app.py +12 -15
- code_interpreter.py +8 -7
- display_model.py +16 -16
- object_filter_gpt4.py +2 -2
- transcrib3d_main.py +26 -23
_appyibu.py
CHANGED
|
@@ -13,23 +13,25 @@ new_glb_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned_AddBox.glb"
|
|
| 13 |
objects_info_file = os.path.join("objects_info", f"objects_info_{scan_id}.npy")
|
| 14 |
|
| 15 |
def insert_user_none_between_assistant(messages):
|
| 16 |
-
#
|
| 17 |
result = []
|
| 18 |
-
#
|
|
|
|
| 19 |
last_role = "user"
|
| 20 |
|
| 21 |
for msg in messages:
|
| 22 |
-
#
|
| 23 |
current_role = msg["role"]
|
| 24 |
|
| 25 |
-
#
|
|
|
|
| 26 |
if last_role == "assistant" and current_role == "assistant":
|
| 27 |
result.append({"role": "user", "content": None})
|
| 28 |
|
| 29 |
-
#
|
| 30 |
result.append(msg)
|
| 31 |
|
| 32 |
-
#
|
| 33 |
last_role = current_role
|
| 34 |
|
| 35 |
return result
|
|
@@ -122,11 +124,11 @@ with gr.Blocks() as demo:
|
|
| 122 |
)
|
| 123 |
|
| 124 |
# print("Type2:",type(model3d))
|
| 125 |
-
#
|
| 126 |
# bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
|
| 127 |
bt.click(fn=process_instruction_callback, inputs=[user_instruction_textbox, gr.State(model3d), gr.State(dialogue)])#, outputs=[model3d,dialogue])
|
| 128 |
|
| 129 |
-
#
|
| 130 |
# type(user_instruction_textbox.value)
|
| 131 |
# user_instruction_textbox.
|
| 132 |
# user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
|
|
|
| 13 |
objects_info_file = os.path.join("objects_info", f"objects_info_{scan_id}.npy")
|
| 14 |
|
| 15 |
def insert_user_none_between_assistant(messages):
|
| 16 |
+
# Initialize the result list
|
| 17 |
result = []
|
| 18 |
+
# Set the initial state to "user" so insertion works correctly
|
| 19 |
+
# when the first item in the list is "assistant"
|
| 20 |
last_role = "user"
|
| 21 |
|
| 22 |
for msg in messages:
|
| 23 |
+
# Check the role of the current message
|
| 24 |
current_role = msg["role"]
|
| 25 |
|
| 26 |
+
# If both the previous and current messages are "assistant",
|
| 27 |
+
# insert a "user" message whose content is None
|
| 28 |
if last_role == "assistant" and current_role == "assistant":
|
| 29 |
result.append({"role": "user", "content": None})
|
| 30 |
|
| 31 |
+
# Add the current message to the result list
|
| 32 |
result.append(msg)
|
| 33 |
|
| 34 |
+
# Update the role of the previous message
|
| 35 |
last_role = current_role
|
| 36 |
|
| 37 |
return result
|
|
|
|
| 124 |
)
|
| 125 |
|
| 126 |
# print("Type2:",type(model3d))
|
| 127 |
+
# Passing model3d directly in the inputs list causes the callback to receive a string
|
| 128 |
# bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
|
| 129 |
bt.click(fn=process_instruction_callback, inputs=[user_instruction_textbox, gr.State(model3d), gr.State(dialogue)])#, outputs=[model3d,dialogue])
|
| 130 |
|
| 131 |
+
# Define a mapping directly with a lambda function
|
| 132 |
# type(user_instruction_textbox.value)
|
| 133 |
# user_instruction_textbox.
|
| 134 |
# user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
app.py
CHANGED
|
@@ -49,23 +49,23 @@ def insert_user_blank_between_assistant(messages):
|
|
| 49 |
|
| 50 |
def timer_check_update(code_interpreter, update_interval, stop_event):
|
| 51 |
"""
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
- code_interpreter: CodeInterpreter
|
| 56 |
-
- update_interval:
|
| 57 |
-
- stop_event:
|
| 58 |
"""
|
| 59 |
while not stop_event.is_set():
|
| 60 |
if code_interpreter.has_update:
|
| 61 |
-
#
|
| 62 |
print("Detected update, trigger UI refreshing...")
|
| 63 |
-
#
|
| 64 |
# ...
|
| 65 |
-
#
|
| 66 |
code_interpreter.has_update = False
|
| 67 |
|
| 68 |
-
#
|
| 69 |
time.sleep(update_interval)
|
| 70 |
|
| 71 |
def process_instruction_callback(inp_api_key, instruction, llm_name):
|
|
@@ -194,13 +194,12 @@ with gr.Blocks() as demo:
|
|
| 194 |
)
|
| 195 |
# Right-5: Dialogue
|
| 196 |
dialogue = gr.Chatbot(
|
| 197 |
-
type="messages",
|
| 198 |
height=470
|
| 199 |
)
|
| 200 |
|
| 201 |
|
| 202 |
# print("Type2:",type(model3d))
|
| 203 |
-
#
|
| 204 |
# bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
|
| 205 |
bt.click(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox,llm_name_text], outputs=[model3d,dialogue])
|
| 206 |
user_instruction_textbox.submit(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox, llm_name_text], outputs=[model3d,dialogue])
|
|
@@ -208,15 +207,13 @@ with gr.Blocks() as demo:
|
|
| 208 |
scene_type_dropdown.select(fn=scene_type_dropdown_callback, inputs=scene_type_dropdown, outputs=model3d)
|
| 209 |
llm_dropdown.select(fn=llm_dropdown_callback, inputs=llm_dropdown, outputs=llm_name_text)
|
| 210 |
|
| 211 |
-
#
|
| 212 |
# type(user_instruction_textbox.value)
|
| 213 |
# user_instruction_textbox.
|
| 214 |
# user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
| 215 |
# user_instruction_textbox.
|
| 216 |
# bt.click(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
| 217 |
|
| 218 |
-
# os.system('uname -a') # 显示所有系统信息
|
| 219 |
-
# demo.launch()
|
| 220 |
|
| 221 |
|
| 222 |
|
|
|
|
| 49 |
|
| 50 |
def timer_check_update(code_interpreter, update_interval, stop_event):
|
| 51 |
"""
|
| 52 |
+
Periodically check whether code_interpreter.has_update is True.
|
| 53 |
+
If it is True, trigger the UI refresh logic and reset the state.
|
| 54 |
+
Args:
|
| 55 |
+
- code_interpreter: A CodeInterpreter instance expected to have a has_update attribute.
|
| 56 |
+
- update_interval: Timer polling interval in seconds.
|
| 57 |
+
- stop_event: A threading.Event() instance used to stop the timer thread.
|
| 58 |
"""
|
| 59 |
while not stop_event.is_set():
|
| 60 |
if code_interpreter.has_update:
|
| 61 |
+
# Implement the UI refresh logic
|
| 62 |
print("Detected update, trigger UI refreshing...")
|
| 63 |
+
# Add UI refresh code here
|
| 64 |
# ...
|
| 65 |
+
# Reset the has_update flag
|
| 66 |
code_interpreter.has_update = False
|
| 67 |
|
| 68 |
+
# Wait until the next check
|
| 69 |
time.sleep(update_interval)
|
| 70 |
|
| 71 |
def process_instruction_callback(inp_api_key, instruction, llm_name):
|
|
|
|
| 194 |
)
|
| 195 |
# Right-5: Dialogue
|
| 196 |
dialogue = gr.Chatbot(
|
|
|
|
| 197 |
height=470
|
| 198 |
)
|
| 199 |
|
| 200 |
|
| 201 |
# print("Type2:",type(model3d))
|
| 202 |
+
# Passing model3d directly in the inputs list causes the callback to receive a string
|
| 203 |
# bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
|
| 204 |
bt.click(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox,llm_name_text], outputs=[model3d,dialogue])
|
| 205 |
user_instruction_textbox.submit(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox, llm_name_text], outputs=[model3d,dialogue])
|
|
|
|
| 207 |
scene_type_dropdown.select(fn=scene_type_dropdown_callback, inputs=scene_type_dropdown, outputs=model3d)
|
| 208 |
llm_dropdown.select(fn=llm_dropdown_callback, inputs=llm_dropdown, outputs=llm_name_text)
|
| 209 |
|
| 210 |
+
# Define a mapping directly with a lambda function
|
| 211 |
# type(user_instruction_textbox.value)
|
| 212 |
# user_instruction_textbox.
|
| 213 |
# user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
| 214 |
# user_instruction_textbox.
|
| 215 |
# bt.click(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
| 216 |
|
|
|
|
|
|
|
| 217 |
|
| 218 |
|
| 219 |
|
code_interpreter.py
CHANGED
|
@@ -12,9 +12,10 @@ class CodeInterpreter(Dialogue):
|
|
| 12 |
super().__init__(**kwargs)
|
| 13 |
|
| 14 |
def call_openai_with_code_interpreter(self, user_prompt,namespace_for_exec={},token_usage_total=0):
|
| 15 |
-
#
|
| 16 |
-
#
|
| 17 |
-
#
|
|
|
|
| 18 |
assistant_response,token_usage = self.call_openai(user_prompt)
|
| 19 |
token_usage_total+=token_usage
|
| 20 |
|
|
@@ -52,15 +53,15 @@ class CodeInterpreter(Dialogue):
|
|
| 52 |
# f.close()
|
| 53 |
# sys.stdout = sys.__stdout__
|
| 54 |
|
| 55 |
-
#############
|
| 56 |
-
# #
|
| 57 |
# with open("code_snippet.py", "w") as file:
|
| 58 |
# file.write(code_snippet)
|
| 59 |
|
| 60 |
-
# #
|
| 61 |
# os.system("python code_snippet.py > output.txt")
|
| 62 |
|
| 63 |
-
# #
|
| 64 |
# with open("output.txt", "r") as file:
|
| 65 |
# code_exe_result = file.read()
|
| 66 |
##################################################
|
|
|
|
| 12 |
super().__init__(**kwargs)
|
| 13 |
|
| 14 |
def call_openai_with_code_interpreter(self, user_prompt,namespace_for_exec={},token_usage_total=0):
|
| 15 |
+
# If the GPT response contains Python code, execute it and send the result
|
| 16 |
+
# back to GPT, then continue waiting for its reply.
|
| 17 |
+
# If the GPT response does not contain Python code, return the full result.
|
| 18 |
+
# Accumulate token usage on each recursive call and return the total at the end.
|
| 19 |
assistant_response,token_usage = self.call_openai(user_prompt)
|
| 20 |
token_usage_total+=token_usage
|
| 21 |
|
|
|
|
| 53 |
# f.close()
|
| 54 |
# sys.stdout = sys.__stdout__
|
| 55 |
|
| 56 |
+
#############Using the file-saving approach####################
|
| 57 |
+
# # Save the code snippet to the code_snippet.py file
|
| 58 |
# with open("code_snippet.py", "w") as file:
|
| 59 |
# file.write(code_snippet)
|
| 60 |
|
| 61 |
+
# # Execute code_snippet.py and redirect the output to a temporary file
|
| 62 |
# os.system("python code_snippet.py > output.txt")
|
| 63 |
|
| 64 |
+
# # Read the result from the temporary file
|
| 65 |
# with open("output.txt", "r") as file:
|
| 66 |
# code_exe_result = file.read()
|
| 67 |
##################################################
|
display_model.py
CHANGED
|
@@ -62,13 +62,13 @@ def add_1box_to_ply(box, ply_file, new_ply_file, line_width=0.05, obj_id=1):
|
|
| 62 |
box_vertices['blue'] = [color[2]] * 16
|
| 63 |
box_vertices['alpha'] = [obj_id] * 16
|
| 64 |
|
| 65 |
-
#
|
| 66 |
updated_vertices = np.concatenate((vertices, box_vertices))
|
| 67 |
|
| 68 |
-
#
|
| 69 |
updated_vertex_element = PlyElement.describe(updated_vertices, 'vertex')
|
| 70 |
|
| 71 |
-
#
|
| 72 |
# ply_data['vertex'] = updated_vertex_element
|
| 73 |
|
| 74 |
# get the number of original vertices:
|
|
@@ -108,18 +108,18 @@ def add_1box_to_ply(box, ply_file, new_ply_file, line_width=0.05, obj_id=1):
|
|
| 108 |
box_faces = np.zeros(len(box_connections), dtype=faces.dtype)
|
| 109 |
box_faces['vertex_indices'] = box_connections
|
| 110 |
|
| 111 |
-
#
|
| 112 |
updated_faces = np.concatenate((faces, box_faces))
|
| 113 |
|
| 114 |
-
#
|
| 115 |
updated_face_element = PlyElement.describe(updated_faces, 'face')
|
| 116 |
|
| 117 |
-
#
|
| 118 |
# ply_data['face'] = updated_face_element
|
| 119 |
|
| 120 |
new_ply_data = PlyData([updated_vertex_element, updated_face_element])
|
| 121 |
|
| 122 |
-
#
|
| 123 |
with open(new_ply_file, 'wb') as f:
|
| 124 |
new_ply_data.write(f)
|
| 125 |
|
|
@@ -127,36 +127,36 @@ def add_1box_to_ply(box, ply_file, new_ply_file, line_width=0.05, obj_id=1):
|
|
| 127 |
|
| 128 |
|
| 129 |
def ply_to_obj(ply_file, obj_file, mtl_file):
|
| 130 |
-
#
|
| 131 |
with open(ply_file, 'rb') as f:
|
| 132 |
plydata = PlyData.read(f)
|
| 133 |
|
| 134 |
-
#
|
| 135 |
vertices = np.vstack([plydata['vertex'][prop] for prop in ['x', 'y', 'z']]).T
|
| 136 |
colors = np.vstack([plydata['vertex'][prop] for prop in ['red', 'green', 'blue', 'alpha']]).T/255.0
|
| 137 |
faces = plydata['face']['vertex_indices']
|
| 138 |
|
| 139 |
-
#
|
| 140 |
with open(obj_file, 'w') as f:
|
| 141 |
-
#
|
| 142 |
f.write("mtllib %s\n"%mtl_file.split('/')[-1])
|
| 143 |
|
| 144 |
-
#
|
| 145 |
for vertex in vertices:
|
| 146 |
f.write(f"v {' '.join(map(str, vertex))}\n")
|
| 147 |
|
| 148 |
-
#
|
| 149 |
for idx in range(len(vertices)):
|
| 150 |
f.write("usemtl mat%d\n"%(idx+1))
|
| 151 |
|
| 152 |
-
#
|
| 153 |
for face in faces:
|
| 154 |
f.write("f")
|
| 155 |
for vertex_index in face:
|
| 156 |
-
f.write(f" {vertex_index + 1}") # OBJ
|
| 157 |
f.write("\n")
|
| 158 |
|
| 159 |
-
#
|
| 160 |
with open(mtl_file, 'w') as f:
|
| 161 |
for idx, color in enumerate(colors):
|
| 162 |
f.write("newmtl mat%d\n" % (idx+1))
|
|
|
|
| 62 |
box_vertices['blue'] = [color[2]] * 16
|
| 63 |
box_vertices['alpha'] = [obj_id] * 16
|
| 64 |
|
| 65 |
+
# Append the new vertex data after the original vertex data
|
| 66 |
updated_vertices = np.concatenate((vertices, box_vertices))
|
| 67 |
|
| 68 |
+
# Create a PlyElement object containing the new vertices
|
| 69 |
updated_vertex_element = PlyElement.describe(updated_vertices, 'vertex')
|
| 70 |
|
| 71 |
+
# Replace the original vertex data with the updated PlyElement object
|
| 72 |
# ply_data['vertex'] = updated_vertex_element
|
| 73 |
|
| 74 |
# get the number of original vertices:
|
|
|
|
| 108 |
box_faces = np.zeros(len(box_connections), dtype=faces.dtype)
|
| 109 |
box_faces['vertex_indices'] = box_connections
|
| 110 |
|
| 111 |
+
# Append the new face data after the original face data
|
| 112 |
updated_faces = np.concatenate((faces, box_faces))
|
| 113 |
|
| 114 |
+
# Create a PlyElement object containing the new faces
|
| 115 |
updated_face_element = PlyElement.describe(updated_faces, 'face')
|
| 116 |
|
| 117 |
+
# Replace the original face data with the updated PlyElement object
|
| 118 |
# ply_data['face'] = updated_face_element
|
| 119 |
|
| 120 |
new_ply_data = PlyData([updated_vertex_element, updated_face_element])
|
| 121 |
|
| 122 |
+
# Write the updated PlyData object back to the PLY file
|
| 123 |
with open(new_ply_file, 'wb') as f:
|
| 124 |
new_ply_data.write(f)
|
| 125 |
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def ply_to_obj(ply_file, obj_file, mtl_file):
|
| 130 |
+
# Read the PLY file
|
| 131 |
with open(ply_file, 'rb') as f:
|
| 132 |
plydata = PlyData.read(f)
|
| 133 |
|
| 134 |
+
# Get the vertex and face data
|
| 135 |
vertices = np.vstack([plydata['vertex'][prop] for prop in ['x', 'y', 'z']]).T
|
| 136 |
colors = np.vstack([plydata['vertex'][prop] for prop in ['red', 'green', 'blue', 'alpha']]).T/255.0
|
| 137 |
faces = plydata['face']['vertex_indices']
|
| 138 |
|
| 139 |
+
# Write the OBJ file
|
| 140 |
with open(obj_file, 'w') as f:
|
| 141 |
+
# Write the referenced MTL file (material colors)
|
| 142 |
f.write("mtllib %s\n"%mtl_file.split('/')[-1])
|
| 143 |
|
| 144 |
+
# Write the vertex data
|
| 145 |
for vertex in vertices:
|
| 146 |
f.write(f"v {' '.join(map(str, vertex))}\n")
|
| 147 |
|
| 148 |
+
# Write the material assignments
|
| 149 |
for idx in range(len(vertices)):
|
| 150 |
f.write("usemtl mat%d\n"%(idx+1))
|
| 151 |
|
| 152 |
+
# Write the face data
|
| 153 |
for face in faces:
|
| 154 |
f.write("f")
|
| 155 |
for vertex_index in face:
|
| 156 |
+
f.write(f" {vertex_index + 1}") # OBJ indices start from 1
|
| 157 |
f.write("\n")
|
| 158 |
|
| 159 |
+
# Write the MTL file
|
| 160 |
with open(mtl_file, 'w') as f:
|
| 161 |
for idx, color in enumerate(colors):
|
| 162 |
f.write("newmtl mat%d\n" % (idx+1))
|
object_filter_gpt4.py
CHANGED
|
@@ -40,7 +40,7 @@ class ObjectFilter(Dialogue):
|
|
| 40 |
super().__init__(**config)
|
| 41 |
|
| 42 |
def extract_all_int_lists_from_text(self,text) ->list:
|
| 43 |
-
#
|
| 44 |
pattern = r'\[([^\[\]]+)\]'
|
| 45 |
matches = re.findall(pattern, text)
|
| 46 |
|
|
@@ -134,7 +134,7 @@ if __name__ == "__main__":
|
|
| 134 |
scanrefer_data=json.load(json_file)
|
| 135 |
|
| 136 |
from datetime import datetime
|
| 137 |
-
#
|
| 138 |
current_time = datetime.now()
|
| 139 |
formatted_time = current_time.strftime("%Y-%m-%d-%H-%M-%S")
|
| 140 |
print("formatted_time:",formatted_time)
|
|
|
|
| 40 |
super().__init__(**config)
|
| 41 |
|
| 42 |
def extract_all_int_lists_from_text(self,text) ->list:
|
| 43 |
+
# Match the content inside square brackets
|
| 44 |
pattern = r'\[([^\[\]]+)\]'
|
| 45 |
matches = re.findall(pattern, text)
|
| 46 |
|
|
|
|
| 134 |
scanrefer_data=json.load(json_file)
|
| 135 |
|
| 136 |
from datetime import datetime
|
| 137 |
+
# Record the current time to use as the folder name
|
| 138 |
current_time = datetime.now()
|
| 139 |
formatted_time = current_time.strftime("%Y-%m-%d-%H-%M-%S")
|
| 140 |
print("formatted_time:",formatted_time)
|
transcrib3d_main.py
CHANGED
|
@@ -98,10 +98,10 @@ def gen_prompt(user_instruction, scan_id):
|
|
| 98 |
objects_related = objects_info
|
| 99 |
|
| 100 |
|
| 101 |
-
#
|
| 102 |
# scene_center=get_scene_center(objects_related)
|
| 103 |
-
scene_center = get_scene_center(objects_info) #
|
| 104 |
-
#
|
| 105 |
prompt = scan_id + ":objects with quantitative description based on right-hand Cartesian coordinate system with x-y-z axes, x-y plane=ground, z-axis=up/down. Coords format [x, y, z].\n\n"
|
| 106 |
# if dataset == 'nr3d':
|
| 107 |
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
|
|
@@ -114,57 +114,58 @@ def gen_prompt(user_instruction, scan_id):
|
|
| 114 |
prompt = prompt + "Scene center:%s. If no direction vector, observer at center for obj orientation.\n\n" % remove_spaces(str(scene_center))
|
| 115 |
prompt = prompt + "objs list:\n"
|
| 116 |
lines = []
|
| 117 |
-
#
|
| 118 |
for obj in objects_related:
|
| 119 |
-
#
|
| 120 |
center_position = obj['center_position']
|
| 121 |
center_position = round_list(center_position, 2)
|
| 122 |
-
#
|
| 123 |
size = obj['size']
|
| 124 |
size = round_list(size, 2)
|
| 125 |
-
#
|
| 126 |
extension = obj['extension']
|
| 127 |
extension = round_list(extension, 2)
|
| 128 |
-
#
|
|
|
|
| 129 |
if obj['has_front']:
|
| 130 |
front_point = np.array(obj['front_point'])
|
| 131 |
center = np.array(obj['obb'][0:3])
|
| 132 |
direction_vector = front_point - center
|
| 133 |
direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector)
|
| 134 |
-
#
|
| 135 |
front_vector = round_list(direction_vector_normalized, 2)
|
| 136 |
up_vector = np.array([0, 0, 1])
|
| 137 |
left_vector = round_list(np.cross(direction_vector_normalized, up_vector), 2)
|
| 138 |
right_vector = round_list(np.cross(up_vector, direction_vector_normalized), 2)
|
| 139 |
behind_vector = round_list(-np.array(front_vector), 2)
|
| 140 |
-
#
|
| 141 |
direction_info = ";direction vectors:front=%s,left=%s,right=%s,behind=%s\n" %(front_vector, left_vector, right_vector, behind_vector)
|
| 142 |
#
|
| 143 |
else:
|
| 144 |
-
direction_info = "\n" #
|
| 145 |
|
| 146 |
-
# sr3d
|
| 147 |
# if dataset == 'sr3d':
|
| 148 |
if False:
|
| 149 |
line = f'{obj["label"]},id={obj["id"]},ctr={remove_spaces(str(center_position))},size={remove_spaces(str(size))}'
|
| 150 |
-
# nr3d
|
| 151 |
else:
|
| 152 |
rgb = obj['avg_rgba'][0:3]
|
| 153 |
hsl = round_list(rgb_to_hsl(rgb), 2)
|
| 154 |
-
# line="%s,id=%s,ctr=%s,size=%s,RGB=%s" %(obj['label'], obj['id'], self.remove_space(str(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(rgb) ))
|
| 155 |
-
line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(obj['label'], obj['id'], remove_spaces(str(center_position)), remove_spaces(str(size)), remove_spaces(str(hsl)))#
|
| 156 |
-
# line = "%s(relevant to %s),id=%s,ctr=%s,size=%s,HSL=%s" % (obj['label'],id_to_name_in_description[obj['id']], obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl))) #
|
| 157 |
# if id_to_name_in_description[obj['id']]=='room':
|
| 158 |
# name=obj['label']
|
| 159 |
# else:
|
| 160 |
# name=id_to_name_in_description[obj['id']]
|
| 161 |
-
# line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(name, obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl) )) #
|
| 162 |
lines.append(line + direction_info)
|
| 163 |
# if self.obj_info_ablation_type == 4:
|
| 164 |
# random.seed(0)
|
| 165 |
# random.shuffle(lines)
|
| 166 |
prompt += ''.join(lines)
|
| 167 |
-
# prompt
|
| 168 |
line = "\nInstruction:find the one described object in description: \n\"%s\"\n" % user_instruction
|
| 169 |
prompt = prompt + line
|
| 170 |
|
|
@@ -208,21 +209,23 @@ def get_gpt_response(prompt: str, code_interpreter: CodeInterpreter):
|
|
| 208 |
return response
|
| 209 |
|
| 210 |
def extract_answer_id_from_last_line(last_line, random_choice_list=[0,]):
|
| 211 |
-
#
|
|
|
|
| 212 |
wrong_return_format = False
|
| 213 |
last_line_split = last_line.split('--')
|
| 214 |
-
#
|
| 215 |
pattern = r"\{[^\}]*\}"
|
| 216 |
match = re.search(pattern, last_line_split[-1])
|
| 217 |
if match:
|
| 218 |
-
#
|
| 219 |
matched_dict_str = match.group()
|
| 220 |
try:
|
| 221 |
-
#
|
| 222 |
extracted_dict = ast.literal_eval(matched_dict_str)
|
| 223 |
print(extracted_dict)
|
| 224 |
answer_id = extracted_dict['ID']
|
| 225 |
-
#
|
|
|
|
| 226 |
if not isinstance(answer_id, int):
|
| 227 |
if isinstance(answer_id, list) and all([isinstance(e, int) for e in answer_id]):
|
| 228 |
print("Wrong answer format: %s. random choice from this list" % str(answer_id))
|
|
|
|
| 98 |
objects_related = objects_info
|
| 99 |
|
| 100 |
|
| 101 |
+
# Get the center coordinates of the scene
|
| 102 |
# scene_center=get_scene_center(objects_related)
|
| 103 |
+
scene_center = get_scene_center(objects_info) # Note: all object information should be used here, not just the relevant ones
|
| 104 |
+
# Generate the background information section of the prompt
|
| 105 |
prompt = scan_id + ":objects with quantitative description based on right-hand Cartesian coordinate system with x-y-z axes, x-y plane=ground, z-axis=up/down. Coords format [x, y, z].\n\n"
|
| 106 |
# if dataset == 'nr3d':
|
| 107 |
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
|
|
|
|
| 114 |
prompt = prompt + "Scene center:%s. If no direction vector, observer at center for obj orientation.\n\n" % remove_spaces(str(scene_center))
|
| 115 |
prompt = prompt + "objs list:\n"
|
| 116 |
lines = []
|
| 117 |
+
# Generate the quantitative object descriptions in the prompt (iterate over all relevant objects)
|
| 118 |
for obj in objects_related:
|
| 119 |
+
# Position information, rounded to 2 decimal places
|
| 120 |
center_position = obj['center_position']
|
| 121 |
center_position = round_list(center_position, 2)
|
| 122 |
+
# Size information, rounded to 2 decimal places
|
| 123 |
size = obj['size']
|
| 124 |
size = round_list(size, 2)
|
| 125 |
+
# Extension information, rounded to 2 decimal places
|
| 126 |
extension = obj['extension']
|
| 127 |
extension = round_list(extension, 2)
|
| 128 |
+
# Direction information represented by direction vectors.
|
| 129 |
+
# Note: ScanRefer does not use the original ScanNet object IDs, so direction information cannot be used.
|
| 130 |
if obj['has_front']:
|
| 131 |
front_point = np.array(obj['front_point'])
|
| 132 |
center = np.array(obj['obb'][0:3])
|
| 133 |
direction_vector = front_point - center
|
| 134 |
direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector)
|
| 135 |
+
# Compute the left and right direction vectors as well, all rounded to 2 decimal places
|
| 136 |
front_vector = round_list(direction_vector_normalized, 2)
|
| 137 |
up_vector = np.array([0, 0, 1])
|
| 138 |
left_vector = round_list(np.cross(direction_vector_normalized, up_vector), 2)
|
| 139 |
right_vector = round_list(np.cross(up_vector, direction_vector_normalized), 2)
|
| 140 |
behind_vector = round_list(-np.array(front_vector), 2)
|
| 141 |
+
# Generate the direction information
|
| 142 |
direction_info = ";direction vectors:front=%s,left=%s,right=%s,behind=%s\n" %(front_vector, left_vector, right_vector, behind_vector)
|
| 143 |
#
|
| 144 |
else:
|
| 145 |
+
direction_info = "\n" # If the direction vector is unknown, leave this blank
|
| 146 |
|
| 147 |
+
# For sr3d, provide center and size
|
| 148 |
# if dataset == 'sr3d':
|
| 149 |
if False:
|
| 150 |
line = f'{obj["label"]},id={obj["id"]},ctr={remove_spaces(str(center_position))},size={remove_spaces(str(size))}'
|
| 151 |
+
# For nr3d and ScanRefer, provide center, size, and color
|
| 152 |
else:
|
| 153 |
rgb = obj['avg_rgba'][0:3]
|
| 154 |
hsl = round_list(rgb_to_hsl(rgb), 2)
|
| 155 |
+
# line="%s,id=%s,ctr=%s,size=%s,RGB=%s" %(obj['label'], obj['id'], self.remove_space(str(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(rgb) )) original RGB version
|
| 156 |
+
line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(obj['label'], obj['id'], remove_spaces(str(center_position)), remove_spaces(str(size)), remove_spaces(str(hsl)))#switched from RGB to HSL
|
| 157 |
+
# line = "%s(relevant to %s),id=%s,ctr=%s,size=%s,HSL=%s" % (obj['label'],id_to_name_in_description[obj['id']], obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl))) # Format: name=original name (the name used in the description)
|
| 158 |
# if id_to_name_in_description[obj['id']]=='room':
|
| 159 |
# name=obj['label']
|
| 160 |
# else:
|
| 161 |
# name=id_to_name_in_description[obj['id']]
|
| 162 |
+
# line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(name, obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl) )) # Format: name=the name used in the description
|
| 163 |
lines.append(line + direction_info)
|
| 164 |
# if self.obj_info_ablation_type == 4:
|
| 165 |
# random.seed(0)
|
| 166 |
# random.shuffle(lines)
|
| 167 |
prompt += ''.join(lines)
|
| 168 |
+
# Requirements in the prompt
|
| 169 |
line = "\nInstruction:find the one described object in description: \n\"%s\"\n" % user_instruction
|
| 170 |
prompt = prompt + line
|
| 171 |
|
|
|
|
| 209 |
return response
|
| 210 |
|
| 211 |
def extract_answer_id_from_last_line(last_line, random_choice_list=[0,]):
|
| 212 |
+
# If the reply does not follow the expected format, choose randomly (Sr3d) or default to 0 (Nr3d and ScanRefer);
|
| 213 |
+
# otherwise, extract the answer from the expected format.
|
| 214 |
wrong_return_format = False
|
| 215 |
last_line_split = last_line.split('--')
|
| 216 |
+
# Use a regular expression to extract the dictionary portion from the string
|
| 217 |
pattern = r"\{[^\}]*\}"
|
| 218 |
match = re.search(pattern, last_line_split[-1])
|
| 219 |
if match:
|
| 220 |
+
# Get the matched dictionary string
|
| 221 |
matched_dict_str = match.group()
|
| 222 |
try:
|
| 223 |
+
# Parse the dictionary string into a dictionary object
|
| 224 |
extracted_dict = ast.literal_eval(matched_dict_str)
|
| 225 |
print(extracted_dict)
|
| 226 |
answer_id = extracted_dict['ID']
|
| 227 |
+
# If the response does follow the expected format but xxx is not a number
|
| 228 |
+
# (for example, None), still fall back to a random choice.
|
| 229 |
if not isinstance(answer_id, int):
|
| 230 |
if isinstance(answer_id, list) and all([isinstance(e, int) for e in answer_id]):
|
| 231 |
print("Wrong answer format: %s. random choice from this list" % str(answer_id))
|