import torch import torch.nn as nn import os import folder_paths from comfy.utils import load_torch_file, ProgressBar from tqdm import tqdm from comfy import model_management as mm script_directory = os.path.dirname(os.path.abspath(__file__)) device = mm.get_torch_device() offload_device = mm.unet_offload_device() from transformers import AutoTokenizer from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM from accelerate import init_empty_weights from ..utils import set_module_tensor_to_device, log from .system_prompt import SYSTEM_PROMPT_MAP SYSTEM_PROMPT_KEYS = [item["label"] for item in SYSTEM_PROMPT_MAP] config_3b ={ "architectures": [ "Qwen2ForCausalLM" ], "attention_dropout": 0.0, "bos_token_id": 151643, "eos_token_id": 151645, "hidden_act": "silu", "hidden_size": 2048, "initializer_range": 0.02, "intermediate_size": 11008, "max_position_embeddings": 32768, "max_window_layers": 70, "model_type": "qwen2", "num_attention_heads": 16, "num_hidden_layers": 36, "num_key_value_heads": 2, "rms_norm_eps": 1e-06, "rope_theta": 1000000.0, "sliding_window": 32768, "tie_word_embeddings": True, "torch_dtype": "bfloat16", "transformers_version": "4.43.1", "use_cache": True, "use_sliding_window": False, "vocab_size": 151936 } config_7b ={ "architectures": [ "Qwen2ForCausalLM" ], "attention_dropout": 0.0, "bos_token_id": 151643, "eos_token_id": 151645, "hidden_act": "silu", "hidden_size": 3584, "initializer_range": 0.02, "intermediate_size": 18944, "max_position_embeddings": 32768, "max_window_layers": 28, "model_type": "qwen2", "num_attention_heads": 28, "num_hidden_layers": 28, "num_key_value_heads": 4, "rms_norm_eps": 1e-06, "rope_theta": 1000000.0, "sliding_window": 131072, "tie_word_embeddings": False, "torch_dtype": "bfloat16", "transformers_version": "4.43.1", "use_cache": True, "use_sliding_window": False, "vocab_size": 152064 } class QwenLoader: @classmethod def INPUT_TYPES(s): return {"required": { "model": (folder_paths.get_filename_list("text_encoders"), ), "load_device": (["main_device", "offload_device"], {"advanced": True}), "precision": (["fp16", "bf16", "fp32"], {"default": "bf16"}), }, } RETURN_TYPES = ("QWENMODEL",) FUNCTION = "load" CATEGORY = "WanVideoWrapper" def load(self, model, load_device, precision): transformer_load_device = device if load_device == "main_device" else offload_device base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[precision] sd = load_torch_file(folder_paths.get_full_path("text_encoders", model)) tokenizer_path = os.path.join(script_directory, "tokenizer") tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) hf_config = Qwen2Config(**config_3b if "3b" in model.lower() else config_7b) # Fix vocab size to match actual tokenizer actual_vocab_size = len(tokenizer) if hf_config.vocab_size != actual_vocab_size: log.warning(f"Adjusting vocab_size from {hf_config.vocab_size} to {actual_vocab_size} to match tokenizer") hf_config.vocab_size = actual_vocab_size with init_empty_weights(): hf_model = Qwen2ForCausalLM(hf_config) log.info("Using accelerate to load and assign model weights to device...") param_count = sum(1 for _ in hf_model.named_parameters()) pbar = ProgressBar(param_count) for name, param in tqdm(hf_model.named_parameters(), desc=f"Loading transformer parameters to {transformer_load_device}", total=param_count, leave=True): if name not in sd: log.warning(f"Parameter {name} not found in state dict, skipping.") continue set_module_tensor_to_device(hf_model, name, device=transformer_load_device, dtype=base_dtype, value=sd[name]) pbar.update(1) hf_model.lm_head = nn.Linear(hf_model.config.hidden_size, hf_model.config.vocab_size, bias=False) if hf_config.tie_word_embeddings: hf_model.lm_head.weight = hf_model.get_input_embeddings().weight else: if "lm_head.weight" in sd: set_module_tensor_to_device(hf_model, "lm_head.weight", device=transformer_load_device, dtype=base_dtype, value=sd["lm_head.weight"]) else: hf_model.lm_head.weight = hf_model.get_input_embeddings().weight hf_model.lm_head.to(hf_model.device, dtype=base_dtype) class EmptyObj: pass qwen = EmptyObj() qwen.model = hf_model qwen.tokenizer = tokenizer return (qwen,) class WanVideoPromptExtender: @classmethod def INPUT_TYPES(s): return {"required": { "qwen": ("QWENMODEL", ), "prompt": ("STRING", {"multiline": True}), "max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 2048, "step": 1, "tooltip": "Maximum number of new tokens to generate."}), "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the model on. Default uses the main device."}), "force_offload": ("BOOLEAN", {"default": True, "tooltip": "Force offload the model to the offload device after generation. Useful for large models."}) }, "optional": { "system_prompt": (SYSTEM_PROMPT_KEYS, {"tooltip": "System prompt to use for the model."}), "custom_system_prompt": ("STRING", {"default": "", "forceInput": True, "tooltip": "Custom system prompt to use instead of the predefined ones."}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), } } RETURN_TYPES = ("STRING",) FUNCTION = "generate" CATEGORY = "WanVideoWrapper" def generate(self, qwen, prompt, device, force_offload, max_new_tokens, system_prompt=None, custom_system_prompt=None, seed=0): if device == "gpu": device = mm.get_torch_device() elif device == "cpu": device = torch.device("cpu") if custom_system_prompt is None: sys_prompt = next((item["prompt"] for item in SYSTEM_PROMPT_MAP if item["label"] == system_prompt), "") else: sys_prompt = custom_system_prompt messages = [ {"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt} ] text = qwen.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) model_inputs = qwen.tokenizer([text], return_tensors="pt").to(device) torch.manual_seed(seed) qwen.model.to(device) generated_ids = qwen.model.generate( **model_inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.8, top_k=20, repetition_penalty=1.05, ) if force_offload: qwen.model.to(offload_device) mm.soft_empty_cache() generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = qwen.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return (response,) class WanVideoPromptExtenderSelect: @classmethod def INPUT_TYPES(s): return {"required": { "model": (folder_paths.get_filename_list("text_encoders"), ), "max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 2048, "step": 1, "tooltip": "Maximum number of new tokens to generate."}), "system_prompt": (SYSTEM_PROMPT_KEYS, {"tooltip": "System prompt to use for the model."}), }, "optional": { "custom_system_prompt": ("STRING", {"default": "", "forceInput": True, "tooltip": "Custom system prompt to use instead of the predefined ones."}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), } } RETURN_TYPES = ("WANVIDEOPROMPTEXTENDER_ARGS",) RETURN_NAMES = ("extender_args",) FUNCTION = "set" CATEGORY = "WanVideoWrapper" def set(self, model, system_prompt, max_new_tokens, custom_system_prompt=None, seed=0): if custom_system_prompt is None: sys_prompt = next((item["prompt"] for item in SYSTEM_PROMPT_MAP if item["label"] == system_prompt), "") else: sys_prompt = custom_system_prompt extender_settings = { "model": model, "system_prompt": sys_prompt, "max_new_tokens": max_new_tokens, "device": "gpu", "force_offload": True, "seed": seed } return (extender_settings,) NODE_CLASS_MAPPINGS = { "QwenLoader": QwenLoader, "WanVideoPromptExtender": WanVideoPromptExtender, "WanVideoPromptExtenderSelect": WanVideoPromptExtenderSelect } NODE_DISPLAY_NAME_MAPPINGS = { "QwenLoader": "Qwen Loader", "WanVideoPromptExtender": "Wan Video Prompt Extender", "WanVideoPromptExtenderSelect": "Wan Video Prompt Extender Select" }