ObjectRelator-Original / example_eval_seg_with_text.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
#!/usr/bin/env python3
"""
示例:如何使用修改后的eval_seg函数同时获取分割结果和文本输出
"""
import torch
from transformers import AutoTokenizer
def example_eval_seg_with_text_output(model, tokenizer, input_data):
"""
示例函数:展示如何使用修改后的eval_seg函数
Args:
model: PSALM模型实例
tokenizer: 对应的tokenizer
input_data: 输入数据字典,包含以下键:
- input_ids: torch.LongTensor
- images: torch.FloatTensor
- seg_info: 分割信息
- 其他必要参数...
Returns:
dict: 包含分割结果和解码后的文本
"""
# 调用修改后的eval_seg函数
result = model.eval_seg(
input_ids=input_data['input_ids'],
images=input_data['images'],
seg_info=input_data['seg_info'],
# 以下是新添加的文本生成参数
generate_text=True, # 是否生成文本
max_new_tokens=512, # 最大生成token数
temperature=0.2, # 温度参数
do_sample=True, # 是否采样
# 其他原有参数...
attention_mask=input_data.get('attention_mask'),
class_name_embedding_indices=input_data.get('class_name_embedding_indices'),
cls_indices=input_data.get('cls_indices'),
token_refer_id=input_data.get('token_refer_id'),
refer_embedding_indices=input_data.get('refer_embedding_indices'),
is_thing_list=input_data.get('is_thing_list')
)
# 提取分割结果
segmentation_results = result['segmentation_results']
# 提取并解码文本输出
output_token_ids = result['output_token_ids']
decoded_text = None
if output_token_ids is not None:
# 解码生成的token ids为文本
decoded_text = tokenizer.decode(
output_token_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
print(f"生成的文本: {decoded_text}")
# 处理分割结果
for i, seg_result in enumerate(segmentation_results):
print(f"图片 {i} 的分割结果:")
if 'instances' in seg_result:
instances = seg_result['instances']
print(f" - 检测到 {len(instances.pred_masks)} 个实例")
print(f" - pred_masks shape: {instances.pred_masks.shape}")
print(f" - scores: {instances.scores}")
if hasattr(instances, 'pred_boxes'):
print(f" - pred_boxes: {instances.pred_boxes}")
if 'sem_seg' in seg_result:
print(f" - 语义分割结果 shape: {seg_result['sem_seg'].shape}")
if 'panoptic_seg' in seg_result:
print(f" - 全景分割结果")
return {
'segmentation_results': segmentation_results,
'decoded_text': decoded_text,
'raw_token_ids': output_token_ids
}
def example_usage():
"""
完整的使用示例
"""
# 假设您已经加载了模型和tokenizer
# model = ... # 您的PSALM模型
# tokenizer = ... # 对应的tokenizer
# 准备输入数据
input_data = {
'input_ids': torch.tensor([[1, 2, 3, ...]]), # 您的输入token ids
'images': torch.randn(1, 3, 224, 224), # 您的图像数据
'seg_info': [{'instances': ...}], # 您的分割信息
# 其他必要的输入...
}
# 调用函数(需要实际的model和tokenizer)
# result = example_eval_seg_with_text_output(model, tokenizer, input_data)
print("示例代码准备完毕!")
print("使用时请确保:")
print("1. 已正确加载PSALM模型")
print("2. 已正确加载对应的tokenizer")
print("3. 准备了正确格式的input_data")
if __name__ == "__main__":
example_usage()