File size: 1,673 Bytes
bf19a0a
 
 
a988268
bf19a0a
 
 
a988268
bf19a0a
 
 
 
 
 
 
 
 
 
 
 
 
a988268
bf19a0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ec131e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from transformers import Pipeline, PreTrainedTokenizer, AutoTokenizer
from typing import Dict, Union, List
import torch

class TokenizerPipeline(Pipeline):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _sanitize_parameters(self, **kwargs):
        # 处理传入参数:是否解码、padding等
        preprocess_kwargs = {}
        if "padding" in kwargs:
            preprocess_kwargs["padding"] = kwargs["padding"]
        if "truncation" in kwargs:
            preprocess_kwargs["truncation"] = kwargs["truncation"]
        
        postprocess_kwargs = {}
        if "return_tokens" in kwargs:
            postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]
        
        return preprocess_kwargs, {}, postprocess_kwargs

    def preprocess(self, inputs, **kwargs) -> Dict:
        # 使用Tokenizer处理输入文本
        return self.tokenizer(inputs, return_tensors="pt", **kwargs)

    def _forward(self, inputs) -> Dict:
        # 直接返回预处理结果(无模型推理)
        return inputs

    def postprocess(self, model_outputs, **kwargs) -> Dict:
        # 转换输出为可读格式
        input_ids = model_outputs["input_ids"][0]
        
        if kwargs.get("return_tokens", True):
            tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
            return {"tokens": tokens}
        else:
            return {"input_ids": input_ids.tolist()}

# 关键:创建并导出pipeline实例
tokenizer = AutoTokenizer.from_pretrained(".")
pipeline = TokenizerPipeline(tokenizer=tokenizer)

# 可选:添加类型提示供HF解析
def get_pipeline() -> Pipeline:
    return pipeline