yoinked commited on
Commit
2987d48
·
verified ·
1 Parent(s): b6c0219

Create node.py

Browse files
Files changed (1) hide show
  1. not-wan/node.py +60 -0
not-wan/node.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2Tokenizer
2
+ from comfy import sd1_clip
3
+ import os
4
+ from typing_extensions import override
5
+ from comfy_api.latest import ComfyExtension, io
6
+
7
+ class Qwen3Tokenizer512Limit(sd1_clip.SDTokenizer):
8
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
9
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
10
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=True, max_length=512, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
11
+
12
+
13
+ class ZImageTokenizer512Limit(sd1_clip.SD1Tokenizer):
14
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
15
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer512Limit)
16
+ self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
17
+
18
+ def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
19
+ if llama_template is None:
20
+ llama_text = self.llama_template.format(text)
21
+ else:
22
+ llama_text = llama_template.format(text)
23
+
24
+ tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
25
+ return tokens
26
+
27
+
28
+
29
+
30
+ class CapZImageTokenizer512(io.ComfyNode):
31
+
32
+ @classmethod
33
+ def define_schema(cls) -> io.Schema:
34
+ return io.Schema(
35
+ node_id="CapZImageTokenizer512",
36
+ display_name="Cap ZImage Tokenizer to 512 tokens ",
37
+ category="conditioning",
38
+ inputs=[
39
+ io.Clip.Input()
40
+ ],
41
+ outputs=[
42
+ io.Clip.Output()
43
+ ],
44
+ )
45
+
46
+ @classmethod
47
+ def execute(cls, base) -> io.NodeOutput:
48
+ base.tokenizer = ZImageTokenizer512Limit()
49
+ return io.NodeOutput(base)
50
+
51
+ class CapZImageTokenizer512Ext(ComfyExtension):
52
+ @override
53
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
54
+ return [
55
+ CapZImageTokenizer512,
56
+ ]
57
+
58
+
59
+ async def comfy_entrypoint() -> CapZImageTokenizer512Ext:
60
+ return CapZImageTokenizer512Ext()