yoinked's picture
Create node.py
2987d48 verified
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import os
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class Qwen3Tokenizer512Limit(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=True, max_length=512, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class ZImageTokenizer512Limit(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer512Limit)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class CapZImageTokenizer512(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CapZImageTokenizer512",
display_name="Cap ZImage Tokenizer to 512 tokens ",
category="conditioning",
inputs=[
io.Clip.Input()
],
outputs=[
io.Clip.Output()
],
)
@classmethod
def execute(cls, base) -> io.NodeOutput:
base.tokenizer = ZImageTokenizer512Limit()
return io.NodeOutput(base)
class CapZImageTokenizer512Ext(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CapZImageTokenizer512,
]
async def comfy_entrypoint() -> CapZImageTokenizer512Ext:
return CapZImageTokenizer512Ext()