|
|
import json |
|
|
import os.path |
|
|
|
|
|
from mmengine.model import BaseModel |
|
|
|
|
|
from lmdeploy import pipeline, TurbomindEngineConfig |
|
|
from lmdeploy.vl import load_image |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
class MLLM_Annotor(BaseModel): |
|
|
def __init__(self, |
|
|
model=None, |
|
|
save_folder='./work_dirs/internvl72b_sam2_obj_cap/' |
|
|
): |
|
|
super().__init__() |
|
|
print(torch.cuda.device_count()) |
|
|
pipe = pipeline(model, |
|
|
backend_config=TurbomindEngineConfig( |
|
|
|
|
|
session_len=16000, |
|
|
tp=torch.cuda.device_count()), |
|
|
) |
|
|
self.pipe = [pipe] |
|
|
self._zero = nn.Linear(10, 10) |
|
|
self.results_list = [] |
|
|
self.item_idx = 0 |
|
|
|
|
|
if not os.path.exists(save_folder): |
|
|
os.mkdir(save_folder) |
|
|
self.save_folder = save_folder |
|
|
|
|
|
def forward(self, **kwargs): |
|
|
return None |
|
|
|
|
|
def predict_forward_sam2(self, data_dicts): |
|
|
prompts = [] |
|
|
for data_dict in data_dicts: |
|
|
images = data_dict['images'] |
|
|
texts = data_dict['text_prompt'] |
|
|
prompts.append((texts, images)) |
|
|
|
|
|
response_list = self.pipe[0](prompts) |
|
|
text_lsit = [item.text for item in response_list] |
|
|
|
|
|
results_list = [] |
|
|
|
|
|
for i, text in enumerate(text_lsit): |
|
|
if '图' in text or len(text) < 15: |
|
|
|
|
|
continue |
|
|
results_list.append({ |
|
|
'video_id': data_dicts[i]['video_id'], |
|
|
'obj_id': data_dicts[i]['obj_id'], |
|
|
'caption': text, |
|
|
'type': data_dicts[i]['type'], |
|
|
}) |
|
|
|
|
|
|
|
|
self.results_list += results_list |
|
|
|
|
|
if len(self.results_list) > 100: |
|
|
self.save_step() |
|
|
print("Saved !!!") |
|
|
|
|
|
return {} |
|
|
|
|
|
def predict_forward_sam2_recap(self, data_dicts): |
|
|
prompts = [] |
|
|
for data_dict in data_dicts: |
|
|
images = data_dict['images'] |
|
|
texts = data_dict['text_prompt'] |
|
|
prompts.append((texts, images)) |
|
|
|
|
|
response_list = self.pipe[0](prompts) |
|
|
text_lsit = [item.text for item in response_list] |
|
|
|
|
|
results_list = [] |
|
|
|
|
|
for i, text in enumerate(text_lsit): |
|
|
if '图' in text or len(text) < 15: |
|
|
|
|
|
continue |
|
|
results_list.append({ |
|
|
'video_id': data_dicts[i]['video_id'], |
|
|
'obj_id': data_dicts[i]['obj_id'], |
|
|
'crop_caption': data_dicts[i]['crop_caption'], |
|
|
'crop_category': data_dicts[i]['crop_category'], |
|
|
'caption': text, |
|
|
}) |
|
|
|
|
|
|
|
|
self.results_list += results_list |
|
|
|
|
|
if len(self.results_list) > 100: |
|
|
self.save_step() |
|
|
print("Saved !!!") |
|
|
|
|
|
return {} |
|
|
|
|
|
def predict_forward_sam2_video_recap(self, data_dicts): |
|
|
prompts = [] |
|
|
for data_dict in data_dicts: |
|
|
images = data_dict['images'] |
|
|
texts = data_dict['text_prompt'] |
|
|
prompts.append((texts, images)) |
|
|
|
|
|
response_list = self.pipe[0](prompts) |
|
|
text_lsit = [item.text for item in response_list] |
|
|
|
|
|
results_list = [] |
|
|
|
|
|
for i, text in enumerate(text_lsit): |
|
|
if '图' in text or len(text) < 15: |
|
|
|
|
|
continue |
|
|
results_list.append({ |
|
|
'video_id': data_dicts[i]['video_id'], |
|
|
'obj_id': data_dicts[i]['obj_id'], |
|
|
'crop_caption': data_dicts[i]['crop_caption'], |
|
|
'crop_category': data_dicts[i]['crop_category'], |
|
|
'image_caption': data_dicts[i]['image_caption'], |
|
|
'video_caption': text, |
|
|
}) |
|
|
|
|
|
|
|
|
self.results_list += results_list |
|
|
|
|
|
if len(self.results_list) > 100: |
|
|
self.save_step() |
|
|
print("Saved !!!") |
|
|
|
|
|
return {} |
|
|
|
|
|
def predict_forward_image_dense_cap_objcap(self, data_dicts): |
|
|
prompts = [] |
|
|
for data_dict in data_dicts: |
|
|
images = data_dict['images'] |
|
|
texts = data_dict['text_prompt'] |
|
|
prompts.append((texts, images)) |
|
|
|
|
|
response_list = self.pipe[0](prompts) |
|
|
text_lsit = [item.text for item in response_list] |
|
|
|
|
|
results_list = [] |
|
|
|
|
|
for i, text in enumerate(text_lsit): |
|
|
results_list.append({ |
|
|
'image_id': data_dicts[i]['image_id'], |
|
|
'caption': text, |
|
|
'object_anno': data_dicts[i]['object_anno'] |
|
|
}) |
|
|
print('\n\n', prompts[i], '\n', results_list[-1], '\n\n') |
|
|
|
|
|
self.results_list += results_list |
|
|
|
|
|
if len(self.results_list) > 100: |
|
|
self.save_step() |
|
|
print("Saved !!!") |
|
|
|
|
|
return {} |
|
|
|
|
|
def predict_forward_image_dense_cap_overallcap(self, data_dicts): |
|
|
prompts = [] |
|
|
for data_dict in data_dicts: |
|
|
images = data_dict['images'] |
|
|
texts = data_dict['text_prompt'] |
|
|
prompts.append((texts, images)) |
|
|
|
|
|
response_list = self.pipe[0](prompts) |
|
|
text_lsit = [item.text for item in response_list] |
|
|
|
|
|
results_list = [] |
|
|
|
|
|
for i, text in enumerate(text_lsit): |
|
|
results_list.append({ |
|
|
'image_id': data_dicts[i]['image_id'], |
|
|
'caption': text, |
|
|
}) |
|
|
print('\n\n', prompts[i], '\n', results_list[-1], '\n\n') |
|
|
|
|
|
self.results_list += results_list |
|
|
|
|
|
if len(self.results_list) > 100: |
|
|
self.save_step() |
|
|
print("Saved !!!") |
|
|
|
|
|
return {} |
|
|
|
|
|
def save_step(self, last=False): |
|
|
if last: |
|
|
save_list = self.results_list |
|
|
else: |
|
|
save_list = self.results_list[:100] |
|
|
self.results_list = self.results_list[100:] |
|
|
|
|
|
json_path = os.path.join(self.save_folder, f'{self.item_idx}.json') |
|
|
self.item_idx += 1 |
|
|
with open(json_path, 'w') as f: |
|
|
json.dump(save_list, fp=f) |
|
|
return |
|
|
|
|
|
def predict_forward(self, image_paths, **kwargs): |
|
|
|
|
|
if 'type' in kwargs.keys() and kwargs['type'] == 'sam2': |
|
|
return self.predict_forward_sam2(kwargs['data_dicts']) |
|
|
elif 'type' in kwargs.keys() and kwargs['type'] == 'sam2_recap': |
|
|
return self.predict_forward_sam2_recap(kwargs['data_dicts']) |
|
|
elif 'type' in kwargs.keys() and kwargs['type'] == 'sam2_video_recap': |
|
|
return self.predict_forward_sam2_video_recap(kwargs['data_dicts']) |
|
|
elif 'type' in kwargs.keys() and kwargs['type'] == 'demo_imgcap': |
|
|
return self.predict_forward_image_dense_cap_objcap(kwargs['data_dicts']) |
|
|
elif 'type' in kwargs.keys() and kwargs['type'] == 'demo_imgcap_overall': |
|
|
return self.predict_forward_image_dense_cap_overallcap(kwargs['data_dicts']) |
|
|
|
|
|
images = [load_image(image_path) for image_path in image_paths] |
|
|
prompts = [('Please briefly describe this image in a sentence.', image) for image in images] |
|
|
response_list = self.pipe[0](prompts) |
|
|
text_lsit = [item.text for item in response_list] |
|
|
print(text_lsit) |
|
|
return {} |
|
|
|