|
|
|
|
|
""" |
|
|
示例:如何使用修改后的eval_seg函数同时获取分割结果和文本输出 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
def example_eval_seg_with_text_output(model, tokenizer, input_data): |
|
|
""" |
|
|
示例函数:展示如何使用修改后的eval_seg函数 |
|
|
|
|
|
Args: |
|
|
model: PSALM模型实例 |
|
|
tokenizer: 对应的tokenizer |
|
|
input_data: 输入数据字典,包含以下键: |
|
|
- input_ids: torch.LongTensor |
|
|
- images: torch.FloatTensor |
|
|
- seg_info: 分割信息 |
|
|
- 其他必要参数... |
|
|
|
|
|
Returns: |
|
|
dict: 包含分割结果和解码后的文本 |
|
|
""" |
|
|
|
|
|
|
|
|
result = model.eval_seg( |
|
|
input_ids=input_data['input_ids'], |
|
|
images=input_data['images'], |
|
|
seg_info=input_data['seg_info'], |
|
|
|
|
|
generate_text=True, |
|
|
max_new_tokens=512, |
|
|
temperature=0.2, |
|
|
do_sample=True, |
|
|
|
|
|
attention_mask=input_data.get('attention_mask'), |
|
|
class_name_embedding_indices=input_data.get('class_name_embedding_indices'), |
|
|
cls_indices=input_data.get('cls_indices'), |
|
|
token_refer_id=input_data.get('token_refer_id'), |
|
|
refer_embedding_indices=input_data.get('refer_embedding_indices'), |
|
|
is_thing_list=input_data.get('is_thing_list') |
|
|
) |
|
|
|
|
|
|
|
|
segmentation_results = result['segmentation_results'] |
|
|
|
|
|
|
|
|
output_token_ids = result['output_token_ids'] |
|
|
decoded_text = None |
|
|
|
|
|
if output_token_ids is not None: |
|
|
|
|
|
decoded_text = tokenizer.decode( |
|
|
output_token_ids[0], |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=True |
|
|
) |
|
|
print(f"生成的文本: {decoded_text}") |
|
|
|
|
|
|
|
|
for i, seg_result in enumerate(segmentation_results): |
|
|
print(f"图片 {i} 的分割结果:") |
|
|
if 'instances' in seg_result: |
|
|
instances = seg_result['instances'] |
|
|
print(f" - 检测到 {len(instances.pred_masks)} 个实例") |
|
|
print(f" - pred_masks shape: {instances.pred_masks.shape}") |
|
|
print(f" - scores: {instances.scores}") |
|
|
if hasattr(instances, 'pred_boxes'): |
|
|
print(f" - pred_boxes: {instances.pred_boxes}") |
|
|
|
|
|
if 'sem_seg' in seg_result: |
|
|
print(f" - 语义分割结果 shape: {seg_result['sem_seg'].shape}") |
|
|
|
|
|
if 'panoptic_seg' in seg_result: |
|
|
print(f" - 全景分割结果") |
|
|
|
|
|
return { |
|
|
'segmentation_results': segmentation_results, |
|
|
'decoded_text': decoded_text, |
|
|
'raw_token_ids': output_token_ids |
|
|
} |
|
|
|
|
|
def example_usage(): |
|
|
""" |
|
|
完整的使用示例 |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_data = { |
|
|
'input_ids': torch.tensor([[1, 2, 3, ...]]), |
|
|
'images': torch.randn(1, 3, 224, 224), |
|
|
'seg_info': [{'instances': ...}], |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("示例代码准备完毕!") |
|
|
print("使用时请确保:") |
|
|
print("1. 已正确加载PSALM模型") |
|
|
print("2. 已正确加载对应的tokenizer") |
|
|
print("3. 准备了正确格式的input_data") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
example_usage() |
|
|
|