File size: 3,935 Bytes
f66b0a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from http.server import HTTPServer, BaseHTTPRequestHandler
import json
import argparse
from tokenizer.tokenizer import get_qwen_tokenizer
class Tokenizer_Http():
def __init__(self):
self.tokenizer = get_qwen_tokenizer("CosyVoice-BlankEN", True)
def encode(self, prompt):
token_ids = self.tokenizer.encode(prompt, allowed_special="all")
return token_ids
def decode(self, token_ids):
return self.tokenizer.decode(token_ids)
# @property
# def bos_id(self):
# return self.tokenizer.bos_token_id
@property
def eos_id(self):
return 1773
# @property
# def bos_token(self):
# return self.tokenizer.bos_token
@property
def eos_token(self):
return "<|eot_id|>"
tokenizer = Tokenizer_Http()
# print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token)
print(tokenizer.encode("hello world"))
class Request(BaseHTTPRequestHandler):
#通过类继承,新定义类
timeout = 5
server_version = 'Apache'
def do_GET(self):
print(self.path)
#在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行)
self.send_response(200)
self.send_header("type", "get") #设置响应头,可省略或设置多个
self.end_headers()
if self.path == '/bos_id':
bos_id = tokenizer.bos_id
# print(bos_id)
# to json
if bos_id is None:
msg = json.dumps({'bos_id': -1})
else:
msg = json.dumps({'bos_id': bos_id})
elif self.path == '/eos_id':
eos_id = tokenizer.eos_id
if eos_id is None:
msg = json.dumps({'eos_id': -1})
else:
msg = json.dumps({'eos_id': eos_id})
else:
msg = 'error'
print(msg)
msg = str(msg).encode() #转为str再转为byte格式
self.wfile.write(msg) #将byte格式的信息返回给客户端
def do_POST(self):
#在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行)
data = self.rfile.read(int(
self.headers['content-length'])) #获取从客户端传入的参数(byte格式)
data = data.decode() #将byte格式转为str格式
self.send_response(200)
self.send_header("type", "post") #设置响应头,可省略或设置多个
self.end_headers()
if self.path == '/encode':
req = json.loads(data)
prompt = req['text']
token_ids = tokenizer.encode(prompt)
if token_ids is None:
msg = json.dumps({'token_ids': -1})
else:
msg = json.dumps({'token_ids': token_ids})
elif self.path == '/decode':
req = json.loads(data)
token_ids = req['token_ids']
text = tokenizer.decode(token_ids)
if text is None:
msg = json.dumps({'text': ""})
else:
msg = json.dumps({'text': text})
else:
msg = 'error'
print(msg)
msg = str(msg).encode() #转为str再转为byte格式
self.wfile.write(msg) #将byte格式的信息返回给客户端
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument('--host', type=str, default='localhost')
args.add_argument('--port', type=int, default=12345)
args = args.parse_args()
host = (args.host, args.port) #设定地址与端口号,'localhost'等价于'127.0.0.1'
print('http://%s:%s' % host)
server = HTTPServer(host, Request) #根据地址端口号和新定义的类,创建服务器实例
server.serve_forever() #开启服务
|