Spaces:
Runtime error
Runtime error
| import itertools | |
| import numpy as np | |
| from PIL import Image | |
| from PIL import ImageSequence | |
| from nltk import pos_tag, word_tokenize | |
| from LLaMA2_Accessory.SPHINX import SPHINXModel | |
| from gpt_combinator import caption_summary | |
| class CaptionRefiner(): | |
| def __init__(self, sample_num, add_detect=True, add_pos=True, add_attr=True, | |
| openai_api_key=None, openai_api_base=None, | |
| ): | |
| self.sample_num = sample_num | |
| self.ADD_DETECTION_OBJ = add_detect | |
| self.ADD_POS = add_pos | |
| self.ADD_ATTR = add_attr | |
| self.openai_api_key = openai_api_key | |
| self.openai_api_base =openai_api_base | |
| def video_load_split(self, video_path=None): | |
| frame_img_list, sampled_img_list = [], [] | |
| if ".gif" in video_path: | |
| img = Image.open(video_path) | |
| # process every frame in GIF from <PIL.GifImagePlugin.GifImageFile> to <PIL.JpegImagePlugin.JpegImageFile> | |
| for frame in ImageSequence.Iterator(img): | |
| frame_np = np.array(frame.copy().convert('RGB').getdata(),dtype=np.uint8).reshape(frame.size[1],frame.size[0],3) | |
| frame_img = Image.fromarray(np.uint8(frame_np)) | |
| frame_img_list.append(frame_img) | |
| elif ".mp4" in video_path: | |
| pass | |
| # sample frames from the mp4/gif | |
| for i in range(0, len(frame_img_list), int(len(frame_img_list)/self.sample_num)): | |
| sampled_img_list.append(frame_img_list[i]) | |
| return sampled_img_list # [<PIL.JpegImagePlugin.JpegImageFile>, ...] | |
| def caption_refine(self, video_path, org_caption, model_path): | |
| sampled_imgs = self.video_load_split(video_path) | |
| model = SPHINXModel.from_pretrained( | |
| pretrained_path=model_path, | |
| with_visual=True | |
| ) | |
| existing_objects, scene_description = [], [] | |
| text = word_tokenize(org_caption) | |
| existing_objects = [word for word,tag in pos_tag(text) if tag in ["NN", "NNS", "NNP"]] | |
| if self.ADD_DETECTION_OBJ: | |
| # Detect the objects and scene in the sampled images | |
| qas = [["Where is this scene in the picture most likely to take place?", None]] | |
| sc_response = model.generate_response(qas, sampled_imgs[0], max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) | |
| scene_description.append(sc_response) | |
| # # Lacking accuracy | |
| # for img in sampled_imgs: | |
| # qas = [["Please detect the objects in the image.", None]] | |
| # response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) | |
| # print(response) | |
| object_attrs = [] | |
| if self.ADD_ATTR: | |
| # Detailed Description for all the objects in the sampled images | |
| for obj in existing_objects: | |
| obj_attr = [] | |
| for img in sampled_imgs: | |
| qas = [["Please describe the attribute of the {}, including color, position, etc".format(obj), None]] | |
| response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) | |
| obj_attr.append(response) | |
| object_attrs.append({obj : obj_attr}) | |
| space_relations = [] | |
| if self.ADD_POS: | |
| obj_pairs = list(itertools.combinations(existing_objects, 2)) | |
| # Description for the relationship between each object in the sample images | |
| for obj_pair in obj_pairs: | |
| qas = [["What is the spatial relationship between the {} and the {}? Please describe in lease than twenty words".format(obj_pair[0], obj_pair[1]), None]] | |
| response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) | |
| space_relations.append(response) | |
| return dict( | |
| org_caption = org_caption, | |
| scene_description = scene_description, | |
| existing_objects = existing_objects, | |
| object_attrs = object_attrs, | |
| space_relations = space_relations, | |
| ) | |
| def gpt_summary(self, total_captions): | |
| # combine all captions into a detailed long caption | |
| detailed_caption = "" | |
| if "org_caption" in total_captions.keys(): | |
| detailed_caption += "In summary, "+ total_captions['org_caption'] | |
| if "scene_description" in total_captions.keys(): | |
| detailed_caption += "We first describe the whole scene. "+total_captions['scene_description'][-1] | |
| if "existing_objects" in total_captions.keys(): | |
| tmp_sentence = "There are multiple objects in the video, including " | |
| for obj in total_captions['existing_objects']: | |
| tmp_sentence += obj+", " | |
| detailed_caption += tmp_sentence | |
| # if "object_attrs" in total_captions.keys(): | |
| # caption_summary( | |
| # caption_list="", | |
| # api_key=self.openai_api_key, | |
| # api_base=self.openai_api_base, | |
| # ) | |
| if "space_relations" in total_captions.keys(): | |
| tmp_sentence = "As for the spatial relationship. " | |
| for sentence in total_captions['space_relations']: tmp_sentence += sentence | |
| detailed_caption += tmp_sentence | |
| detailed_caption = caption_summary(detailed_caption, self.open_api_key, self.open_api_base) | |
| return detailed_caption |