zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import json
import os.path
from mmengine.model import BaseModel
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
import torch
import torch.nn as nn
class MLLM_Annotor(BaseModel):
def __init__(self,
model=None,
save_folder='./work_dirs/internvl72b_sam2_obj_cap/'
):
super().__init__()
print(torch.cuda.device_count())
pipe = pipeline(model,
backend_config=TurbomindEngineConfig(
# session_len=8192,
session_len=16000,
tp=torch.cuda.device_count()),
)
self.pipe = [pipe]
self._zero = nn.Linear(10, 10)
self.results_list = []
self.item_idx = 0
if not os.path.exists(save_folder):
os.mkdir(save_folder)
self.save_folder = save_folder
def forward(self, **kwargs):
return None
def predict_forward_sam2(self, data_dicts):
prompts = []
for data_dict in data_dicts:
images = data_dict['images']
texts = data_dict['text_prompt']
prompts.append((texts, images))
response_list = self.pipe[0](prompts)
text_lsit = [item.text for item in response_list]
results_list = []
for i, text in enumerate(text_lsit):
if '图' in text or len(text) < 15:
# wrong response
continue
results_list.append({
'video_id': data_dicts[i]['video_id'],
'obj_id': data_dicts[i]['obj_id'],
'caption': text,
'type': data_dicts[i]['type'],
})
# print('\n\n', results_list[-1], '\n\n')
self.results_list += results_list
if len(self.results_list) > 100:
self.save_step()
print("Saved !!!")
return {}
def predict_forward_sam2_recap(self, data_dicts):
prompts = []
for data_dict in data_dicts:
images = data_dict['images']
texts = data_dict['text_prompt']
prompts.append((texts, images))
response_list = self.pipe[0](prompts)
text_lsit = [item.text for item in response_list]
results_list = []
for i, text in enumerate(text_lsit):
if '图' in text or len(text) < 15:
# wrong response
continue
results_list.append({
'video_id': data_dicts[i]['video_id'],
'obj_id': data_dicts[i]['obj_id'],
'crop_caption': data_dicts[i]['crop_caption'],
'crop_category': data_dicts[i]['crop_category'],
'caption': text,
})
# print('\n\n', prompts[i], '\n', results_list[-1], '\n\n')
self.results_list += results_list
if len(self.results_list) > 100:
self.save_step()
print("Saved !!!")
return {}
def predict_forward_sam2_video_recap(self, data_dicts):
prompts = []
for data_dict in data_dicts:
images = data_dict['images']
texts = data_dict['text_prompt']
prompts.append((texts, images))
response_list = self.pipe[0](prompts)
text_lsit = [item.text for item in response_list]
results_list = []
for i, text in enumerate(text_lsit):
if '图' in text or len(text) < 15:
# wrong response
continue
results_list.append({
'video_id': data_dicts[i]['video_id'],
'obj_id': data_dicts[i]['obj_id'],
'crop_caption': data_dicts[i]['crop_caption'],
'crop_category': data_dicts[i]['crop_category'],
'image_caption': data_dicts[i]['image_caption'],
'video_caption': text,
})
# print('\n\n', prompts[i], '\n', results_list[-1], '\n\n')
self.results_list += results_list
if len(self.results_list) > 100:
self.save_step()
print("Saved !!!")
return {}
def predict_forward_image_dense_cap_objcap(self, data_dicts):
prompts = []
for data_dict in data_dicts:
images = data_dict['images']
texts = data_dict['text_prompt']
prompts.append((texts, images))
response_list = self.pipe[0](prompts)
text_lsit = [item.text for item in response_list]
results_list = []
for i, text in enumerate(text_lsit):
results_list.append({
'image_id': data_dicts[i]['image_id'],
'caption': text,
'object_anno': data_dicts[i]['object_anno']
})
print('\n\n', prompts[i], '\n', results_list[-1], '\n\n')
self.results_list += results_list
if len(self.results_list) > 100:
self.save_step()
print("Saved !!!")
return {}
def predict_forward_image_dense_cap_overallcap(self, data_dicts):
prompts = []
for data_dict in data_dicts:
images = data_dict['images']
texts = data_dict['text_prompt']
prompts.append((texts, images))
response_list = self.pipe[0](prompts)
text_lsit = [item.text for item in response_list]
results_list = []
for i, text in enumerate(text_lsit):
results_list.append({
'image_id': data_dicts[i]['image_id'],
'caption': text,
})
print('\n\n', prompts[i], '\n', results_list[-1], '\n\n')
self.results_list += results_list
if len(self.results_list) > 100:
self.save_step()
print("Saved !!!")
return {}
def save_step(self, last=False):
if last:
save_list = self.results_list
else:
save_list = self.results_list[:100]
self.results_list = self.results_list[100:]
json_path = os.path.join(self.save_folder, f'{self.item_idx}.json')
self.item_idx += 1
with open(json_path, 'w') as f:
json.dump(save_list, fp=f)
return
def predict_forward(self, image_paths, **kwargs):
if 'type' in kwargs.keys() and kwargs['type'] == 'sam2':
return self.predict_forward_sam2(kwargs['data_dicts'])
elif 'type' in kwargs.keys() and kwargs['type'] == 'sam2_recap':
return self.predict_forward_sam2_recap(kwargs['data_dicts'])
elif 'type' in kwargs.keys() and kwargs['type'] == 'sam2_video_recap':
return self.predict_forward_sam2_video_recap(kwargs['data_dicts'])
elif 'type' in kwargs.keys() and kwargs['type'] == 'demo_imgcap':
return self.predict_forward_image_dense_cap_objcap(kwargs['data_dicts'])
elif 'type' in kwargs.keys() and kwargs['type'] == 'demo_imgcap_overall':
return self.predict_forward_image_dense_cap_overallcap(kwargs['data_dicts'])
images = [load_image(image_path) for image_path in image_paths]
prompts = [('Please briefly describe this image in a sentence.', image) for image in images]
response_list = self.pipe[0](prompts)
text_lsit = [item.text for item in response_list]
print(text_lsit)
return {}