Spaces:
Configuration error
Configuration error
| import re | |
| import random | |
| import gc | |
| import comfy.model_management as mm | |
| from nodes import ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine | |
| def chatglm3_text_encode(chatglm3_model, prompt, clean_gpu=False): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| if clean_gpu: | |
| mm.unload_all_models() | |
| mm.soft_empty_cache() | |
| # Function to randomly select an option from the brackets | |
| def choose_random_option(match): | |
| options = match.group(1).split('|') | |
| return random.choice(options) | |
| prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt) | |
| if "|" in prompt: | |
| prompt = prompt.split("|") | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| # Define tokenizers and text encoders | |
| tokenizer = chatglm3_model['tokenizer'] | |
| text_encoder = chatglm3_model['text_encoder'] | |
| text_encoder.to(device) | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=256, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).to(device) | |
| output = text_encoder( | |
| input_ids=text_inputs['input_ids'], | |
| attention_mask=text_inputs['attention_mask'], | |
| position_ids=text_inputs['position_ids'], | |
| output_hidden_states=True) | |
| # [batch_size, 77, 4096] | |
| prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() | |
| text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096] | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, 1, 1) | |
| prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) | |
| bs_embed = text_proj.shape[0] | |
| text_proj = text_proj.repeat(1, 1).view(bs_embed, -1) | |
| text_encoder.to(offload_device) | |
| if clean_gpu: | |
| mm.soft_empty_cache() | |
| gc.collect() | |
| return [[prompt_embeds, {"pooled_output": text_proj},]] | |
| def chatglm3_adv_text_encode(chatglm3_model, text, clean_gpu=False): | |
| time_start = 0 | |
| time_end = 1 | |
| match = re.search(r'TIMESTEP.*$', text) | |
| if match: | |
| timestep = match.group() | |
| timestep = timestep.split(' ') | |
| timestep = timestep[0] | |
| text = text.replace(timestep, '') | |
| value = timestep.split(':') | |
| if len(value) >= 3: | |
| time_start = float(value[1]) | |
| time_end = float(value[2]) | |
| elif len(value) == 2: | |
| time_start = float(value[1]) | |
| time_end = 1 | |
| elif len(value) == 1: | |
| time_start = 0.1 | |
| time_end = 1 | |
| pass3 = [x.strip() for x in text.split("BREAK")] | |
| pass3 = [x for x in pass3 if x != ''] | |
| if len(pass3) == 0: | |
| pass3 = [''] | |
| conditioning = None | |
| for text in pass3: | |
| cond = chatglm3_text_encode(chatglm3_model, text, clean_gpu) | |
| if conditioning is not None: | |
| conditioning = ConditioningConcat().concat(conditioning, cond)[0] | |
| else: | |
| conditioning = cond | |
| # setTimeStepRange | |
| if time_start > 0 or time_end < 1: | |
| conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start) | |
| conditioning_1, = ConditioningZeroOut().zero_out(conditioning) | |
| conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end) | |
| conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2) | |
| return conditioning |