| |
| |
|
|
| from transformers import AutoTokenizer |
| from http.server import HTTPServer, BaseHTTPRequestHandler |
| import json |
| import argparse |
|
|
|
|
| class Tokenizer_Http: |
| def __init__(self): |
| model_id = "TinySwallow-1.5B-Instruct-ax630c" |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
| def encode(self, prompt): |
| messages = [ |
| {"role": "system", "content": "あなたは、Sakana AI株式会社が開発したTinySwallowです。小型ながら、誠実で優秀なアシスタントです。"}, |
| {"role": "user", "content": prompt} |
| ] |
| text = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| print(text) |
| token_ids = self.tokenizer.encode(text) |
| return token_ids |
|
|
| def decode(self, token_ids): |
| return self.tokenizer.decode(token_ids) |
|
|
| @property |
| def bos_id(self): |
| return self.tokenizer.bos_token_id |
|
|
| @property |
| def eos_id(self): |
| return self.tokenizer.eos_token_id |
|
|
| @property |
| def bos_token(self): |
| return self.tokenizer.bos_token |
|
|
| @property |
| def eos_token(self): |
| return self.tokenizer.eos_token |
|
|
|
|
| tokenizer = Tokenizer_Http() |
|
|
| print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) |
| print(tokenizer.encode("hello world")) |
|
|
|
|
| class Request(BaseHTTPRequestHandler): |
| timeout = 5 |
| server_version = 'Apache' |
|
|
| def _set_headers(self): |
| self.send_response(200) |
| self.send_header("Content-Type", "application/json; charset=utf-8") |
| self.end_headers() |
|
|
| def do_GET(self): |
| print(self.path) |
| self._set_headers() |
|
|
| if self.path == '/bos_id': |
| bos_id = tokenizer.bos_id |
| msg = json.dumps({'bos_id': bos_id if bos_id is not None else -1}, ensure_ascii=False) |
| elif self.path == '/eos_id': |
| eos_id = tokenizer.eos_id |
| msg = json.dumps({'eos_id': eos_id if eos_id is not None else -1}, ensure_ascii=False) |
| else: |
| msg = json.dumps({'error': 'Invalid path'}, ensure_ascii=False) |
|
|
| print(msg) |
| self.wfile.write(msg.encode('utf-8')) |
|
|
| def do_POST(self): |
| content_length = int(self.headers['Content-Length']) |
| data = self.rfile.read(content_length).decode('utf-8') |
|
|
| self._set_headers() |
|
|
| try: |
| req = json.loads(data) |
| except json.JSONDecodeError: |
| msg = json.dumps({'error': 'Invalid JSON'}, ensure_ascii=False) |
| self.wfile.write(msg.encode('utf-8')) |
| return |
|
|
| if self.path == '/encode': |
| prompt = req.get('text', '') |
| token_ids = tokenizer.encode(prompt) |
| msg = json.dumps({'token_ids': token_ids if token_ids else -1}, ensure_ascii=False) |
|
|
| elif self.path == '/decode': |
| token_ids = req.get('token_ids', []) |
| text = tokenizer.decode(token_ids) |
| msg = json.dumps({'text': text if text else ""}, ensure_ascii=False) |
|
|
| else: |
| msg = json.dumps({'error': 'Invalid path'}, ensure_ascii=False) |
|
|
| print(msg) |
| self.wfile.write(msg.encode('utf-8')) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--host', type=str, default='localhost') |
| parser.add_argument('--port', type=int, default=8080) |
| args = parser.parse_args() |
|
|
| host = (args.host, args.port) |
| print(f'http://{host[0]}:{host[1]}') |
| server = HTTPServer(host, Request) |
| server.serve_forever() |
|
|
|
|