Spaces:
Runtime error
Runtime error
| import base64 | |
| import difflib | |
| import json | |
| import os | |
| import diffusers | |
| import numpy as np | |
| import requests | |
| import torch | |
| import torch.nn.functional as F | |
| import transformers | |
| from diffusers import (AutoencoderKL, DiffusionPipeline, | |
| FlowMatchEulerDiscreteScheduler, FluxPipeline, | |
| FluxTransformer2DModel, SD3Transformer2DModel, | |
| StableDiffusion3Pipeline) | |
| from diffusers.callbacks import PipelineCallback | |
| from torchao.quantization import int8_weight_only, quantize_ | |
| from torchvision import transforms | |
| from transformers import (AutoModelForCausalLM, AutoProcessor, CLIPTextModel, | |
| CLIPTextModelWithProjection, T5EncoderModel) | |
| def get_flux_pipeline( | |
| model_id="black-forest-labs/FLUX.1-dev", | |
| pipeline_class=FluxPipeline, | |
| torch_dtype=torch.bfloat16, | |
| quantize=False | |
| ): | |
| ############ Diffusion Transformer ############ | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| model_id, subfolder="transformer", torch_dtype=torch_dtype | |
| ) | |
| ############ Text Encoder ############ | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| model_id, subfolder="text_encoder", torch_dtype=torch_dtype | |
| ) | |
| ############ Text Encoder 2 ############ | |
| text_encoder_2 = T5EncoderModel.from_pretrained( | |
| model_id, subfolder="text_encoder_2", torch_dtype=torch_dtype | |
| ) | |
| ############ VAE ############ | |
| vae = AutoencoderKL.from_pretrained( | |
| model_id, subfolder="vae", torch_dtype=torch_dtype | |
| ) | |
| if quantize: | |
| quantize_(transformer, int8_weight_only()) | |
| quantize_(text_encoder, int8_weight_only()) | |
| quantize_(text_encoder_2, int8_weight_only()) | |
| quantize_(vae, int8_weight_only()) | |
| # Initialize the pipeline now. | |
| pipe = pipeline_class.from_pretrained( | |
| model_id, | |
| transformer=transformer, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| torch_dtype=torch_dtype | |
| ) | |
| return pipe | |
| def mask_decode(encoded_mask,image_shape=[512,512]): | |
| length=image_shape[0]*image_shape[1] | |
| mask_array=np.zeros((length,)) | |
| for i in range(0,len(encoded_mask),2): | |
| splice_len=min(encoded_mask[i+1],length-encoded_mask[i]) | |
| for j in range(splice_len): | |
| mask_array[encoded_mask[i]+j]=1 | |
| mask_array=mask_array.reshape(image_shape[0], image_shape[1]) | |
| # to avoid annotation errors in boundary | |
| mask_array[0,:]=1 | |
| mask_array[-1,:]=1 | |
| mask_array[:,0]=1 | |
| mask_array[:,-1]=1 | |
| return mask_array | |
| def mask_interpolate(mask, size=128): | |
| mask = torch.tensor(mask) | |
| mask = F.interpolate(mask[None, None, ...], size, mode='bicubic') | |
| mask = mask.squeeze() | |
| return mask | |
| def get_blend_word_index(prompt, word, tokenizer): | |
| input_ids = tokenizer(prompt).input_ids | |
| blend_ids = tokenizer(word, add_special_tokens=False).input_ids | |
| index = [] | |
| for i, id in enumerate(input_ids): | |
| # Ignore common token | |
| if id < 100: | |
| continue | |
| if id in blend_ids: | |
| index.append(i) | |
| return index | |
| def find_token_id_differences(prompt1, prompt2, tokenizer): | |
| # Tokenize inputs and get input IDs | |
| tokens1 = tokenizer.encode(prompt1, add_special_tokens=False) | |
| tokens2 = tokenizer.encode(prompt2, add_special_tokens=False) | |
| # Get sequence matcher output | |
| seq_matcher = difflib.SequenceMatcher(None, tokens1, tokens2) | |
| diff1_indices, diff1_ids = [], [] | |
| diff2_indices, diff2_ids = [], [] | |
| for opcode, a_start, a_end, b_start, b_end in seq_matcher.get_opcodes(): | |
| if opcode in ['replace', 'delete']: | |
| diff1_indices.extend(range(a_start, a_end)) | |
| diff1_ids.extend(tokens1[a_start:a_end]) | |
| if opcode in ['replace', 'insert']: | |
| diff2_indices.extend(range(b_start, b_end)) | |
| diff2_ids.extend(tokens2[b_start:b_end]) | |
| return { | |
| 'prompt_1': {'index': diff1_indices, 'id': diff1_ids}, | |
| 'prompt_2': {'index': diff2_indices, 'id': diff2_ids} | |
| } | |
| def find_word_token_indices(prompt, word, tokenizer): | |
| # Tokenize with offsets to track word positions | |
| encoding = tokenizer(prompt, return_offsets_mapping=True, add_special_tokens=False) | |
| tokens = encoding.tokens() | |
| offsets = encoding.offset_mapping # Start and end positions of tokens in the original text | |
| word_indices = [] | |
| # Normalize the word for comparison | |
| word_tokens = tokenizer(word, add_special_tokens=False).tokens() | |
| # Find matching token sequences | |
| for i in range(len(tokens) - len(word_tokens) + 1): | |
| if tokens[i : i + len(word_tokens)] == word_tokens: | |
| word_indices.extend(range(i, i + len(word_tokens))) | |
| return word_indices |