import torch from transformers import T5EncoderModel, BitsAndBytesConfig from diffusers import FluxKontextPipeline class KontextBackend: def __init__(self, model_id, optimized_model_path=None): self.model_id = model_id self.optimized_model_path = optimized_model_path self.pipeline = None def load(self): print(f"Loading Kontext backend from {self.model_id}...") if self.optimized_model_path: print(f"Loading optimized transformer from {self.optimized_model_path}...") # Load the optimized transformer (Nunchaku style! *hyah!*) try: from nunchaku import NunchakuFluxTransformer2dModel except ImportError: print("Oops, nunchaku not found! Please install it for optimized magic.") raise transformer = NunchakuFluxTransformer2dModel.from_pretrained(self.optimized_model_path) text_quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True ) text_encoder_2_4bit = T5EncoderModel.from_pretrained( self.model_id, subfolder="text_encoder_2", quantization_config=text_quant_config, torch_dtype=torch.bfloat16 # bfloat16 for your NVIDIA setup—faster magic! ) # Load the pipeline with the optimized transformer # We need FluxKontextPipeline for editing magic! pipeline = FluxKontextPipeline.from_pretrained( self.model_id, text_encoder_2=text_encoder_2_4bit, transformer=transformer, torch_dtype=torch.bfloat16, ) else: print("No optimized model path provided for KontextBackend. Falling back to standard loading if possible, or maybe we should insist on one?") # Original code implied usage of optimized model for Kontext was the main path, but let's support standard if needed, # or minimally just load standard logic if that was the fallback. # Looking at original code: "if args.optimized_model: ... else: ... Flux2Pipeline" # Wait, the original code fell back to Flux2Pipeline if no optimized model was present! # The user request says: "create KontextBackend.py that creates a pipeline from base and optional optimized paths" # So KontextBackend *should* support both optimized and unoptimized? Or was the fallback in original code actually switching to Flux2? # Original code: # if args.optimized_model: # # Load Nunchaku stuff # pipeline = FluxKontextPipeline(...) # else: # # Load standard stuff # pipeline = Flux2Pipeline(...) # # The USER request says: "KontextBackend.py that creates a pipeline from base and optional optimized paths". # This implies if I choose "kontext" backend but don't provide optimized path, it should still load a FluxKontextPipeline (presumably unoptimized/standard). # However, FluxKontextPipeline might expect specific components. # Let's assume standard loading for FluxKontextPipeline if no optimized model is separate. print(f"Loading standard FluxKontextPipeline from {self.model_id}...") # Assuming standard 4-bit loading for memory savings similar to before quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) # Use basic from_pretrained pipeline = FluxKontextPipeline.from_pretrained( self.model_id, torch_dtype=torch.bfloat16 # We might need quantization for components if memory is tight, but from_pretrained handles a lot. # Let's keep it simple for now as we don't have the Nunchaku specific loading here. ) # Actually, if we look at how specialized the optimized loading was, standard loading might just be: # pipeline = FluxKontextPipeline.from_pretrained(model_id, torch_dtype=...) self.pipeline = pipeline self.pipeline.to("cuda") # Additional setup if needed (like offload) # self.pipeline.enable_model_cpu_offload() # User code had this for optimized path return self.pipeline, self.pipeline