File size: 4,814 Bytes
1e103b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from transformers import T5EncoderModel, BitsAndBytesConfig
from diffusers import FluxKontextPipeline

class KontextBackend:
    def __init__(self, model_id, optimized_model_path=None):
        self.model_id = model_id
        self.optimized_model_path = optimized_model_path
        self.pipeline = None

    def load(self):
        print(f"Loading Kontext backend from {self.model_id}...")
        
        if self.optimized_model_path:
            print(f"Loading optimized transformer from {self.optimized_model_path}...")
            # Load the optimized transformer (Nunchaku style! *hyah!*)
            try:
                from nunchaku import NunchakuFluxTransformer2dModel
            except ImportError:
                 print("Oops, nunchaku not found! Please install it for optimized magic.")
                 raise

            transformer = NunchakuFluxTransformer2dModel.from_pretrained(self.optimized_model_path)

            text_quant_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True
            )
            
            text_encoder_2_4bit = T5EncoderModel.from_pretrained(
                self.model_id,
                subfolder="text_encoder_2",
                quantization_config=text_quant_config,
                torch_dtype=torch.bfloat16  # bfloat16 for your NVIDIA setup—faster magic!
            )
            
            # Load the pipeline with the optimized transformer
            # We need FluxKontextPipeline for editing magic!
            pipeline = FluxKontextPipeline.from_pretrained(
                self.model_id,
                text_encoder_2=text_encoder_2_4bit,
                transformer=transformer,
                torch_dtype=torch.bfloat16,
            )
        else:
            print("No optimized model path provided for KontextBackend. Falling back to standard loading if possible, or maybe we should insist on one?")
            # Original code implied usage of optimized model for Kontext was the main path, but let's support standard if needed,
            # or minimally just load standard logic if that was the fallback. 
            # Looking at original code: "if args.optimized_model: ... else: ... Flux2Pipeline" 
            # Wait, the original code fell back to Flux2Pipeline if no optimized model was present!
            # The user request says: "create KontextBackend.py that creates a pipeline from base and optional optimized paths"
            # So KontextBackend *should* support both optimized and unoptimized? Or was the fallback in original code actually switching to Flux2?
            # Original code: 
            # if args.optimized_model: 
            #    # Load Nunchaku stuff 
            #    pipeline = FluxKontextPipeline(...)
            # else:
            #    # Load standard stuff
            #    pipeline = Flux2Pipeline(...) 
            #
            # The USER request says: "KontextBackend.py that creates a pipeline from base and optional optimized paths".
            # This implies if I choose "kontext" backend but don't provide optimized path, it should still load a FluxKontextPipeline (presumably unoptimized/standard).
            # However, FluxKontextPipeline might expect specific components.
            # Let's assume standard loading for FluxKontextPipeline if no optimized model is separate.
            
            print(f"Loading standard FluxKontextPipeline from {self.model_id}...")
            # Assuming standard 4-bit loading for memory savings similar to before
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
            )
            
            # Use basic from_pretrained
            pipeline = FluxKontextPipeline.from_pretrained(
                self.model_id,
                torch_dtype=torch.bfloat16
                # We might need quantization for components if memory is tight, but from_pretrained handles a lot.
                # Let's keep it simple for now as we don't have the Nunchaku specific loading here.
            )
            # Actually, if we look at how specialized the optimized loading was, standard loading might just be:
            # pipeline = FluxKontextPipeline.from_pretrained(model_id, torch_dtype=...)
            
        self.pipeline = pipeline
        self.pipeline.to("cuda")
        
        # Additional setup if needed (like offload)
        # self.pipeline.enable_model_cpu_offload() # User code had this for optimized path
        
        return self.pipeline, self.pipeline