custom_toolbox / apps /text_tools.py
MashiroLn's picture
Upload folder using huggingface_hub
4f4e23f verified
raw
history blame
7.28 kB
import gradio as gr
import math
import json
import os
from transformers import AutoTokenizer
# --- Tokenizer 加载逻辑 ---
# 为了避免每次请求都重新加载,我们可以尝试缓存 tokenizer
# 但在 HF Spaces 中,内存有限,且模型可能很大。
# 对于 Qwen2.5-VL,我们可以使用 Qwen/Qwen2.5-VL-7B-Instruct 的 tokenizer
# 对于 Llava,通常使用 Llama-2 或 Vicuna 的 tokenizer
TOKENIZERS = {}
def get_tokenizer(model_name):
if model_name in TOKENIZERS:
return TOKENIZERS[model_name]
try:
if model_name == "Qwen2.5-VL / Qwen2-VL":
# Qwen2-VL 使用 Qwen2 的 tokenizer
# 注意:这里需要联网下载 tokenizer.json,HF Spaces 通常允许
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True)
elif model_name == "Llava-1.6 (Next)":
# Llava-1.6 基于 Vicuna/Llama-2,这里用 Llama-2 tokenizer 近似,或者直接用 llava-hf
# 为了通用性,我们使用 llava-hf/llava-v1.6-vicuna-7b-hf
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf", trust_remote_code=True)
else:
return None
TOKENIZERS[model_name] = tokenizer
return tokenizer
except Exception as e:
print(f"Error loading tokenizer for {model_name}: {e}")
return None
# --- Token 计算逻辑 ---
def calculate_qwen2_vl_tokens(text, images, videos, tokenizer):
"""
Qwen2-VL / Qwen2.5-VL Token 计算公式
"""
total_tokens = 0
# 1. 文本 Token (真实计算)
text_tokens = []
if tokenizer:
text_tokens = tokenizer.encode(text)
total_tokens += len(text_tokens)
else:
# Fallback
total_tokens += len(text) // 2
# 2. 图片 Token
for img in images:
width, height = img['width'], img['height']
new_w = int(round(width / 28.0) * 28)
new_h = int(round(height / 28.0) * 28)
grid_w = new_w // 14
grid_h = new_h // 14
img_tokens = grid_h * grid_w
total_tokens += img_tokens
# 3. 视频 Token
for vid in videos:
frames = vid['frames']
width, height = vid['width'], vid['height']
new_w = int(round(width / 28.0) * 28)
new_h = int(round(height / 28.0) * 28)
grid_w = new_w // 14
grid_h = new_h // 14
frame_tokens = grid_h * grid_w
total_tokens += frames * frame_tokens
return total_tokens, text_tokens
def calculate_llava_next_tokens(text, images, tokenizer):
"""
Llava-1.6 (Next) Token 计算公式
"""
total_tokens = 0
# 1. 文本 Token
text_tokens = []
if tokenizer:
text_tokens = tokenizer.encode(text)
total_tokens += len(text_tokens)
else:
total_tokens += len(text) // 2
# 2. 图片 Token
for img in images:
width, height = img['width'], img['height']
scale_res = 336
patch_x = math.ceil(width / scale_res)
patch_y = math.ceil(height / scale_res)
num_patches = patch_x * patch_y
img_tokens = (num_patches + 1) * 576
total_tokens += img_tokens
return total_tokens, text_tokens
# --- 实际 UI 逻辑 ---
def run_calculation(text, model, img_count, img_w, img_h, vid_count, vid_frames, vid_w, vid_h):
# 构造虚拟数据
images = [{'width': img_w, 'height': img_h} for _ in range(int(img_count))]
videos = [{'width': vid_w, 'height': vid_h, 'frames': int(vid_frames)} for _ in range(int(vid_count))]
# 获取 Tokenizer
tokenizer = get_tokenizer(model)
tokenizer_status = "✅ 已加载真实 Tokenizer" if tokenizer else "⚠️ Tokenizer 加载失败,使用估算值"
text_tokens_ids = []
if model == "Qwen2.5-VL / Qwen2-VL":
tokens, text_tokens_ids = calculate_qwen2_vl_tokens(text, images, videos, tokenizer)
info = "Qwen2-VL 使用 Naive Dynamic Resolution (patch 14x14)。\n图片会被 resize 为 28 的倍数。"
elif model == "Llava-1.6 (Next)":
tokens, text_tokens_ids = calculate_llava_next_tokens(text, images, tokenizer)
info = "Llava-1.6 使用 AnyRes 技术 (base 336x336)。\n包含 Base Image + Grid Patches。"
else:
tokens = 0
info = "未知模型"
# 生成 Token 对应文件
token_file_path = None
if tokenizer and text_tokens_ids:
token_data = []
# 解码每个 token id 对应的 string
for tid in text_tokens_ids:
token_str = tokenizer.decode([tid])
token_data.append({"id": tid, "token": token_str})
token_file_path = "token_analysis.json"
with open(token_file_path, "w", encoding="utf-8") as f:
json.dump({"text": text, "tokens": token_data}, f, ensure_ascii=False, indent=2)
return {
"总 Token 数": tokens,
"自然语言字符数": len(text),
"Tokenizer 状态": tokenizer_status,
"模型": model,
"说明": info
}, token_file_path
def create_ui():
with gr.Row():
with gr.Column(scale=1):
model_select = gr.Dropdown(
choices=["Qwen2.5-VL / Qwen2-VL", "Llava-1.6 (Next)"],
value="Qwen2.5-VL / Qwen2-VL",
label="选择模型"
)
text_input = gr.Textbox(lines=5, label="输入文本 (Text)", placeholder="输入 Prompt...")
with gr.Accordion("🖼️ 图片设置 (Images)", open=True):
with gr.Row():
img_count = gr.Number(value=1, label="图片数量", precision=0)
img_w = gr.Number(value=1024, label="宽 (px)")
img_h = gr.Number(value=1024, label="高 (px)")
with gr.Accordion("🎥 视频设置 (Videos)", open=False):
with gr.Row():
vid_count = gr.Number(value=0, label="视频数量", precision=0)
vid_frames = gr.Number(value=16, label="总帧数/视频", precision=0)
vid_w = gr.Number(value=512, label="宽 (px)")
vid_h = gr.Number(value=512, label="高 (px)")
btn = gr.Button("🚀 计算 Token", variant="primary")
with gr.Column(scale=1):
out_json = gr.JSON(label="计算结果")
out_file = gr.File(label="下载 Token 分析 (JSON)")
gr.Markdown("""
### 说明
* **真实 Tokenizer**: 首次运行时会自动下载 `transformers` 模型配置,可能需要几秒钟。
* **Qwen2-VL**: 基于 `H/14 * W/14` 计算,自动对齐到 28px 网格。
* **Llava-1.6**: 基于 `(Patches + 1) * 576` 计算,Patch 大小为 336px。
""")
btn.click(
run_calculation,
[text_input, model_select, img_count, img_w, img_h, vid_count, vid_frames, vid_w, vid_h],
[out_json, out_file]
)