Spaces:
Sleeping
Sleeping
| # encoding:utf-8 | |
| import ast | |
| import csv | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import re | |
| import time | |
| from copy import deepcopy | |
| from datetime import datetime | |
| import numpy as np | |
| from tenacity import RetryError, before_sleep_log, retry, stop_after_attempt, wait_exponential_jitter # for exponential backoff | |
| from code_interpreter import CodeInterpreter | |
| # from config import confs_nr3d, confs_scanrefer, confs_sr3d | |
| # from gpt_dialogue import Dialogue | |
| # from object_filter_gpt4 import ObjectFilter | |
| from prompt_text import get_principle, get_principle_sr3d, get_system_message | |
| logger = logging.getLogger(__name__ + 'logger') | |
| logger.setLevel(logging.ERROR) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setLevel(logging.ERROR) | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| # def round_list(lst, length): | |
| # # round every element in lst | |
| # for idx, num in enumerate(lst): | |
| # lst[idx] = round(num, length) | |
| # return list(lst) | |
| def round_list(lst, length): | |
| return np.round(lst, length).tolist() | |
| def remove_spaces(s: str): | |
| return s.replace(' ', '') | |
| def rgb_to_hsl(rgb): | |
| # Normalize RGB values to the range [0, 1] | |
| r, g, b = [x / 255.0 for x in rgb] | |
| # Calculate min and max values of RGB to find chroma | |
| c_max = max(r, g, b) | |
| c_min = min(r, g, b) | |
| chroma = c_max - c_min | |
| # Calculate lightness | |
| lightness = (c_max + c_min) / 2 | |
| # Calculate hue and saturation | |
| hue = 0 | |
| saturation = 0 | |
| if chroma != 0: | |
| if c_max == r: | |
| hue = ((g - b) / chroma) % 6 | |
| elif c_max == g: | |
| hue = ((b - r) / chroma) + 2 | |
| elif c_max == b: | |
| hue = ((r - g) / chroma) + 4 | |
| hue *= 60 | |
| # Calculate saturation | |
| if lightness <= 0.5: | |
| saturation = chroma / (2 * lightness) | |
| else: | |
| saturation = chroma / (2 - 2 * lightness) | |
| return [hue, saturation, lightness] | |
| def get_scene_center(objects): | |
| xmin, ymin, zmin = float('inf'), float('inf'), float('inf') | |
| xmax, ymax, zmax = float('-inf'), float('-inf'), float('-inf') | |
| for obj in objects: | |
| x, y, z = obj['center_position'] | |
| if x < xmin: | |
| xmin = x | |
| if x > xmax: | |
| xmax = x | |
| if y < ymin: | |
| ymin = y | |
| if y > ymax: | |
| ymax = y | |
| if z < zmin: | |
| zmin = z | |
| if z > zmax: | |
| zmax = z | |
| return round_list([(xmin + xmax) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2], 2) | |
| def find_relevant_objects(user_instruction, scan_id): | |
| pass | |
| def gen_prompt(user_instruction, scan_id): | |
| npy_path = os.path.join("objects_info", f"objects_info_{scan_id}.npy") | |
| objects_info = np.load(npy_path, allow_pickle=True) | |
| # objects_related = find_relevant_objects(user_instruction, scan_id) | |
| objects_related = objects_info | |
| # Get the center coordinates of the scene | |
| # scene_center=get_scene_center(objects_related) | |
| scene_center = get_scene_center(objects_info) # Note: all object information should be used here, not just the relevant ones | |
| # Generate the background information section of the prompt | |
| prompt = scan_id + ":objects with quantitative description based on right-hand Cartesian coordinate system with x-y-z axes, x-y plane=ground, z-axis=up/down. Coords format [x, y, z].\n\n" | |
| # if dataset == 'nr3d': | |
| # prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center)) | |
| # elif dataset == 'scanrefer': | |
| # if use_camera_position: | |
| # prompt = prompt + "Scene center:%s.\n" % remove_spaces(str(scene_center)) | |
| # prompt = prompt + "Observer position:%s.\n" % remove_spaces(str(round_list(camera_info_aligned['position'], 2))) | |
| # else: | |
| # prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center)) | |
| prompt = prompt + "Scene center:%s. If no direction vector, observer at center for obj orientation.\n\n" % remove_spaces(str(scene_center)) | |
| prompt = prompt + "objs list:\n" | |
| lines = [] | |
| # Generate the quantitative object descriptions in the prompt (iterate over all relevant objects) | |
| for obj in objects_related: | |
| # Position information, rounded to 2 decimal places | |
| center_position = obj['center_position'] | |
| center_position = round_list(center_position, 2) | |
| # Size information, rounded to 2 decimal places | |
| size = obj['size'] | |
| size = round_list(size, 2) | |
| # Extension information, rounded to 2 decimal places | |
| extension = obj['extension'] | |
| extension = round_list(extension, 2) | |
| # Direction information represented by direction vectors. | |
| # Note: ScanRefer does not use the original ScanNet object IDs, so direction information cannot be used. | |
| if obj['has_front']: | |
| front_point = np.array(obj['front_point']) | |
| center = np.array(obj['obb'][0:3]) | |
| direction_vector = front_point - center | |
| direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector) | |
| # Compute the left and right direction vectors as well, all rounded to 2 decimal places | |
| front_vector = round_list(direction_vector_normalized, 2) | |
| up_vector = np.array([0, 0, 1]) | |
| left_vector = round_list(np.cross(direction_vector_normalized, up_vector), 2) | |
| right_vector = round_list(np.cross(up_vector, direction_vector_normalized), 2) | |
| behind_vector = round_list(-np.array(front_vector), 2) | |
| # Generate the direction information | |
| direction_info = ";direction vectors:front=%s,left=%s,right=%s,behind=%s\n" %(front_vector, left_vector, right_vector, behind_vector) | |
| # | |
| else: | |
| direction_info = "\n" # If the direction vector is unknown, leave this blank | |
| # For sr3d, provide center and size | |
| # if dataset == 'sr3d': | |
| if False: | |
| line = f'{obj["label"]},id={obj["id"]},ctr={remove_spaces(str(center_position))},size={remove_spaces(str(size))}' | |
| # For nr3d and ScanRefer, provide center, size, and color | |
| else: | |
| rgb = obj['avg_rgba'][0:3] | |
| hsl = round_list(rgb_to_hsl(rgb), 2) | |
| # line="%s,id=%s,ctr=%s,size=%s,RGB=%s" %(obj['label'], obj['id'], self.remove_space(str(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(rgb) )) original RGB version | |
| line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(obj['label'], obj['id'], remove_spaces(str(center_position)), remove_spaces(str(size)), remove_spaces(str(hsl)))#switched from RGB to HSL | |
| # line = "%s(relevant to %s),id=%s,ctr=%s,size=%s,HSL=%s" % (obj['label'],id_to_name_in_description[obj['id']], obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl))) # Format: name=original name (the name used in the description) | |
| # if id_to_name_in_description[obj['id']]=='room': | |
| # name=obj['label'] | |
| # else: | |
| # name=id_to_name_in_description[obj['id']] | |
| # line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(name, obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl) )) # Format: name=the name used in the description | |
| lines.append(line + direction_info) | |
| # if self.obj_info_ablation_type == 4: | |
| # random.seed(0) | |
| # random.shuffle(lines) | |
| prompt += ''.join(lines) | |
| # Requirements in the prompt | |
| line = "\nInstruction:find the one described object in description: \n\"%s\"\n" % user_instruction | |
| prompt = prompt + line | |
| prompt = prompt + "\n\nThere is exactly one answer, so if you receive multiple answers, considerother constraints; if get no answers, loosen constraints." | |
| prompt = prompt + "\n\nWork this out step by step to ensure right answer." | |
| prompt = prompt + "\n\nIf the answer is complete, add \"Now the answer is complete -- {'ID':id}\" to the end of your answer(that is, your completion, not your code), where id is the id of the referred obj. Do not add anything after." | |
| return prompt | |
| # 20s,40s,80s,120s + random.uniform(0,20) | |
| def get_gpt_response(prompt: str, code_interpreter: CodeInterpreter): | |
| print("llm_name:",code_interpreter.model) | |
| # get response from GPT(using code interpreter). using retry from tenacity. | |
| # count the token usage and time as well | |
| # if the reponse does not include "Now the answer is complete", this means the answer is notdone. attach an empty user message to let GPT to keep going. | |
| # start timing | |
| call_start_time = time.time() | |
| # the first call with the original prompt | |
| response, token_usage_total = code_interpreter.call_openai_with_code_interpreter(prompt) | |
| response = response['content'] | |
| # loop until "Now the answer is complete" is in the response, or looping more than 10 times. | |
| count_response = 0 | |
| while not "Now the answer is complete" in response: | |
| if count_response >= 10: | |
| print("Response does not end with 'Now the answer is complete.' !") | |
| break | |
| response, token_usage_add = code_interpreter.call_openai_with_code_interpreter('') | |
| response = response['content'] | |
| token_usage_total += token_usage_add | |
| count_response += 1 | |
| print("count_response:", count_response) | |
| # stop timing | |
| call_end_time = time.time() | |
| time_consumed = call_end_time - call_start_time | |
| # self.token_usage_this_ques += token_usage_total | |
| # self.token_usage_whole_run += token_usage_total | |
| # self.time_consumed_this_ques += time_consumed | |
| # self.time_consumed_whole_run += time_consumed | |
| # print("\n*** Refer model: token usage=%d, time consumed=%ds, TPM=%.2f ***" %(token_usage_total, time_consumed, token_usage_total / time_consumed * 60)) | |
| return response | |
| def extract_answer_id_from_last_line(last_line, random_choice_list=[0,]): | |
| # If the reply does not follow the expected format, choose randomly (Sr3d) or default to 0 (Nr3d and ScanRefer); | |
| # otherwise, extract the answer from the expected format. | |
| wrong_return_format = False | |
| last_line_split = last_line.split('--') | |
| # Use a regular expression to extract the dictionary portion from the string | |
| pattern = r"\{[^\}]*\}" | |
| match = re.search(pattern, last_line_split[-1]) | |
| if match: | |
| # Get the matched dictionary string | |
| matched_dict_str = match.group() | |
| try: | |
| # Parse the dictionary string into a dictionary object | |
| extracted_dict = ast.literal_eval(matched_dict_str) | |
| print(extracted_dict) | |
| answer_id = extracted_dict['ID'] | |
| # If the response does follow the expected format but xxx is not a number | |
| # (for example, None), still fall back to a random choice. | |
| if not isinstance(answer_id, int): | |
| if isinstance(answer_id, list) and all([isinstance(e, int) for e in answer_id]): | |
| print("Wrong answer format: %s. random choice from this list" % str(answer_id)) | |
| answer_id = random.choice(answer_id) | |
| else: | |
| print("Wrong answer format: %s. No dict found. Random choice from relevant objects." % str(answer_id)) | |
| answer_id = random.choice(random_choice_list) | |
| wrong_return_format = True | |
| except BaseException: | |
| print("Wrong answer format!! No dict found. Random choice.") | |
| answer_id = random.choice(random_choice_list) | |
| wrong_return_format = True | |
| else: | |
| print("Wrong answer format!! No dict found. Random choice.") | |
| answer_id = random.choice(random_choice_list) | |
| wrong_return_format = True | |
| return answer_id, wrong_return_format | |
| def get_openai_config(llm_name='gpt-3.5-turbo-0125'): | |
| system_message = "" | |
| system_message += get_system_message() | |
| system_message += get_principle() | |
| openai_config = { | |
| # 'model': 'gpt-4-turbo-preview', | |
| 'model': llm_name, | |
| 'temperature': 1e-7, | |
| 'top_p': 1e-7, | |
| # 'max_tokens': 4096, | |
| 'max_tokens': 8192, | |
| 'system_message': system_message, | |
| # 'load_path': '', | |
| 'save_path': 'chats', | |
| 'debug': True | |
| } | |
| return openai_config | |
| if __name__ == "__main__": | |
| # system_message = 'Imagine you are an artificial intelligence assistant. You job is to do 3D referring reasoning, namely to find the object for a given utterance from a 3d scene presented as object-centric semantic information.\n' | |
| system_message = "" | |
| system_message += get_system_message() | |
| system_message += get_principle() | |
| openai_config = { | |
| 'model': 'gpt-4', | |
| 'temperature': 1e-7, | |
| 'top_p': 1e-7, | |
| # 'max_tokens': 4096, | |
| 'max_tokens': 8192, | |
| 'system_message': system_message, | |
| # 'load_path': '', | |
| 'save_path': 'chats', | |
| 'debug': True | |
| } | |
| code_interpreter = CodeInterpreter(**openai_config) | |
| prompt = gen_prompt("Find the chair next to the table.", "scene0132_00") | |
| print(prompt) | |
| response = get_gpt_response(prompt, code_interpreter) | |
| # print(response) | |
| print("-------pretext--------") | |
| print(code_interpreter.pretext) | |