File size: 3,926 Bytes
625a17f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
"""
示例:如何使用修改后的eval_seg函数同时获取分割结果和文本输出
"""

import torch
from transformers import AutoTokenizer

def example_eval_seg_with_text_output(model, tokenizer, input_data):
    """
    示例函数:展示如何使用修改后的eval_seg函数
    
    Args:
        model: PSALM模型实例
        tokenizer: 对应的tokenizer
        input_data: 输入数据字典,包含以下键:
            - input_ids: torch.LongTensor
            - images: torch.FloatTensor
            - seg_info: 分割信息
            - 其他必要参数...
    
    Returns:
        dict: 包含分割结果和解码后的文本
    """
    
    # 调用修改后的eval_seg函数
    result = model.eval_seg(
        input_ids=input_data['input_ids'],
        images=input_data['images'], 
        seg_info=input_data['seg_info'],
        # 以下是新添加的文本生成参数
        generate_text=True,           # 是否生成文本
        max_new_tokens=512,          # 最大生成token数
        temperature=0.2,             # 温度参数
        do_sample=True,              # 是否采样
        # 其他原有参数...
        attention_mask=input_data.get('attention_mask'),
        class_name_embedding_indices=input_data.get('class_name_embedding_indices'),
        cls_indices=input_data.get('cls_indices'),
        token_refer_id=input_data.get('token_refer_id'),
        refer_embedding_indices=input_data.get('refer_embedding_indices'),
        is_thing_list=input_data.get('is_thing_list')
    )
    
    # 提取分割结果
    segmentation_results = result['segmentation_results']
    
    # 提取并解码文本输出
    output_token_ids = result['output_token_ids']
    decoded_text = None
    
    if output_token_ids is not None:
        # 解码生成的token ids为文本
        decoded_text = tokenizer.decode(
            output_token_ids[0], 
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        print(f"生成的文本: {decoded_text}")
    
    # 处理分割结果
    for i, seg_result in enumerate(segmentation_results):
        print(f"图片 {i} 的分割结果:")
        if 'instances' in seg_result:
            instances = seg_result['instances']
            print(f"  - 检测到 {len(instances.pred_masks)} 个实例")
            print(f"  - pred_masks shape: {instances.pred_masks.shape}")
            print(f"  - scores: {instances.scores}")
            if hasattr(instances, 'pred_boxes'):
                print(f"  - pred_boxes: {instances.pred_boxes}")
        
        if 'sem_seg' in seg_result:
            print(f"  - 语义分割结果 shape: {seg_result['sem_seg'].shape}")
        
        if 'panoptic_seg' in seg_result:
            print(f"  - 全景分割结果")
    
    return {
        'segmentation_results': segmentation_results,
        'decoded_text': decoded_text,
        'raw_token_ids': output_token_ids
    }

def example_usage():
    """
    完整的使用示例
    """
    # 假设您已经加载了模型和tokenizer
    # model = ...  # 您的PSALM模型
    # tokenizer = ...  # 对应的tokenizer
    
    # 准备输入数据
    input_data = {
        'input_ids': torch.tensor([[1, 2, 3, ...]]),  # 您的输入token ids
        'images': torch.randn(1, 3, 224, 224),        # 您的图像数据
        'seg_info': [{'instances': ...}],             # 您的分割信息
        # 其他必要的输入...
    }
    
    # 调用函数(需要实际的model和tokenizer)
    # result = example_eval_seg_with_text_output(model, tokenizer, input_data)
    
    print("示例代码准备完毕!")
    print("使用时请确保:")
    print("1. 已正确加载PSALM模型")
    print("2. 已正确加载对应的tokenizer") 
    print("3. 准备了正确格式的input_data")

if __name__ == "__main__":
    example_usage()