File size: 1,673 Bytes
bf19a0a a988268 bf19a0a a988268 bf19a0a a988268 bf19a0a 7ec131e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from transformers import Pipeline, PreTrainedTokenizer, AutoTokenizer
from typing import Dict, Union, List
import torch
class TokenizerPipeline(Pipeline):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _sanitize_parameters(self, **kwargs):
# 处理传入参数:是否解码、padding等
preprocess_kwargs = {}
if "padding" in kwargs:
preprocess_kwargs["padding"] = kwargs["padding"]
if "truncation" in kwargs:
preprocess_kwargs["truncation"] = kwargs["truncation"]
postprocess_kwargs = {}
if "return_tokens" in kwargs:
postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]
return preprocess_kwargs, {}, postprocess_kwargs
def preprocess(self, inputs, **kwargs) -> Dict:
# 使用Tokenizer处理输入文本
return self.tokenizer(inputs, return_tensors="pt", **kwargs)
def _forward(self, inputs) -> Dict:
# 直接返回预处理结果(无模型推理)
return inputs
def postprocess(self, model_outputs, **kwargs) -> Dict:
# 转换输出为可读格式
input_ids = model_outputs["input_ids"][0]
if kwargs.get("return_tokens", True):
tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
return {"tokens": tokens}
else:
return {"input_ids": input_ids.tolist()}
# 关键:创建并导出pipeline实例
tokenizer = AutoTokenizer.from_pretrained(".")
pipeline = TokenizerPipeline(tokenizer=tokenizer)
# 可选:添加类型提示供HF解析
def get_pipeline() -> Pipeline:
return pipeline
|