Transcrib3D-Demo / transcrib3d_main.py
Vincent-Tann
fix Gradio Chatbot bug; clear comments.
b8de9e2
# encoding:utf-8
import ast
import csv
import json
import logging
import os
import random
import re
import time
from copy import deepcopy
from datetime import datetime
import numpy as np
from tenacity import RetryError, before_sleep_log, retry, stop_after_attempt, wait_exponential_jitter # for exponential backoff
from code_interpreter import CodeInterpreter
# from config import confs_nr3d, confs_scanrefer, confs_sr3d
# from gpt_dialogue import Dialogue
# from object_filter_gpt4 import ObjectFilter
from prompt_text import get_principle, get_principle_sr3d, get_system_message
logger = logging.getLogger(__name__ + 'logger')
logger.setLevel(logging.ERROR)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.ERROR)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# def round_list(lst, length):
# # round every element in lst
# for idx, num in enumerate(lst):
# lst[idx] = round(num, length)
# return list(lst)
def round_list(lst, length):
return np.round(lst, length).tolist()
def remove_spaces(s: str):
return s.replace(' ', '')
def rgb_to_hsl(rgb):
# Normalize RGB values to the range [0, 1]
r, g, b = [x / 255.0 for x in rgb]
# Calculate min and max values of RGB to find chroma
c_max = max(r, g, b)
c_min = min(r, g, b)
chroma = c_max - c_min
# Calculate lightness
lightness = (c_max + c_min) / 2
# Calculate hue and saturation
hue = 0
saturation = 0
if chroma != 0:
if c_max == r:
hue = ((g - b) / chroma) % 6
elif c_max == g:
hue = ((b - r) / chroma) + 2
elif c_max == b:
hue = ((r - g) / chroma) + 4
hue *= 60
# Calculate saturation
if lightness <= 0.5:
saturation = chroma / (2 * lightness)
else:
saturation = chroma / (2 - 2 * lightness)
return [hue, saturation, lightness]
def get_scene_center(objects):
xmin, ymin, zmin = float('inf'), float('inf'), float('inf')
xmax, ymax, zmax = float('-inf'), float('-inf'), float('-inf')
for obj in objects:
x, y, z = obj['center_position']
if x < xmin:
xmin = x
if x > xmax:
xmax = x
if y < ymin:
ymin = y
if y > ymax:
ymax = y
if z < zmin:
zmin = z
if z > zmax:
zmax = z
return round_list([(xmin + xmax) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2], 2)
def find_relevant_objects(user_instruction, scan_id):
pass
def gen_prompt(user_instruction, scan_id):
npy_path = os.path.join("objects_info", f"objects_info_{scan_id}.npy")
objects_info = np.load(npy_path, allow_pickle=True)
# objects_related = find_relevant_objects(user_instruction, scan_id)
objects_related = objects_info
# Get the center coordinates of the scene
# scene_center=get_scene_center(objects_related)
scene_center = get_scene_center(objects_info) # Note: all object information should be used here, not just the relevant ones
# Generate the background information section of the prompt
prompt = scan_id + ":objects with quantitative description based on right-hand Cartesian coordinate system with x-y-z axes, x-y plane=ground, z-axis=up/down. Coords format [x, y, z].\n\n"
# if dataset == 'nr3d':
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
# elif dataset == 'scanrefer':
# if use_camera_position:
# prompt = prompt + "Scene center:%s.\n" % remove_spaces(str(scene_center))
# prompt = prompt + "Observer position:%s.\n" % remove_spaces(str(round_list(camera_info_aligned['position'], 2)))
# else:
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
prompt = prompt + "Scene center:%s. If no direction vector, observer at center for obj orientation.\n\n" % remove_spaces(str(scene_center))
prompt = prompt + "objs list:\n"
lines = []
# Generate the quantitative object descriptions in the prompt (iterate over all relevant objects)
for obj in objects_related:
# Position information, rounded to 2 decimal places
center_position = obj['center_position']
center_position = round_list(center_position, 2)
# Size information, rounded to 2 decimal places
size = obj['size']
size = round_list(size, 2)
# Extension information, rounded to 2 decimal places
extension = obj['extension']
extension = round_list(extension, 2)
# Direction information represented by direction vectors.
# Note: ScanRefer does not use the original ScanNet object IDs, so direction information cannot be used.
if obj['has_front']:
front_point = np.array(obj['front_point'])
center = np.array(obj['obb'][0:3])
direction_vector = front_point - center
direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector)
# Compute the left and right direction vectors as well, all rounded to 2 decimal places
front_vector = round_list(direction_vector_normalized, 2)
up_vector = np.array([0, 0, 1])
left_vector = round_list(np.cross(direction_vector_normalized, up_vector), 2)
right_vector = round_list(np.cross(up_vector, direction_vector_normalized), 2)
behind_vector = round_list(-np.array(front_vector), 2)
# Generate the direction information
direction_info = ";direction vectors:front=%s,left=%s,right=%s,behind=%s\n" %(front_vector, left_vector, right_vector, behind_vector)
#
else:
direction_info = "\n" # If the direction vector is unknown, leave this blank
# For sr3d, provide center and size
# if dataset == 'sr3d':
if False:
line = f'{obj["label"]},id={obj["id"]},ctr={remove_spaces(str(center_position))},size={remove_spaces(str(size))}'
# For nr3d and ScanRefer, provide center, size, and color
else:
rgb = obj['avg_rgba'][0:3]
hsl = round_list(rgb_to_hsl(rgb), 2)
# line="%s,id=%s,ctr=%s,size=%s,RGB=%s" %(obj['label'], obj['id'], self.remove_space(str(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(rgb) )) original RGB version
line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(obj['label'], obj['id'], remove_spaces(str(center_position)), remove_spaces(str(size)), remove_spaces(str(hsl)))#switched from RGB to HSL
# line = "%s(relevant to %s),id=%s,ctr=%s,size=%s,HSL=%s" % (obj['label'],id_to_name_in_description[obj['id']], obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl))) # Format: name=original name (the name used in the description)
# if id_to_name_in_description[obj['id']]=='room':
# name=obj['label']
# else:
# name=id_to_name_in_description[obj['id']]
# line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(name, obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl) )) # Format: name=the name used in the description
lines.append(line + direction_info)
# if self.obj_info_ablation_type == 4:
# random.seed(0)
# random.shuffle(lines)
prompt += ''.join(lines)
# Requirements in the prompt
line = "\nInstruction:find the one described object in description: \n\"%s\"\n" % user_instruction
prompt = prompt + line
prompt = prompt + "\n\nThere is exactly one answer, so if you receive multiple answers, considerother constraints; if get no answers, loosen constraints."
prompt = prompt + "\n\nWork this out step by step to ensure right answer."
prompt = prompt + "\n\nIf the answer is complete, add \"Now the answer is complete -- {'ID':id}\" to the end of your answer(that is, your completion, not your code), where id is the id of the referred obj. Do not add anything after."
return prompt
@retry(wait=wait_exponential_jitter(initial=20, max=120, jitter=20), stop=stop_after_attempt(5), before_sleep=before_sleep_log(logger, logging.ERROR)) # 20s,40s,80s,120s + random.uniform(0,20)
def get_gpt_response(prompt: str, code_interpreter: CodeInterpreter):
print("llm_name:",code_interpreter.model)
# get response from GPT(using code interpreter). using retry from tenacity.
# count the token usage and time as well
# if the reponse does not include "Now the answer is complete", this means the answer is notdone. attach an empty user message to let GPT to keep going.
# start timing
call_start_time = time.time()
# the first call with the original prompt
response, token_usage_total = code_interpreter.call_openai_with_code_interpreter(prompt)
response = response['content']
# loop until "Now the answer is complete" is in the response, or looping more than 10 times.
count_response = 0
while not "Now the answer is complete" in response:
if count_response >= 10:
print("Response does not end with 'Now the answer is complete.' !")
break
response, token_usage_add = code_interpreter.call_openai_with_code_interpreter('')
response = response['content']
token_usage_total += token_usage_add
count_response += 1
print("count_response:", count_response)
# stop timing
call_end_time = time.time()
time_consumed = call_end_time - call_start_time
# self.token_usage_this_ques += token_usage_total
# self.token_usage_whole_run += token_usage_total
# self.time_consumed_this_ques += time_consumed
# self.time_consumed_whole_run += time_consumed
# print("\n*** Refer model: token usage=%d, time consumed=%ds, TPM=%.2f ***" %(token_usage_total, time_consumed, token_usage_total / time_consumed * 60))
return response
def extract_answer_id_from_last_line(last_line, random_choice_list=[0,]):
# If the reply does not follow the expected format, choose randomly (Sr3d) or default to 0 (Nr3d and ScanRefer);
# otherwise, extract the answer from the expected format.
wrong_return_format = False
last_line_split = last_line.split('--')
# Use a regular expression to extract the dictionary portion from the string
pattern = r"\{[^\}]*\}"
match = re.search(pattern, last_line_split[-1])
if match:
# Get the matched dictionary string
matched_dict_str = match.group()
try:
# Parse the dictionary string into a dictionary object
extracted_dict = ast.literal_eval(matched_dict_str)
print(extracted_dict)
answer_id = extracted_dict['ID']
# If the response does follow the expected format but xxx is not a number
# (for example, None), still fall back to a random choice.
if not isinstance(answer_id, int):
if isinstance(answer_id, list) and all([isinstance(e, int) for e in answer_id]):
print("Wrong answer format: %s. random choice from this list" % str(answer_id))
answer_id = random.choice(answer_id)
else:
print("Wrong answer format: %s. No dict found. Random choice from relevant objects." % str(answer_id))
answer_id = random.choice(random_choice_list)
wrong_return_format = True
except BaseException:
print("Wrong answer format!! No dict found. Random choice.")
answer_id = random.choice(random_choice_list)
wrong_return_format = True
else:
print("Wrong answer format!! No dict found. Random choice.")
answer_id = random.choice(random_choice_list)
wrong_return_format = True
return answer_id, wrong_return_format
def get_openai_config(llm_name='gpt-3.5-turbo-0125'):
system_message = ""
system_message += get_system_message()
system_message += get_principle()
openai_config = {
# 'model': 'gpt-4-turbo-preview',
'model': llm_name,
'temperature': 1e-7,
'top_p': 1e-7,
# 'max_tokens': 4096,
'max_tokens': 8192,
'system_message': system_message,
# 'load_path': '',
'save_path': 'chats',
'debug': True
}
return openai_config
if __name__ == "__main__":
# system_message = 'Imagine you are an artificial intelligence assistant. You job is to do 3D referring reasoning, namely to find the object for a given utterance from a 3d scene presented as object-centric semantic information.\n'
system_message = ""
system_message += get_system_message()
system_message += get_principle()
openai_config = {
'model': 'gpt-4',
'temperature': 1e-7,
'top_p': 1e-7,
# 'max_tokens': 4096,
'max_tokens': 8192,
'system_message': system_message,
# 'load_path': '',
'save_path': 'chats',
'debug': True
}
code_interpreter = CodeInterpreter(**openai_config)
prompt = gen_prompt("Find the chair next to the table.", "scene0132_00")
print(prompt)
response = get_gpt_response(prompt, code_interpreter)
# print(response)
print("-------pretext--------")
print(code_interpreter.pretext)