File size: 3,745 Bytes
9cf08e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import requests
import json
import os
class Qwen3TokenizerClient:
def __init__(self, base_url :str = None):
self.base_url = base_url or os.getenv("Tokenizer_API_BASE_URL", "http://localhost:8000/v1")
self.encode_url = f"{self.base_url}/encode"
self.decode_url = f"{self.base_url}/decode"
self.batch_url = f"{self.base_url}/batch_encode"
self.health_url = f"{self.base_url}/health"
def check_health(self):
"""检查服务是否正常运行"""
try:
response = requests.get(self.health_url)
if response.status_code == 200:
print(f"✅ 服务状态正常: {response.json()}")
return True
else:
print(f"❌ 服务检查失败: {response.status_code}")
return False
except Exception as e:
print(f"❌ 无法连接到服务: {e}")
return False
def encode(self, text, add_special_tokens=True):
"""
将文本转换为 Token IDs
"""
payload = {
"text": text,
"add_special_tokens": add_special_tokens
}
response = requests.post(self.encode_url, json=payload)
if response.status_code == 200:
return response.json()
else:
print(f"❌ 编码失败: {response.text}")
return None
def batch_encode(self, texts, padding=True, max_length=None):
"""
批量发送文本进行编码
"""
payload = {
"texts": texts,
"padding": padding,
"max_length": max_length
}
response = requests.post(self.batch_url, json=payload)
if response.status_code == 200:
return response.json()
else:
print(f"❌ 批量编码失败: {response.text}")
return None
def decode(self, token_ids, skip_special_tokens=True):
"""
将 Token IDs 还原为文本
"""
payload = {
"token_ids": token_ids,
"skip_special_tokens": skip_special_tokens
}
response = requests.post(self.decode_url, json=payload)
if response.status_code == 200:
return response.json()
else:
print(f"❌ 解码失败: {response.text}")
return None
# --- 主程序入口 ---
if __name__ == "__main__":
# 1. 初始化客户端 (如果你的 API 运行在不同端口,请修改这里)
client = Qwen3TokenizerClient(base_url="http://127.0.0.1:8001")
# 2. 检查服务状态
if not client.check_health():
print("请先启动 API 服务 (python tokenizer_api.py)")
exit()
print("-" * 30)
# 3. 测试 Encode (文本 -> ID)
input_text = "你好,Qwen3!"
print(f"📝 原始文本: {input_text}")
encode_result = client.encode(input_text)
if encode_result:
token_ids = encode_result['token_ids']
count = encode_result['count']
print(f"🔢 Token IDs: {token_ids}")
print(f"📏 Token 长度: {count}")
print("-" * 30)
# 4. 测试 Decode (ID -> 文本)
# 这里我们尝试还原刚才生成的 IDs
decode_result = client.decode(token_ids, skip_special_tokens=False)
if decode_result:
restored_text = decode_result['text']
print(f"🔄 还原结果 (含特殊标记): {restored_text}")
# 尝试跳过特殊标记还原
clean_result = client.decode(token_ids, skip_special_tokens=True)
print(f"✨ 纯净文本: {clean_result['text']}") |