catplusplus commited on
Commit
1e103b7
·
verified ·
1 Parent(s): 8c09cde

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
chat_template.jinja ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_extra_keys(json_dict, handled_keys) %}
2
+ {%- if json_dict is mapping %}
3
+ {%- for json_key in json_dict if json_key not in handled_keys %}
4
+ {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
5
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
6
+ {%- else %}
7
+ {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
8
+ {%- endif %}
9
+ {%- endfor %}
10
+ {%- endif %}
11
+ {% endmacro %}
12
+
13
+ {%- if messages[0]["role"] == "system" %}
14
+ {%- set system_message = messages[0]["content"] %}
15
+ {%- set loop_messages = messages[1:] %}
16
+ {%- else %}
17
+ {%- set loop_messages = messages %}
18
+ {%- endif %}
19
+
20
+ {%- if not tools is defined %}
21
+ {%- set tools = [] %}
22
+ {%- endif %}
23
+
24
+ {%- if system_message is defined %}
25
+ {{- "<|im_start|>system\n" + system_message }}
26
+ {%- else %}
27
+ {%- if tools is iterable and tools | length > 0 %}
28
+ {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }}
29
+ {%- endif %}
30
+ {%- endif %}
31
+ {%- if tools is iterable and tools | length > 0 %}
32
+ {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }}
33
+ {{- "<tools>" }}
34
+ {%- for tool in tools %}
35
+ {%- if tool.function is defined %}
36
+ {%- set tool = tool.function %}
37
+ {%- endif %}
38
+ {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
39
+ {%- if tool.description is defined %}
40
+ {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
41
+ {%- endif %}
42
+ {{- '\n<parameters>' }}
43
+ {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
44
+ {%- for param_name, param_fields in tool.parameters.properties|items %}
45
+ {{- '\n<parameter>' }}
46
+ {{- '\n<name>' ~ param_name ~ '</name>' }}
47
+ {%- if param_fields.type is defined %}
48
+ {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
49
+ {%- endif %}
50
+ {%- if param_fields.description is defined %}
51
+ {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
52
+ {%- endif %}
53
+ {%- set handled_keys = ['name', 'type', 'description'] %}
54
+ {{- render_extra_keys(param_fields, handled_keys) }}
55
+ {{- '\n</parameter>' }}
56
+ {%- endfor %}
57
+ {%- endif %}
58
+ {% set handled_keys = ['type', 'properties'] %}
59
+ {{- render_extra_keys(tool.parameters, handled_keys) }}
60
+ {{- '\n</parameters>' }}
61
+ {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
62
+ {{- render_extra_keys(tool, handled_keys) }}
63
+ {{- '\n</function>' }}
64
+ {%- endfor %}
65
+ {{- "\n</tools>" }}
66
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
67
+ {%- endif %}
68
+ {%- if system_message is defined %}
69
+ {{- '<|im_end|>\n' }}
70
+ {%- else %}
71
+ {%- if tools is iterable and tools | length > 0 %}
72
+ {{- '<|im_end|>\n' }}
73
+ {%- endif %}
74
+ {%- endif %}
75
+ {%- for message in loop_messages %}
76
+ {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
77
+ {{- '<|im_start|>' + message.role }}
78
+ {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}
79
+ {{- '\n' + message.content | trim + '\n' }}
80
+ {%- endif %}
81
+ {%- for tool_call in message.tool_calls %}
82
+ {%- if tool_call.function is defined %}
83
+ {%- set tool_call = tool_call.function %}
84
+ {%- endif %}
85
+ {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
86
+ {%- if tool_call.arguments is defined %}
87
+ {%- for args_name, args_value in tool_call.arguments|items %}
88
+ {{- '<parameter=' + args_name + '>\n' }}
89
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
90
+ {{- args_value }}
91
+ {{- '\n</parameter>\n' }}
92
+ {%- endfor %}
93
+ {%- endif %}
94
+ {{- '</function>\n</tool_call>' }}
95
+ {%- endfor %}
96
+ {{- '<|im_end|>\n' }}
97
+ {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %}
98
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
99
+ {%- elif message.role == "tool" %}
100
+ {%- if loop.previtem and loop.previtem.role != "tool" %}
101
+ {{- '<|im_start|>user\n' }}
102
+ {%- endif %}
103
+ {{- '<tool_response>\n' }}
104
+ {{- message.content }}
105
+ {{- '\n</tool_response>\n' }}
106
+ {%- if not loop.last and loop.nextitem.role != "tool" %}
107
+ {{- '<|im_end|>\n' }}
108
+ {%- elif loop.last %}
109
+ {{- '<|im_end|>\n' }}
110
+ {%- endif %}
111
+ {%- else %}
112
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
113
+ {%- endif %}
114
+ {%- endfor %}
115
+ {%- if add_generation_prompt %}
116
+ {{- '<|im_start|>assistant\n' }}
117
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 2560,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 9728,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention"
52
+ ],
53
+ "max_position_embeddings": 40960,
54
+ "max_window_layers": 36,
55
+ "model_type": "qwen3",
56
+ "num_attention_heads": 32,
57
+ "num_hidden_layers": 36,
58
+ "num_key_value_heads": 8,
59
+ "pad_token_id": null,
60
+ "quantization_config": {
61
+ "config_groups": {
62
+ "group_0": {
63
+ "format": "nvfp4-pack-quantized",
64
+ "input_activations": {
65
+ "actorder": null,
66
+ "block_structure": null,
67
+ "dynamic": "local",
68
+ "group_size": 16,
69
+ "num_bits": 4,
70
+ "observer": "static_minmax",
71
+ "observer_kwargs": {},
72
+ "scale_dtype": "torch.float8_e4m3fn",
73
+ "strategy": "tensor_group",
74
+ "symmetric": true,
75
+ "type": "float",
76
+ "zp_dtype": null
77
+ },
78
+ "output_activations": null,
79
+ "targets": [
80
+ "Linear"
81
+ ],
82
+ "weights": {
83
+ "actorder": null,
84
+ "block_structure": null,
85
+ "dynamic": false,
86
+ "group_size": 16,
87
+ "num_bits": 4,
88
+ "observer": "memoryless_minmax",
89
+ "observer_kwargs": {},
90
+ "scale_dtype": "torch.float8_e4m3fn",
91
+ "strategy": "tensor_group",
92
+ "symmetric": true,
93
+ "type": "float",
94
+ "zp_dtype": null
95
+ }
96
+ }
97
+ },
98
+ "format": "nvfp4-pack-quantized",
99
+ "global_compression_ratio": null,
100
+ "ignore": [
101
+ "lm_head"
102
+ ],
103
+ "kv_cache_scheme": null,
104
+ "quant_method": "compressed-tensors",
105
+ "quantization_status": "compressed",
106
+ "sparsity_config": {},
107
+ "transform_config": {},
108
+ "version": "0.15.1.dev14+g01a1c9a"
109
+ },
110
+ "rms_norm_eps": 1e-06,
111
+ "rope_parameters": {
112
+ "rope_theta": 1000000,
113
+ "rope_type": "default"
114
+ },
115
+ "sliding_window": null,
116
+ "tie_word_embeddings": true,
117
+ "transformers_version": "5.2.0",
118
+ "use_cache": true,
119
+ "use_sliding_window": false,
120
+ "vocab_size": 151936
121
+ }
extras/Flux2Backend.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Mistral3ForConditionalGeneration, PixtralProcessor, BitsAndBytesConfig
3
+ from diffusers import Flux2Pipeline, AutoencoderKLFlux2, Flux2Transformer2DModel
4
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
5
+
6
+ class Flux2Backend:
7
+ def __init__(self, model_id):
8
+ self.model_id = model_id
9
+ self.pipeline = None
10
+
11
+ def load(self):
12
+ print(f"Loading Flux2 backend from {self.model_id}...")
13
+
14
+ quantization_config = BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_quant_type="nf4",
17
+ bnb_4bit_compute_dtype=torch.float16,
18
+ bnb_4bit_use_double_quant=True,
19
+ )
20
+
21
+ # Scheduler
22
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
23
+ self.model_id,
24
+ subfolder="scheduler",
25
+ torch_dtype=torch.bfloat16
26
+ )
27
+
28
+ # VAE - loaded manually with full precision
29
+ vae = AutoencoderKLFlux2.from_pretrained(
30
+ self.model_id,
31
+ subfolder="vae",
32
+ torch_dtype=torch.float16
33
+ )
34
+
35
+ tokenizer = PixtralProcessor.from_pretrained(
36
+ self.model_id,
37
+ subfolder="tokenizer",
38
+ torch_dtype=torch.float16
39
+ )
40
+
41
+ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
42
+ self.model_id,
43
+ subfolder="text_encoder",
44
+ torch_dtype=torch.float16,
45
+ quantization_config=quantization_config
46
+ )
47
+
48
+ dit = Flux2Transformer2DModel.from_pretrained(
49
+ self.model_id,
50
+ subfolder="transformer",
51
+ torch_dtype=torch.float16,
52
+ quantization_config=quantization_config
53
+ )
54
+
55
+
56
+ # Standard loading without Nunchaku optimization
57
+ # Constructing pipeline manually rather than from_pretrained
58
+ pipeline = Flux2Pipeline(
59
+ scheduler=scheduler,
60
+ vae=vae,
61
+ text_encoder=text_encoder,
62
+ tokenizer=tokenizer,
63
+ transformer=dit,
64
+ )
65
+
66
+ self.pipeline = pipeline
67
+ self.pipeline.to("cuda")
68
+ self.pipeline.transformer.set_attention_backend("flash")
69
+
70
+ return self.pipeline, self.pipeline
extras/GlmBackend.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import diffusers
3
+ try:
4
+ from sdnq import SDNQConfig
5
+ from sdnq.common import use_torch_compile as triton_is_available
6
+ from sdnq.loader import apply_sdnq_options_to_model
7
+ SDNQ_AVAILABLE = True
8
+ except ImportError:
9
+ print("SDNQ not found, optimized GLM loading will be skipped.")
10
+ SDNQ_AVAILABLE = False
11
+
12
+ class GlmBackend:
13
+ def __init__(self, model_id="Disty0/GLM-Image-SDNQ-4bit-dynamic"):
14
+ self.model_id = model_id
15
+ self.pipeline = None
16
+
17
+ def load(self):
18
+ print(f"Loading GLM backend from {self.model_id}...")
19
+
20
+ # Load the pipeline
21
+ # Using bfloat16 as per request snippet
22
+ pipeline = diffusers.GlmImagePipeline.from_pretrained(
23
+ self.model_id,
24
+ torch_dtype=torch.bfloat16,
25
+ trust_remote_code=True,
26
+ )
27
+
28
+ if SDNQ_AVAILABLE:
29
+ # Enable INT8 MatMul for GPUs if Triton is available
30
+ if triton_is_available and (torch.cuda.is_available() or torch.xpu.is_available()):
31
+ print("Applying SDNQ optimizations (INT8 MatMul)...")
32
+ pipeline.transformer = apply_sdnq_options_to_model(pipeline.transformer, use_quantized_matmul=True)
33
+ # pipeline.transformer = torch.compile(pipeline.transformer) # Optional, commented out as in snippet
34
+ else:
35
+ print("Triton or CUDA/XPU not available, skipping SDNQ optimization.")
36
+
37
+ print("Enabling CPU offload for GLM pipeline...")
38
+ pipeline.enable_model_cpu_offload()
39
+
40
+
41
+ self.pipeline = pipeline
42
+
43
+ # The user stated: "this one uses same pipe line for image generation and editing"
44
+ # So we return the same pipeline for both.
45
+ return self.pipeline, self.pipeline
extras/ImageEditServer.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import io
4
+ import time
5
+ import torch
6
+ import uvicorn
7
+ import gc
8
+ import asyncio
9
+ import traceback
10
+ from typing import List, Optional, Union
11
+ from contextlib import asynccontextmanager
12
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
13
+ from pydantic import BaseModel
14
+ from PIL import Image, ImageOps
15
+
16
+ # Argument parsing
17
+ parser = argparse.ArgumentParser(description="Flux Image Edit Server with Nunchaku")
18
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
19
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
20
+ parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-Kontext-dev", help="Path or Repo ID of the base model")
21
+ parser.add_argument("--optimized-model", type=str, default=None, help="Path to the optimized Nunchaku model safetensors file")
22
+ parser.add_argument("--optimized-edit-model", type=str, default=None, help="Path to the optimized Nunchaku model safetensors file for editing (optional)")
23
+ parser.add_argument("--backend", type=str, default="kontext", choices=["kontext", "flux2", "qwen", "glm", "zimage"], help="Backend to use: 'kontext', 'flux2', 'qwen', 'glm', or 'zimage'")
24
+ parser.add_argument("--steps", type=int, default=28, help="Default number of inference steps")
25
+ parser.add_argument("--guidance-scale", type=float, default=3.5, help="Default guidance scale")
26
+ parser.add_argument("--qwenimage", action="store_true", help="Use QwenImageBackend (T2I only) instead of full Qwen edit backend")
27
+ parser.add_argument("--uma", action="store_true", help="Enable Unified Memory Architecture mode (load all to GPU, disable offload)")
28
+ parser.add_argument(
29
+ "--nvfp4-text-encoder",
30
+ type=str,
31
+ default=None,
32
+ help=(
33
+ "Path to an NVFP4-pack-quantized HuggingFace text encoder "
34
+ "(compressed-tensors format). Currently honoured by the zimage backend; "
35
+ "swaps in vLLM's W4A4 NVFP4 CUTLASS GEMM for ~4x text-encoder VRAM savings."
36
+ ),
37
+ )
38
+ args = parser.parse_args()
39
+
40
+ @asynccontextmanager
41
+ async def lifespan(app: FastAPI):
42
+ # Startup logic
43
+ load_model()
44
+ yield
45
+ # Shutdown logic (if any) could go here
46
+
47
+ app = FastAPI(lifespan=lifespan)
48
+
49
+ # Global components
50
+ IMAGE_DIMENSION_ALIGNMENT = 32
51
+ pipeline = None
52
+ edit_pipeline = None
53
+ request_lock = asyncio.Lock()
54
+ is_sleeping_flag = False
55
+ sleep_requested = False
56
+
57
+ def load_model():
58
+ global pipeline, edit_pipeline
59
+
60
+ try:
61
+ if args.backend == "kontext":
62
+ import KontextBackend
63
+ print(f"Initializing KontextBackend...")
64
+ backend = KontextBackend.KontextBackend(args.model, args.optimized_model)
65
+ pipeline, edit_pipeline = backend.load()
66
+ elif args.backend == "flux2":
67
+ import Flux2Backend
68
+ print(f"Initializing Flux2Backend...")
69
+ backend = Flux2Backend.Flux2Backend(args.model)
70
+ pipeline, edit_pipeline = backend.load()
71
+ elif args.backend == "glm":
72
+ import GlmBackend
73
+ print(f"Initializing GlmBackend...")
74
+ # Use provided model or default to the one in the snippet if args.model is generic
75
+ # The user might pass the specific GLM model via --model, or we default in GlmBackend.
76
+ # Let's pass args.model if it's not the default flux one, otherwise let GlmBackend use its default.
77
+ model_to_use = args.model if args.model != "black-forest-labs/FLUX.1-Kontext-dev" else "Disty0/GLM-Image-SDNQ-4bit-dynamic"
78
+ backend = GlmBackend.GlmBackend(model_to_use)
79
+ pipeline, edit_pipeline = backend.load()
80
+ elif args.backend.startswith("qwen"):
81
+ if args.qwenimage:
82
+ import QwenImageBackend
83
+ print(f"Initializing QwenImageBackend (T2I only)...")
84
+ backend = QwenImageBackend.QwenImageBackend(args.model, args.optimized_model)
85
+ pipeline, edit_pipeline = backend.load()
86
+ else:
87
+ import QwenBackend
88
+ print(f"Initializing QwenBackend...")
89
+ backend = QwenBackend.QwenBackend(args.model, args.optimized_model, optimized_edit_model_path=args.optimized_edit_model, uma=args.uma)
90
+ pipeline, edit_pipeline = backend.load()
91
+ elif args.backend == "zimage":
92
+ import ZImageTurboBackend
93
+ print(f"Initializing ZImageTurboBackend...")
94
+ backend = ZImageTurboBackend.ZImageTurboBackend(
95
+ args.model,
96
+ args.optimized_model,
97
+ uma=args.uma,
98
+ nvfp4_text_encoder_path=args.nvfp4_text_encoder,
99
+ )
100
+ pipeline, edit_pipeline = backend.load()
101
+ else:
102
+ raise ValueError(f"Unknown backend: {args.backend}")
103
+
104
+ except Exception as e:
105
+ print(f"Oh no! The model refused to wake up: {e}")
106
+ raise e
107
+
108
+ # Enable progress bar for diffusers
109
+ import diffusers.utils.logging
110
+ diffusers.utils.logging.enable_progress_bar()
111
+ diffusers.utils.logging.set_verbosity_info()
112
+
113
+ print("Model loaded successfully! Ready for editing quests!")
114
+
115
+ def flush():
116
+ gc.collect()
117
+ torch.cuda.empty_cache()
118
+
119
+
120
+ class ImageGenerationRequest(BaseModel):
121
+ prompt: str
122
+ n: int = 1
123
+ size: str = "1024x1024"
124
+ response_format: str = "b64_json"
125
+ quality: str = "standard"
126
+ style: str = "vivid"
127
+ num_inference_steps: Optional[int] = None
128
+ guidance_scale: Optional[float] = None
129
+ negative_prompt: Optional[str] = None
130
+ seed: Optional[int] = None
131
+
132
+
133
+ @app.post("/v1/sleep")
134
+ async def sleep_endpoint():
135
+ global is_sleeping_flag, sleep_requested
136
+ sleep_requested = True
137
+ try:
138
+ async with request_lock:
139
+ if not is_sleeping_flag and sleep_requested:
140
+ print("Sleep requested, moving models to CPU...")
141
+ for p in [pipeline, edit_pipeline]:
142
+ if not p: continue
143
+ for name, component in p.components.items():
144
+ if isinstance(component, torch.nn.Module):
145
+ # Special handling for Nunchaku which blocks .to() if offload is True
146
+ if hasattr(component, "set_offload") and getattr(component, "offload", False):
147
+ component.set_offload(False)
148
+ component._nunchaku_was_offloaded = True
149
+
150
+ try:
151
+ component.to("cpu")
152
+ except Exception as e:
153
+ pass
154
+ flush()
155
+ is_sleeping_flag = True
156
+ finally:
157
+ sleep_requested = False
158
+ return {"status": "sleep completed", "is_sleeping": is_sleeping_flag}
159
+
160
+ @app.post("/v1/wake_up")
161
+ async def wake_up_endpoint():
162
+ global is_sleeping_flag, sleep_requested
163
+ sleep_requested = False
164
+ async with request_lock:
165
+ if is_sleeping_flag:
166
+ print("Waking up, restoring models to CUDA...")
167
+ for p in [pipeline, edit_pipeline]:
168
+ if not p: continue
169
+ excluded = getattr(p, "_exclude_from_cpu_offload", [])
170
+ for name, component in p.components.items():
171
+ if isinstance(component, torch.nn.Module):
172
+ if getattr(component, "_nunchaku_was_offloaded", False):
173
+ component.set_offload(True, use_pin_memory=True, num_blocks_on_gpu=8)
174
+ for attr in ["img_in", "txt_in", "txt_norm", "time_text_embed", "norm_out", "proj_out"]:
175
+ if hasattr(component, attr):
176
+ try:
177
+ getattr(component, attr).to("cuda")
178
+ except Exception:
179
+ pass
180
+ component._nunchaku_was_offloaded = False
181
+ elif not hasattr(component, "_hf_hook") or name in excluded:
182
+ try:
183
+ component.to("cuda")
184
+ except Exception:
185
+ pass
186
+ is_sleeping_flag = False
187
+ return {"status": "awoken", "is_sleeping": False}
188
+
189
+ @app.get("/v1/is_sleeping")
190
+ async def is_sleeping_endpoint():
191
+ return {"is_sleeping": is_sleeping_flag}
192
+
193
+
194
+ @app.get("/v1/memory_stats")
195
+ async def memory_stats_endpoint():
196
+ """Lightweight introspection endpoint that returns PyTorch's CUDA allocator
197
+ snapshot. Used to diagnose VRAM/UMA bloat without restarting the server."""
198
+ stats = {}
199
+ if torch.cuda.is_available():
200
+ stats["allocated_gb"] = torch.cuda.memory_allocated() / 1e9
201
+ stats["reserved_gb"] = torch.cuda.memory_reserved() / 1e9
202
+ stats["max_allocated_gb"] = torch.cuda.max_memory_allocated() / 1e9
203
+ stats["max_reserved_gb"] = torch.cuda.max_memory_reserved() / 1e9
204
+ # Top allocations by size from the allocator snapshot (>=64 MiB)
205
+ try:
206
+ snap = torch.cuda.memory_snapshot()
207
+ blocks = []
208
+ for seg in snap:
209
+ for b in seg.get("blocks", []):
210
+ if b.get("state") == "active_allocated" and b.get("size", 0) >= 64 * 1024 * 1024:
211
+ blocks.append(b["size"])
212
+ blocks.sort(reverse=True)
213
+ stats["large_active_blocks_gb"] = [round(s / 1e9, 3) for s in blocks[:20]]
214
+ stats["large_active_blocks_total_gb"] = round(sum(blocks) / 1e9, 3)
215
+ stats["large_active_blocks_count"] = len(blocks)
216
+ except Exception as e:
217
+ stats["snapshot_error"] = str(e)
218
+ # Walk Python objects to find big tensors and group them
219
+ try:
220
+ import gc as _gc
221
+ seen = set()
222
+ big = []
223
+ for obj in _gc.get_objects():
224
+ try:
225
+ if isinstance(obj, torch.Tensor) and obj.is_cuda:
226
+ ptr = obj.data_ptr()
227
+ if ptr in seen or ptr == 0:
228
+ continue
229
+ seen.add(ptr)
230
+ sz = obj.element_size() * obj.numel()
231
+ if sz >= 16 * 1024 * 1024:
232
+ big.append((sz, tuple(obj.shape), str(obj.dtype)))
233
+ except Exception:
234
+ continue
235
+ big.sort(reverse=True)
236
+ # Group by (shape, dtype)
237
+ from collections import Counter
238
+ grouped = Counter((shape, dtype) for _, shape, dtype in big)
239
+ stats["big_tensor_groups"] = [
240
+ {"shape": list(shape), "dtype": dtype, "count": cnt,
241
+ "size_gb_each": round(
242
+ (1 if shape == () else (lambda l: __import__('functools').reduce(lambda a, b: a*b, l, 1))(shape)) * (
243
+ 8 if 'int64' in dtype or 'float64' in dtype else
244
+ 4 if 'int32' in dtype or 'float32' in dtype else
245
+ 2 if 'bfloat16' in dtype or 'float16' in dtype else 1
246
+ ) / 1e9, 4)}
247
+ for (shape, dtype), cnt in grouped.most_common(30)
248
+ ]
249
+ stats["big_tensor_count"] = len(big)
250
+ stats["big_tensor_total_gb"] = round(sum(s for s, _, _ in big) / 1e9, 3)
251
+ except Exception as e:
252
+ stats["walk_error"] = str(e)
253
+ return stats
254
+
255
+ @app.post("/v1/images/edits")
256
+ async def edit_image(
257
+ image: Union[List[UploadFile], UploadFile] = File(...),
258
+ prompt: str = Form(...),
259
+ n: int = Form(1),
260
+ size: str = Form("1024x1024"),
261
+ response_format: str = Form("b64_json"), # Default to b64_json
262
+ guidance_scale: Optional[float] = Form(None),
263
+ num_inference_steps: Optional[int] = Form(None),
264
+ negative_prompt: Optional[str] = Form(None),
265
+ seed: Optional[int] = Form(None)
266
+ ):
267
+ # Use CLI defaults if not provided
268
+ steps = num_inference_steps if num_inference_steps is not None else args.steps
269
+ cfg_scale = guidance_scale if guidance_scale is not None else args.guidance_scale
270
+ neg_prompt = negative_prompt if negative_prompt is not None else "" # Default empty for now, or maybe None?
271
+
272
+ generator = None
273
+ import random
274
+ if seed is None:
275
+ seed = random.randint(0, 2**32 - 1)
276
+
277
+ print(f"Using seed: {seed}")
278
+ generator = torch.Generator(device="cuda").manual_seed(seed)
279
+
280
+ if not edit_pipeline:
281
+ raise HTTPException(status_code=500, detail="Model not loaded")
282
+
283
+ if sleep_requested or is_sleeping_flag:
284
+ raise HTTPException(status_code=503, detail="Server is sleeping or trying to sleep.")
285
+
286
+ async with request_lock:
287
+ print(f"Received edit request: {prompt}")
288
+
289
+ # Processing the input image(s)
290
+ input_files = image if isinstance(image, list) else [image]
291
+ init_images = []
292
+
293
+ try:
294
+ for img_file in input_files:
295
+ await img_file.seek(0)
296
+ contents = await img_file.read()
297
+ img = Image.open(io.BytesIO(contents)).convert("RGB")
298
+ init_images.append(img)
299
+ except Exception as e:
300
+ raise HTTPException(status_code=400, detail=f"Invalid image file: {e}")
301
+
302
+ if not init_images:
303
+ raise HTTPException(status_code=400, detail="No images provided")
304
+
305
+ # Parse max target dimensions from requested size
306
+ try:
307
+ target_width, target_height = map(int, size.split("x"))
308
+ except ValueError:
309
+ target_width, target_height = 1024, 1024
310
+
311
+ # Calculate new dimensions preserving aspect ratio based on the first image
312
+ first_image = init_images[0]
313
+ orig_width, orig_height = first_image.size
314
+ scale = min(target_width / orig_width, target_height / orig_height)
315
+ new_width = int(orig_width * scale)
316
+ new_height = int(orig_height * scale)
317
+
318
+ # Ensure dimensions are aligned to 32 for compatibility (e.g. GLM-Image)
319
+ width = (new_width // IMAGE_DIMENSION_ALIGNMENT) * IMAGE_DIMENSION_ALIGNMENT
320
+ height = (new_height // IMAGE_DIMENSION_ALIGNMENT) * IMAGE_DIMENSION_ALIGNMENT
321
+
322
+ # Resize input images to match the calculated target size, padding if necessary
323
+ resized_images = []
324
+ for img in init_images:
325
+ if img.size != (width, height):
326
+ # Use ImageOps.pad to preserve aspect ratio and center in the target size
327
+ # This handles cases where subsequent images might have different ARs
328
+ img = ImageOps.pad(img, (width, height), method=Image.LANCZOS, color=(0, 0, 0))
329
+ resized_images.append(img)
330
+
331
+ # If single image, pass as item, if multiple, pass as list
332
+ # GLM pipeline has a bug where it checks len() on the input, so it must be a list
333
+ if len(resized_images) > 1 or args.backend == "glm":
334
+ image_input = resized_images
335
+ else:
336
+ image_input = resized_images[0]
337
+
338
+ response_images = []
339
+
340
+ try:
341
+ if args.backend.startswith("qwen"):
342
+ # Qwen specific parameters
343
+ # guidance_scale maps to true_cfg_scale
344
+ if args.qwenimage: # QwenImageBackend is T2I only, so it doesn't take an image
345
+ generated_images = edit_pipeline(
346
+ prompt=prompt,
347
+ height=height,
348
+ width=width,
349
+ num_inference_steps=steps,
350
+ true_cfg_scale=cfg_scale,
351
+ num_images_per_prompt=n,
352
+ generator=generator,
353
+ ).images
354
+ else: # Full Qwen edit backend takes an image (or list of images now)
355
+ generated_images = edit_pipeline(
356
+ image=image_input,
357
+ prompt=prompt,
358
+ height=height,
359
+ width=width,
360
+ negative_prompt=neg_prompt,
361
+ num_inference_steps=steps,
362
+ true_cfg_scale=cfg_scale,
363
+ num_images_per_prompt=n,
364
+ generator=generator,
365
+ ).images
366
+ else:
367
+ # Standard Flux/Kontext or GLM
368
+ # GLM I2I Fix: Manually move vision encoder to GPU because get_image_features escapes hooks
369
+ if args.backend == "glm" and hasattr(edit_pipeline, "vision_language_encoder"):
370
+ print("Manually moving GLM Vision Encoder to GPU...")
371
+ edit_pipeline.vision_language_encoder.to("cuda")
372
+
373
+ try:
374
+ generated_images = edit_pipeline(
375
+ image=image_input,
376
+ prompt=prompt,
377
+ height=height,
378
+ width=width,
379
+ num_inference_steps=steps,
380
+ guidance_scale=cfg_scale,
381
+ num_images_per_prompt=n,
382
+ generator=generator,
383
+ ).images
384
+ finally:
385
+ if args.backend == "glm" and hasattr(edit_pipeline, "vision_language_encoder"):
386
+ print("Moving GLM Vision Encoder back to CPU...")
387
+ edit_pipeline.vision_language_encoder.to("cpu")
388
+
389
+ for img in generated_images:
390
+ buffered = io.BytesIO()
391
+ img.save(buffered, format="PNG")
392
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
393
+
394
+ if response_format == "b64_json":
395
+ response_images.append({"b64_json": img_str})
396
+ else:
397
+ # If url is requested we can't really do it without storage, so we fallback or error?
398
+ # For now, let's just assume simple b64_json as per request
399
+ response_images.append({"b64_json": img_str}) # Fallback
400
+
401
+ except Exception as e:
402
+ print(f"Error during editing: {e}")
403
+ print(traceback.format_exc())
404
+ raise HTTPException(status_code=500, detail=str(e))
405
+ finally:
406
+ flush()
407
+
408
+ return {
409
+ "created": int(time.time()),
410
+ "data": response_images
411
+ }
412
+
413
+
414
+
415
+ @app.post("/v1/images/generations")
416
+ async def generate_image(request: ImageGenerationRequest):
417
+ if not pipeline:
418
+ raise HTTPException(status_code=500, detail="Model not loaded")
419
+
420
+ if sleep_requested or is_sleeping_flag:
421
+ raise HTTPException(status_code=503, detail="Server is sleeping or trying to sleep.")
422
+
423
+ async with request_lock:
424
+ #print(f"Received generation request: {request.prompt}")
425
+
426
+ # Parse size
427
+ try:
428
+ width, height = map(int, request.size.split("x"))
429
+ except ValueError:
430
+ width, height = 1024, 1024
431
+
432
+ # Ensure dimensions are aligned to 32
433
+ width = (width // IMAGE_DIMENSION_ALIGNMENT) * IMAGE_DIMENSION_ALIGNMENT
434
+ height = (height // IMAGE_DIMENSION_ALIGNMENT) * IMAGE_DIMENSION_ALIGNMENT
435
+
436
+ response_images = []
437
+
438
+ try:
439
+ # Generate images (no image argument for txt2img!)
440
+ steps = request.num_inference_steps if request.num_inference_steps is not None else args.steps
441
+ cfg_scale = request.guidance_scale if request.guidance_scale is not None else args.guidance_scale
442
+ # negative_prompt not in standard request body in original snippet, but we added it to model
443
+ neg_prompt = request.negative_prompt if request.negative_prompt is not None else ""
444
+
445
+ generator = None
446
+ import random
447
+ seed = request.seed
448
+ if seed is None:
449
+ seed = random.randint(0, 2**32 - 1)
450
+
451
+ print(f"Using seed: {seed}")
452
+ generator = torch.Generator(device="cuda").manual_seed(seed)
453
+
454
+ if args.backend.startswith("qwen"):
455
+ generated_images = pipeline(
456
+ prompt=request.prompt,
457
+ height=height,
458
+ width=width,
459
+ num_inference_steps=steps,
460
+ true_cfg_scale=cfg_scale,
461
+ num_images_per_prompt=request.n,
462
+ negative_prompt=neg_prompt,
463
+ generator=generator,
464
+ ).images
465
+ else:
466
+ generated_images = pipeline(
467
+ prompt=request.prompt,
468
+ height=height,
469
+ width=width,
470
+ num_inference_steps=steps,
471
+ guidance_scale=cfg_scale,
472
+ num_images_per_prompt=request.n,
473
+ generator=generator,
474
+ # Not passing negative_prompt here for generation unless we confirm support in standard Flux pipeline?
475
+ ).images
476
+
477
+ for img in generated_images:
478
+ buffered = io.BytesIO()
479
+ img.save(buffered, format="PNG")
480
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
481
+ response_images.append({"b64_json": img_str})
482
+
483
+ except Exception as e:
484
+ print(f"Error during generation: {e}")
485
+ print(traceback.format_exc())
486
+ raise HTTPException(status_code=500, detail=str(e))
487
+ finally:
488
+ flush()
489
+
490
+ return {
491
+ "created": int(time.time()),
492
+ "data": response_images
493
+ }
494
+
495
+ if __name__ == "__main__":
496
+ uvicorn.run(app, host=args.host, port=args.port)
extras/ImageGenClient.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import os
4
+ import base64
5
+ import time
6
+ import io
7
+ import requests
8
+ from PIL import Image
9
+
10
+ # Oh, hello there! Nikola here, ready to help this little client talk to the big server!
11
+ # It's like sending a messenger bird from our village to the capital!
12
+
13
+ def main():
14
+ # Peeking at the arguments... gotta make sure we have all our supplies for the journey!
15
+ parser = argparse.ArgumentParser(description="ImageGen Client - A little seeker tool!")
16
+ parser.add_argument("--host", type=str, default="localhost", help="Where the server lives (Host)")
17
+ parser.add_argument("--port", type=int, default=8000, help="The door to knock on (Port)")
18
+ parser.add_argument("--num_images", type=int, default=1, help="How many pictures to paint?")
19
+ parser.add_argument("--image_folder", type=str, default="generated_images", help="Where to keep our treasures")
20
+ # Changing defaults to None so we can use input image size if needed!
21
+ parser.add_argument("--width", type=int, default=None, help="Canvas width (default: 1024 or input image size)")
22
+ parser.add_argument("--height", type=int, default=None, help="Canvas height (default: 1024 or input image size)")
23
+
24
+ # New shiny tools for our quest!
25
+ parser.add_argument("--input", type=str, default=None, help="Path to an input image (for image-to-image magic!)")
26
+ parser.add_argument("--max-size", type=int, default=1024, help="Max size for the input image (we don't want it to get too heavy for the bird!)")
27
+
28
+ args = parser.parse_args()
29
+
30
+ # Reading the prompt from the spirits... I mean, stdin!
31
+ # "What do you desire to see?" *sparkle*
32
+ print("Waiting for a prompt from stdin... (Type something and press Ctrl+D!)")
33
+ try:
34
+ prompt = sys.stdin.read().strip()
35
+ except Exception as e:
36
+ print(f"Oh no! The spirits were silent (stdin error): {e}")
37
+ return
38
+
39
+ if not prompt:
40
+ print("Aww, the prompt was empty! The canvas remains blank.")
41
+ return
42
+
43
+ print(f"Yay! We got a prompt: '{prompt}'")
44
+
45
+ # Restoring the canvas size variables from the journey's start!
46
+ final_width = args.width
47
+ final_height = args.height
48
+
49
+ # Prepare prompt and payload
50
+ url_gen = f"http://{args.host}:{args.port}/v1/images/generations"
51
+ url_edit = f"http://{args.host}:{args.port}/v1/images/edits"
52
+
53
+ try:
54
+ if args.input:
55
+ print(f"Oh! You brought a reference image: {args.input}. Let's go to the Editing Shrine!")
56
+
57
+ # Prepare for multipart upload
58
+ # We need to open the image file effectively
59
+ if not os.path.exists(args.input):
60
+ print(f"Eek! I can't find the image at {args.input}")
61
+ return
62
+
63
+ # Open image to ensure it's valid and memory-friendly resize if needed
64
+ with Image.open(args.input) as img:
65
+ img = img.convert("RGB")
66
+ w, h = img.size
67
+ max_dim = max(w, h)
68
+ if max_dim > args.max_size:
69
+ scale = args.max_size / max_dim
70
+ new_w = int(w * scale)
71
+ new_h = int(h * scale)
72
+ print(f"Resizing big image from {w}x{h} to {new_w}x{new_h}. Compact and cute!")
73
+ img = img.resize((new_w, new_h), Image.LANCZOS)
74
+
75
+ if final_width is None: final_width = img.width
76
+ if final_height is None: final_height = img.height
77
+
78
+ # Save to buffer for upload
79
+ buffered = io.BytesIO()
80
+ img.save(buffered, format="PNG")
81
+ buffered.seek(0)
82
+ image_bytes = buffered.getvalue()
83
+
84
+ # Construct multipart payload
85
+ files = {
86
+ 'image': ('input.png', image_bytes, 'image/png')
87
+ }
88
+ data = {
89
+ 'prompt': prompt,
90
+ 'n': args.num_images,
91
+ 'size': f"{final_width}x{final_height}",
92
+ 'response_format': 'b64_json',
93
+ 'guidance_scale': 2.5 # Default specific to edit/kontext if needed
94
+ }
95
+
96
+ print(f"Sending input image to {url_edit}... *whoosh*")
97
+ response = requests.post(url_edit, files=files, data=data)
98
+
99
+ else:
100
+ # Standard Generation
101
+ print("Just a prompt? Off to the Creation Forge!")
102
+ if final_width is None: final_width = 1024
103
+ if final_height is None: final_height = 1024
104
+
105
+ payload = {
106
+ "prompt": prompt,
107
+ "n": args.num_images,
108
+ "size": f"{final_width}x{final_height}",
109
+ "response_format": "b64_json"
110
+ }
111
+
112
+ print(f"Sending prompt to {url_gen}... *sparkle*")
113
+ response = requests.post(url_gen, json=payload)
114
+
115
+ response.raise_for_status()
116
+
117
+ data = response.json()
118
+
119
+ # Making sure we have a chest for our treasures
120
+ if not os.path.exists(args.image_folder):
121
+ print(f"Creating a new treasure chest at {args.image_folder}...")
122
+ os.makedirs(args.image_folder)
123
+
124
+ # Unpacking the magic
125
+ images = data.get("data", [])
126
+ print(f"Ooh! The server sent back {len(images)} masterpieces!")
127
+
128
+ for i, img_data in enumerate(images):
129
+ # Decoding the spell
130
+ img_bytes = base64.b64decode(img_data["b64_json"])
131
+
132
+ timestamp = int(time.time())
133
+ filename = f"image_{timestamp}_{i}.png"
134
+ filepath = os.path.join(args.image_folder, filename)
135
+
136
+ with open(filepath, "wb") as f:
137
+ f.write(img_bytes)
138
+
139
+ print(f"Saved masterpiece #{i+1} to {filepath}! It sparkles!")
140
+
141
+ except requests.exceptions.ConnectionError:
142
+ print("Oh no! The server didn't answer. Is it sleeping? (Connection Refused)")
143
+ print("Maybe check if the host and port are correct? We tried: " + url)
144
+ except Exception as e:
145
+ print(f"Eek! Something went wrong on the journey: {e}")
146
+ # We'll give it a gentle hug and try to understand...
147
+ print("Don't worry, we can try again later!")
148
+
149
+ if __name__ == "__main__":
150
+ main()
extras/ImageGenServer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import io
4
+ import time
5
+ import torch
6
+ import uvicorn
7
+ import gc
8
+ import asyncio
9
+ from fastapi import FastAPI, HTTPException
10
+ from pydantic import BaseModel
11
+ from diffusers import FluxPipeline
12
+ from nunchaku import NunchakuFluxTransformer2dModel
13
+
14
+ # Argument parsing
15
+ parser = argparse.ArgumentParser(description="Flux Image Generation Server with Nunchaku")
16
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
17
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
18
+ parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-dev", help="Path or Repo ID of the base model")
19
+ parser.add_argument("--optimized-model", type=str, required=True, help="Path to the optimized Nunchaku model safetensors file")
20
+ args = parser.parse_args()
21
+
22
+ app = FastAPI()
23
+
24
+ # Global components
25
+ pipeline = None
26
+ request_lock = asyncio.Lock()
27
+
28
+ def load_model():
29
+ global pipeline
30
+
31
+ print(f"Loading base model from {args.model}...")
32
+ print(f"Loading optimized transformer from {args.optimized_model}...")
33
+
34
+ try:
35
+ # Load the optimized transformer
36
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(args.optimized_model)
37
+
38
+ # Load the pipeline with the optimized transformer
39
+ pipeline = FluxPipeline.from_pretrained(
40
+ args.model,
41
+ transformer=transformer,
42
+ torch_dtype=torch.bfloat16,
43
+ ).to("cuda")
44
+
45
+ pipeline.transformer.set_attention_backend("flash")
46
+ pipeline.enable_model_cpu_offload()
47
+ pipeline.enable_vae_tiling()
48
+ pipeline.enable_vae_slicing()
49
+
50
+ except Exception as e:
51
+ print(f"Error loading model: {e}")
52
+ raise e
53
+
54
+ print("Model loaded successfully!")
55
+
56
+ def flush():
57
+ gc.collect()
58
+ torch.cuda.empty_cache()
59
+
60
+ class ImageGenerationRequest(BaseModel):
61
+ prompt: str
62
+ n: int = 1
63
+ size: str = "1024x1024"
64
+ response_format: str = "b64_json"
65
+ quality: str = "standard"
66
+ style: str = "vivid"
67
+
68
+ @app.on_event("startup")
69
+ async def startup_event():
70
+ load_model()
71
+
72
+ @app.post("/v1/images/generations")
73
+ async def generate_image(request: ImageGenerationRequest):
74
+ if not pipeline:
75
+ raise HTTPException(status_code=500, detail="Model not loaded")
76
+
77
+ async with request_lock:
78
+ print(f"Received request: {request.prompt}")
79
+
80
+ # Parse size
81
+ try:
82
+ width, height = map(int, request.size.split("x"))
83
+ except ValueError:
84
+ width, height = 1024, 1024
85
+
86
+ # Flux requires dimensions to be multiples of 16 (or 8 depending on VAE)
87
+ # Standard Flux dev usually works well with 1024x1024
88
+ # We'll ensure they are divisible by 16 just in case
89
+ width = (width // 16) * 16
90
+ height = (height // 16) * 16
91
+
92
+ images = []
93
+
94
+ try:
95
+ # Generate images
96
+ generated_images = pipeline(
97
+ request.prompt,
98
+ height=height,
99
+ width=width,
100
+ num_inference_steps=4, # Standard for Flux Dev
101
+ guidance_scale=3.5, # Nunchaku example uses 3.5, previous code used 4.0. Let's stick to 3.5 or 4.0. Example says 3.5.
102
+ num_images_per_prompt=request.n
103
+ ).images
104
+
105
+ for image in generated_images:
106
+ buffered = io.BytesIO()
107
+ image.save(buffered, format="PNG")
108
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
109
+ images.append({"b64_json": img_str})
110
+
111
+ except Exception as e:
112
+ print(f"Error during generation: {e}")
113
+ raise HTTPException(status_code=500, detail=str(e))
114
+ finally:
115
+ flush()
116
+
117
+ return {
118
+ "created": int(time.time()),
119
+ "data": images
120
+ }
121
+
122
+ if __name__ == "__main__":
123
+ uvicorn.run(app, host=args.host, port=args.port)
extras/ImageGenServer_cpu.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import io
4
+ import time
5
+ import torch
6
+ import uvicorn
7
+ import numpy as np
8
+ import gc
9
+ import asyncio
10
+ from fastapi import FastAPI, HTTPException, Request
11
+ from accelerate import infer_auto_device_map, dispatch_model
12
+ from pydantic import BaseModel
13
+ from diffusers import (
14
+ Flux2Pipeline,
15
+ Flux2Transformer2DModel,
16
+ AutoencoderKLFlux2,
17
+ FlowMatchEulerDiscreteScheduler
18
+ )
19
+ from diffusers.pipelines.flux2.pipeline_flux2 import compute_empirical_mu, retrieve_timesteps
20
+ from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
21
+ from transformers import Mistral3ForConditionalGeneration, AutoProcessor
22
+
23
+ # Argument parsing
24
+ parser = argparse.ArgumentParser(description="Flux2 Image Generation Server")
25
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
26
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
27
+ parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-dev", help="Path or Repo ID of the model")
28
+ args = parser.parse_args()
29
+
30
+ app = FastAPI()
31
+
32
+ # Global components
33
+ text_encoder = None
34
+ tokenizer = None
35
+ transformer = None
36
+ vae = None
37
+ scheduler = None
38
+ image_processor = None
39
+ request_lock = asyncio.Lock()
40
+
41
+ # Device maps
42
+ text_encoder_map = None
43
+ transformer_map = None
44
+ vae_map = None
45
+
46
+ GPU_MEMORY_FRACTION = 0.90
47
+
48
+ def load_model():
49
+ global text_encoder, tokenizer, transformer, vae, scheduler, image_processor
50
+ global text_encoder_map, transformer_map, vae_map
51
+
52
+ print(f"Loading model from {args.model}...")
53
+
54
+ try:
55
+ print("Loading Flux2 components...")
56
+
57
+ # Calculate max memory per GPU
58
+ #max_memory = {}
59
+ #if torch.cuda.is_available():
60
+ # for i in range(torch.cuda.device_count()):
61
+ # total_mem = torch.cuda.get_device_properties(i).total_memory
62
+ # max_memory[i] = int(total_mem * GPU_MEMORY_FRACTION)
63
+
64
+ max_memory = {
65
+ 0: "5GB", # leave a little headroom
66
+ # 1: "10GB",
67
+ "cpu": "120GB" # your 128GB RAM minus OS
68
+ }
69
+
70
+ # Load Text Encoder (Mistral3) on CPU
71
+ print("Loading Text Encoder on CPU...")
72
+ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
73
+ args.model,
74
+ subfolder="text_encoder",
75
+ torch_dtype=torch.bfloat16,
76
+ device_map="cpu"
77
+ )
78
+ print("Calculating Text Encoder device map...")
79
+ text_encoder_map = infer_auto_device_map(text_encoder, max_memory=max_memory)
80
+
81
+ # Load Tokenizer on CPU
82
+ print("Loading Tokenizer on CPU...")
83
+ tokenizer = AutoProcessor.from_pretrained(
84
+ args.model,
85
+ subfolder="tokenizer",
86
+ device_map="cpu"
87
+ )
88
+
89
+ # Load Transformer on CPU
90
+ print("Loading Transformer on CPU...")
91
+ transformer = Flux2Transformer2DModel.from_pretrained(
92
+ args.model,
93
+ subfolder="transformer",
94
+ torch_dtype=torch.bfloat16,
95
+ device_map="cpu"
96
+ )
97
+ print("Calculating Transformer device map...")
98
+ transformer_map = infer_auto_device_map(transformer, max_memory=max_memory)
99
+
100
+ # Load VAE on CPU
101
+ print("Loading VAE on CPU...")
102
+ vae = AutoencoderKLFlux2.from_pretrained(
103
+ args.model,
104
+ subfolder="vae",
105
+ torch_dtype=torch.bfloat16,
106
+ device_map="cpu"
107
+ )
108
+ print("Calculating VAE device map...")
109
+ vae_map = infer_auto_device_map(vae, max_memory=max_memory)
110
+
111
+ # Initialize Scheduler
112
+ print("Initializing Scheduler...")
113
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
114
+ args.model,
115
+ subfolder="scheduler"
116
+ )
117
+
118
+ # Initialize Image Processor
119
+ print("Initializing Image Processor...")
120
+ # VAE scale factor logic from pipeline
121
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
122
+ image_processor = Flux2ImageProcessor(vae_scale_factor=vae_scale_factor * 2)
123
+
124
+ except Exception as e:
125
+ print(f"Error loading model: {e}")
126
+ raise e
127
+
128
+ print("Model loaded successfully!")
129
+
130
+ def flush():
131
+ gc.collect()
132
+ torch.cuda.empty_cache()
133
+
134
+ class ImageGenerationRequest(BaseModel):
135
+ prompt: str
136
+ n: int = 1
137
+ size: str = "1024x1024"
138
+ response_format: str = "b64_json"
139
+ quality: str = "standard"
140
+ style: str = "vivid"
141
+
142
+ @app.on_event("startup")
143
+ async def startup_event():
144
+ load_model()
145
+
146
+ @app.post("/v1/images/generations")
147
+ async def generate_image(request: ImageGenerationRequest):
148
+ if not transformer:
149
+ raise HTTPException(status_code=500, detail="Model not loaded")
150
+
151
+ async with request_lock:
152
+ print(f"Received request: {request.prompt}")
153
+
154
+ # Parse size
155
+ try:
156
+ width, height = map(int, request.size.split("x"))
157
+ except ValueError:
158
+ width, height = 1024, 1024
159
+
160
+ num_inference_steps = 28
161
+ guidance_scale = 4.0
162
+ max_sequence_length = 512
163
+ device = torch.device("cuda")
164
+ dtype = torch.bfloat16
165
+
166
+ images = []
167
+
168
+ # 1. Generate embeddings on CPU
169
+ print("Generating embeddings...")
170
+ flush()
171
+ prompt_embeds = Flux2Pipeline._get_mistral_3_small_prompt_embeds(
172
+ text_encoder=text_encoder,
173
+ tokenizer=tokenizer,
174
+ prompt=request.prompt,
175
+ # device=torch.device("cpu"),
176
+ max_sequence_length=max_sequence_length
177
+ )
178
+
179
+
180
+ # prompt_embeds = prompt_embeds.to("cuda")
181
+
182
+ # 2. Prepare Latents
183
+ # Flux latents are turned into 2x2 patches and packed.
184
+ # This means the latent width and height has to be divisible by the patch size.
185
+ # So the vae scale factor is multiplied by the patch size to account for this
186
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
187
+
188
+ height = height or 1024
189
+ width = width or 1024
190
+
191
+ # Resize to be divisible by vae_scale_factor * 2
192
+ height = 2 * (int(height) // (vae_scale_factor * 2))
193
+ width = 2 * (int(width) // (vae_scale_factor * 2))
194
+
195
+ num_channels_latents = transformer.config.in_channels // 4
196
+ shape = (1, num_channels_latents * 4, height // 2, width // 2)
197
+
198
+ # 3. Prepare IDs
199
+ # We need to prepare text_ids and latent_ids
200
+ # prompt_embeds shape: (batch_size, seq_len, hidden_dim)
201
+ batch_size, seq_len, _ = prompt_embeds.shape
202
+
203
+ # Repeat for num_images_per_prompt (assuming 1 for now per loop iteration)
204
+ # If request.n > 1, we loop outside or handle batching. Here we loop outside.
205
+
206
+ # Prepare text IDs
207
+ text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(device)
208
+
209
+ for _ in range(request.n):
210
+ # Generate random latents
211
+ latents = torch.randn(shape, device=device, dtype=dtype)
212
+
213
+ # Prepare latent IDs
214
+ latent_ids = Flux2Pipeline._prepare_latent_ids(latents).to(device)
215
+
216
+ # Pack latents
217
+ packed_latents = Flux2Pipeline._pack_latents(latents)
218
+
219
+ # 4. Prepare Timesteps
220
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
221
+ image_seq_len = packed_latents.shape[1]
222
+ mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
223
+ timesteps, num_inference_steps = retrieve_timesteps(
224
+ scheduler,
225
+ num_inference_steps,
226
+ device,
227
+ sigmas=sigmas,
228
+ mu=mu,
229
+ )
230
+
231
+ # --- SWAP TRANSFORMER TO CUDA ---
232
+ print("Moving Transformer to CUDA...")
233
+ flush()
234
+ dispatch_model(transformer, device_map=transformer_map)
235
+
236
+ # 5. Denoising Loop
237
+ print("Starting denoising loop on CUDA...")
238
+ scheduler.set_begin_index(0)
239
+
240
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
241
+ guidance = guidance.expand(packed_latents.shape[0])
242
+
243
+ for i, t in enumerate(timesteps):
244
+ start_time = time.time()
245
+ # broadcast to batch dimension
246
+ timestep = t.expand(packed_latents.shape[0]).to(packed_latents.dtype)
247
+
248
+ noise_pred = transformer(
249
+ hidden_states=packed_latents,
250
+ timestep=timestep / 1000,
251
+ guidance=guidance,
252
+ encoder_hidden_states=prompt_embeds,
253
+ txt_ids=text_ids,
254
+ img_ids=latent_ids,
255
+ return_dict=False,
256
+ )[0]
257
+
258
+ # step
259
+ packed_latents = scheduler.step(noise_pred, t, packed_latents, return_dict=False)[0]
260
+
261
+ step_time = time.time() - start_time
262
+ print(f"Step {i+1}/{num_inference_steps}: {step_time:.2f}s")
263
+
264
+ # --- SWAP TRANSFORMER TO CPU ---
265
+ print("Moving Transformer to CPU...")
266
+ transformer.to("cpu")
267
+ flush()
268
+
269
+ # --- SWAP VAE TO CUDA ---
270
+ print("Moving VAE to CUDA...")
271
+ dispatch_model(vae, device_map=vae_map)
272
+
273
+ # 6. Decode
274
+ print("Decoding on CUDA...")
275
+ # Move packed_latents to CUDA for decoding (already there, but ensuring)
276
+ packed_latents = packed_latents.to(device)
277
+ latent_ids = latent_ids.to(device)
278
+
279
+ latents = Flux2Pipeline._unpack_latents_with_ids(packed_latents, latent_ids)
280
+
281
+ latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
282
+ latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
283
+ latents.device, latents.dtype
284
+ )
285
+ latents = latents * latents_bn_std + latents_bn_mean
286
+ latents = Flux2Pipeline._unpatchify_latents(latents)
287
+
288
+ image = vae.decode(latents, return_dict=False)[0]
289
+ image = image_processor.postprocess(image, output_type="pil")[0]
290
+
291
+ # --- SWAP VAE TO CPU ---
292
+ print("Moving VAE to CPU...")
293
+ vae.to("cpu")
294
+
295
+ # Convert to base64
296
+ buffered = io.BytesIO()
297
+ image.save(buffered, format="PNG")
298
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
299
+ images.append({"b64_json": img_str})
300
+
301
+ return {
302
+ "created": int(time.time()),
303
+ "data": images
304
+ }
305
+
306
+ if __name__ == "__main__":
307
+ uvicorn.run(app, host=args.host, port=args.port)
extras/ImageGenServer_new.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import io
4
+ import time
5
+ import torch
6
+ import uvicorn
7
+ import gc
8
+ import asyncio
9
+ from typing import Optional
10
+ from fastapi import FastAPI, HTTPException
11
+ from pydantic import BaseModel
12
+ from diffusers import FluxPipeline, FluxKontextPipeline
13
+ from nunchaku import NunchakuFluxTransformer2dModel
14
+ from PIL import Image
15
+
16
+ # Argument parsing
17
+ parser = argparse.ArgumentParser(description="Flux Image Generation Server with Nunchaku")
18
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
19
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
20
+ parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-dev", help="Path or Repo ID of the base model")
21
+ parser.add_argument("--optimized-model", type=str, required=True, help="Path to the optimized Nunchaku model safetensors file")
22
+ args = parser.parse_args()
23
+
24
+ app = FastAPI()
25
+
26
+ # Global components
27
+ pipeline = None
28
+ img2img_pipeline = None
29
+ request_lock = asyncio.Lock()
30
+
31
+ def load_model():
32
+ global pipeline, img2img_pipeline
33
+
34
+ print(f"Loading base model from {args.model}...")
35
+ print(f"Loading optimized transformer from {args.optimized_model}...")
36
+
37
+ try:
38
+ # Load the optimized transformer
39
+ # Ensuring transformer is in bfloat16 to match the pipeline expectation
40
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(args.optimized_model)
41
+
42
+ # Load the pipeline with the optimized transformer
43
+ pipeline = FluxPipeline.from_pretrained(
44
+ args.model,
45
+ transformer=transformer,
46
+ torch_dtype=torch.bfloat16,
47
+ ).to("cuda")
48
+
49
+ # Load the Img2Img/Context pipeline sharing the same components
50
+ # We use strict component sharing to avoid VRAM duplication
51
+ print("Initializing FluxKontextPipeline for image inputs...")
52
+ # Since FluxKontextPipeline shares architecture with FluxPipeline, we can initialize it with the same components
53
+ img2img_pipeline = FluxKontextPipeline.from_pretrained(
54
+ args.model,
55
+ transformer=pipeline.transformer,
56
+ vae=pipeline.vae,
57
+ text_encoder=pipeline.text_encoder,
58
+ text_encoder_2=pipeline.text_encoder_2,
59
+ tokenizer=pipeline.tokenizer,
60
+ tokenizer_2=pipeline.tokenizer_2,
61
+ scheduler=pipeline.scheduler,
62
+ torch_dtype=torch.bfloat16
63
+ ).to("cuda")
64
+
65
+ # Enable CPU offload for the main pipeline.
66
+ # Since components are shared, this should handle memory management for both.
67
+ pipeline.enable_model_cpu_offload()
68
+ # img2img_pipeline.enable_model_cpu_offload() # Avoid double hook registration
69
+
70
+ except Exception as e:
71
+ print(f"Error loading model: {e}")
72
+ raise e
73
+
74
+ print("Model loaded successfully!")
75
+
76
+ def flush():
77
+ gc.collect()
78
+ torch.cuda.empty_cache()
79
+
80
+ class ImageGenerationRequest(BaseModel):
81
+ prompt: str
82
+ n: int = 1
83
+ size: str = "1024x1024"
84
+ response_format: str = "b64_json"
85
+ quality: str = "standard"
86
+ style: str = "vivid"
87
+ image: Optional[str] = None # Base64 encoded image
88
+
89
+ @app.on_event("startup")
90
+ async def startup_event():
91
+ load_model()
92
+
93
+ @app.post("/v1/images/generations")
94
+ async def generate_image(request: ImageGenerationRequest):
95
+ if not pipeline:
96
+ raise HTTPException(status_code=500, detail="Model not loaded")
97
+
98
+ async with request_lock:
99
+ print(f"Received request: {request.prompt}")
100
+
101
+ # Parse size
102
+ try:
103
+ width, height = map(int, request.size.split("x"))
104
+ except ValueError:
105
+ width, height = 1024, 1024
106
+
107
+ # Flux requires dimensions to be multiples of 16 (or 8 depending on VAE)
108
+ # Standard Flux dev usually works well with 1024x1024
109
+ # We'll ensure they are divisible by 16 just in case
110
+ width = (width // 16) * 16
111
+ height = (height // 16) * 16
112
+
113
+ images = []
114
+
115
+ try:
116
+ input_image = None
117
+ if request.image:
118
+ try:
119
+ # Handle data URI if present
120
+ img_data = request.image
121
+ if "," in img_data:
122
+ img_data = img_data.split(",")[1]
123
+
124
+ input_bytes = base64.b64decode(img_data)
125
+ input_image = Image.open(io.BytesIO(input_bytes)).convert("RGB")
126
+ # Resize input image to match request size
127
+ input_image = input_image.resize((width, height), Image.LANCZOS)
128
+ print(f"Processed input image of size {input_image.size}")
129
+ except Exception as e:
130
+ print(f"Failed to decode input image: {e}")
131
+ raise HTTPException(status_code=400, detail="Invalid image data")
132
+
133
+ # Generate images
134
+ if input_image:
135
+ # Use FluxKontextPipeline
136
+ print("Running FluxKontextPipeline...")
137
+ generated_images = pipeline(
138
+ image=input_image,
139
+ prompt=request.prompt,
140
+ height=height,
141
+ width=width,
142
+ num_inference_steps=28,
143
+ guidance_scale=2.5, # Recommended for Kontext
144
+ num_images_per_prompt=request.n
145
+ ).images
146
+ else:
147
+ # Use standard FluxPipeline
148
+ print("Running FluxPipeline...")
149
+ generated_images = pipeline(
150
+ request.prompt,
151
+ height=height,
152
+ width=width,
153
+ num_inference_steps=28, # Standard for Flux Dev
154
+ guidance_scale=3.5, # Nunchaku example uses 3.5
155
+ num_images_per_prompt=request.n
156
+ ).images
157
+
158
+ for image in generated_images:
159
+ buffered = io.BytesIO()
160
+ image.save(buffered, format="PNG")
161
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
162
+ images.append({"b64_json": img_str})
163
+
164
+ except Exception as e:
165
+ print(f"Error during generation: {e}")
166
+ raise HTTPException(status_code=500, detail=str(e))
167
+ finally:
168
+ flush()
169
+
170
+ return {
171
+ "created": int(time.time()),
172
+ "data": images
173
+ }
174
+
175
+ if __name__ == "__main__":
176
+ uvicorn.run(app, host=args.host, port=args.port)
extras/KontextBackend.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5EncoderModel, BitsAndBytesConfig
3
+ from diffusers import FluxKontextPipeline
4
+
5
+ class KontextBackend:
6
+ def __init__(self, model_id, optimized_model_path=None):
7
+ self.model_id = model_id
8
+ self.optimized_model_path = optimized_model_path
9
+ self.pipeline = None
10
+
11
+ def load(self):
12
+ print(f"Loading Kontext backend from {self.model_id}...")
13
+
14
+ if self.optimized_model_path:
15
+ print(f"Loading optimized transformer from {self.optimized_model_path}...")
16
+ # Load the optimized transformer (Nunchaku style! *hyah!*)
17
+ try:
18
+ from nunchaku import NunchakuFluxTransformer2dModel
19
+ except ImportError:
20
+ print("Oops, nunchaku not found! Please install it for optimized magic.")
21
+ raise
22
+
23
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(self.optimized_model_path)
24
+
25
+ text_quant_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_compute_dtype=torch.bfloat16,
29
+ bnb_4bit_use_double_quant=True
30
+ )
31
+
32
+ text_encoder_2_4bit = T5EncoderModel.from_pretrained(
33
+ self.model_id,
34
+ subfolder="text_encoder_2",
35
+ quantization_config=text_quant_config,
36
+ torch_dtype=torch.bfloat16 # bfloat16 for your NVIDIA setup—faster magic!
37
+ )
38
+
39
+ # Load the pipeline with the optimized transformer
40
+ # We need FluxKontextPipeline for editing magic!
41
+ pipeline = FluxKontextPipeline.from_pretrained(
42
+ self.model_id,
43
+ text_encoder_2=text_encoder_2_4bit,
44
+ transformer=transformer,
45
+ torch_dtype=torch.bfloat16,
46
+ )
47
+ else:
48
+ print("No optimized model path provided for KontextBackend. Falling back to standard loading if possible, or maybe we should insist on one?")
49
+ # Original code implied usage of optimized model for Kontext was the main path, but let's support standard if needed,
50
+ # or minimally just load standard logic if that was the fallback.
51
+ # Looking at original code: "if args.optimized_model: ... else: ... Flux2Pipeline"
52
+ # Wait, the original code fell back to Flux2Pipeline if no optimized model was present!
53
+ # The user request says: "create KontextBackend.py that creates a pipeline from base and optional optimized paths"
54
+ # So KontextBackend *should* support both optimized and unoptimized? Or was the fallback in original code actually switching to Flux2?
55
+ # Original code:
56
+ # if args.optimized_model:
57
+ # # Load Nunchaku stuff
58
+ # pipeline = FluxKontextPipeline(...)
59
+ # else:
60
+ # # Load standard stuff
61
+ # pipeline = Flux2Pipeline(...)
62
+ #
63
+ # The USER request says: "KontextBackend.py that creates a pipeline from base and optional optimized paths".
64
+ # This implies if I choose "kontext" backend but don't provide optimized path, it should still load a FluxKontextPipeline (presumably unoptimized/standard).
65
+ # However, FluxKontextPipeline might expect specific components.
66
+ # Let's assume standard loading for FluxKontextPipeline if no optimized model is separate.
67
+
68
+ print(f"Loading standard FluxKontextPipeline from {self.model_id}...")
69
+ # Assuming standard 4-bit loading for memory savings similar to before
70
+ quantization_config = BitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_quant_type="nf4",
73
+ bnb_4bit_compute_dtype=torch.float16,
74
+ bnb_4bit_use_double_quant=True,
75
+ )
76
+
77
+ # Use basic from_pretrained
78
+ pipeline = FluxKontextPipeline.from_pretrained(
79
+ self.model_id,
80
+ torch_dtype=torch.bfloat16
81
+ # We might need quantization for components if memory is tight, but from_pretrained handles a lot.
82
+ # Let's keep it simple for now as we don't have the Nunchaku specific loading here.
83
+ )
84
+ # Actually, if we look at how specialized the optimized loading was, standard loading might just be:
85
+ # pipeline = FluxKontextPipeline.from_pretrained(model_id, torch_dtype=...)
86
+
87
+ self.pipeline = pipeline
88
+ self.pipeline.to("cuda")
89
+
90
+ # Additional setup if needed (like offload)
91
+ # self.pipeline.enable_model_cpu_offload() # User code had this for optimized path
92
+
93
+ return self.pipeline, self.pipeline
extras/NVFP4TextEncoder.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NVFP4 text encoder loader for diffusers image pipelines.
3
+
4
+ Loads a compressed-tensors NVFP4-pack-quantized HuggingFace causal LM and wraps
5
+ it so it can be plugged into ``diffusers.ZImagePipeline`` (or any pipeline
6
+ calling ``self.text_encoder(input_ids, attention_mask, output_hidden_states=True)``).
7
+
8
+ Strategy:
9
+ - Instantiate the HF model on the ``meta`` device (no real allocation).
10
+ - Walk every ``torch.nn.Linear`` and swap it for vLLM's ``ReplicatedLinear`` with
11
+ ``CompressedTensorsConfig`` derived from the checkpoint's
12
+ ``quantization_config``. This registers ``weight_packed`` / ``weight_scale`` /
13
+ ``*_global_scale`` parameters in the exact layout vLLM's
14
+ ``CompressedTensorsW4A4Fp4`` scheme expects.
15
+ - Materialise remaining (non-Linear) parameters (embeddings, RMSNorm, k/q norms)
16
+ on the target device & dtype.
17
+ - Stream the safetensors file and dispatch each tensor through the registered
18
+ vLLM ``weight_loader`` (which handles layout swizzling on
19
+ ``process_weights_after_loading``).
20
+ - Tie the LM head to the input embedding when ``config.tie_word_embeddings``.
21
+
22
+ The result is a regular ``nn.Module`` matching the HF model's call signature
23
+ (``forward(input_ids, attention_mask, output_hidden_states)``) -- usable directly
24
+ as ``ZImagePipeline.text_encoder``.
25
+
26
+ vLLM requires a minimal global context (distributed process group + model
27
+ parallel state + active VllmConfig) even at TP=1 because ``ReplicatedLinear``
28
+ queries the TP world size at construction. We bootstrap that lazily once.
29
+
30
+ Forced kernel: we set ``VLLM_NVFP4_GEMM_BACKEND=cutlass`` to skip
31
+ flashinfer-cutlass JIT (which needs the ``ninja`` binary on PATH). The vLLM
32
+ CUTLASS kernel is built into the wheel.
33
+ """
34
+ from __future__ import annotations
35
+
36
+ import json
37
+ import os
38
+ from collections.abc import Iterator
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn as nn
43
+
44
+
45
+ # ----------------------------------------------------------------------------
46
+ # One-time vLLM bootstrap (TP=1, no engine, just enough context for ReplicatedLinear)
47
+ # ----------------------------------------------------------------------------
48
+ _VLLM_BOOTSTRAPPED = False
49
+ _VLLM_CONFIG_CTX = None # holds the entered set_current_vllm_config context manager
50
+
51
+
52
+ def _bootstrap_vllm_once() -> None:
53
+ """Initialise the bits of vLLM that ReplicatedLinear needs at TP=1.
54
+
55
+ Idempotent. Uses ``gloo`` so it works without NCCL/CUDA-aware MPI and even
56
+ when CUDA is busy with the diffusion transformer.
57
+ """
58
+ global _VLLM_BOOTSTRAPPED, _VLLM_CONFIG_CTX
59
+ if _VLLM_BOOTSTRAPPED:
60
+ return
61
+
62
+ # Force CUTLASS to avoid flashinfer-cutlass JIT (requires `ninja` on PATH).
63
+ os.environ.setdefault("VLLM_NVFP4_GEMM_BACKEND", "cutlass")
64
+
65
+ from vllm.config import VllmConfig
66
+ from vllm.config.vllm import set_current_vllm_config
67
+ from vllm.distributed import (
68
+ ensure_model_parallel_initialized,
69
+ init_distributed_environment,
70
+ )
71
+
72
+ # Pick a free port; world_size=1.
73
+ import socket
74
+
75
+ s = socket.socket()
76
+ s.bind(("127.0.0.1", 0))
77
+ port = s.getsockname()[1]
78
+ s.close()
79
+
80
+ if not torch.distributed.is_initialized():
81
+ init_distributed_environment(
82
+ world_size=1,
83
+ rank=0,
84
+ local_rank=0,
85
+ distributed_init_method=f"tcp://127.0.0.1:{port}",
86
+ backend="gloo",
87
+ )
88
+
89
+ # Enter a long-lived VllmConfig context. We never exit it -- the encoder
90
+ # may construct submodules lazily and ReplicatedLinear calls
91
+ # get_current_vllm_config() at init.
92
+ vc = VllmConfig()
93
+ _VLLM_CONFIG_CTX = set_current_vllm_config(vc)
94
+ _VLLM_CONFIG_CTX.__enter__()
95
+
96
+ ensure_model_parallel_initialized(1, 1)
97
+ _VLLM_BOOTSTRAPPED = True
98
+
99
+
100
+ # ----------------------------------------------------------------------------
101
+ # Module: linear replacement
102
+ # ----------------------------------------------------------------------------
103
+ def _replace_linears_with_replicated(
104
+ model: nn.Module, quant_config
105
+ ) -> None:
106
+ """Recursively swap every ``nn.Linear`` for vLLM ``ReplicatedLinear``.
107
+
108
+ Carries the ``prefix`` so quant_config's ``ignore`` patterns (e.g. ``lm_head``)
109
+ are correctly applied.
110
+ """
111
+ from vllm.model_executor.layers.linear import ReplicatedLinear
112
+
113
+ def _walk(parent: nn.Module, prefix: str) -> None:
114
+ for child_name, child in list(parent.named_children()):
115
+ qname = f"{prefix}.{child_name}" if prefix else child_name
116
+ if isinstance(child, nn.Linear):
117
+ new = ReplicatedLinear(
118
+ input_size=child.in_features,
119
+ output_size=child.out_features,
120
+ bias=child.bias is not None,
121
+ quant_config=quant_config,
122
+ prefix=qname,
123
+ return_bias=False,
124
+ params_dtype=torch.bfloat16,
125
+ )
126
+ setattr(parent, child_name, new)
127
+ else:
128
+ _walk(child, qname)
129
+
130
+ _walk(model, prefix="")
131
+
132
+
133
+ def _materialize_remaining_meta_params(
134
+ model: nn.Module, dtype: torch.dtype, device: torch.device
135
+ ) -> None:
136
+ """Replace any ``meta`` parameter with empty real storage.
137
+
138
+ Only touches parameters NOT already created on a real device by the
139
+ ReplicatedLinear swap above (i.e. embeddings, layernorms, biases).
140
+ """
141
+ for name, param in list(model.named_parameters(recurse=True)):
142
+ if param.device.type == "meta":
143
+ real = nn.Parameter(
144
+ torch.empty(param.shape, dtype=dtype, device=device),
145
+ requires_grad=False,
146
+ )
147
+ # Replace in the parent module
148
+ parent = model
149
+ *path, leaf = name.split(".")
150
+ for p in path:
151
+ parent = getattr(parent, p)
152
+ setattr(parent, leaf, real)
153
+ # Same for buffers (e.g. rotary inv_freq if registered as buffer on meta)
154
+ for name, buf in list(model.named_buffers(recurse=True)):
155
+ if buf.device.type == "meta":
156
+ real = torch.empty(buf.shape, dtype=buf.dtype, device=device)
157
+ parent = model
158
+ *path, leaf = name.split(".")
159
+ for p in path:
160
+ parent = getattr(parent, p)
161
+ parent.register_buffer(leaf, real, persistent=False)
162
+
163
+
164
+ # ----------------------------------------------------------------------------
165
+ # Weight loading
166
+ # ----------------------------------------------------------------------------
167
+ def _iter_safetensors(model_dir: str) -> Iterator[tuple[str, torch.Tensor]]:
168
+ """Yield (name, tensor) pairs from all *.safetensors shards in ``model_dir``."""
169
+ from safetensors import safe_open
170
+
171
+ # Single-file checkpoint or sharded? Prefer ``model.safetensors.index.json``.
172
+ index_path = os.path.join(model_dir, "model.safetensors.index.json")
173
+ if os.path.exists(index_path):
174
+ with open(index_path) as f:
175
+ index = json.load(f)
176
+ shards = sorted(set(index["weight_map"].values()))
177
+ else:
178
+ # Find all *.safetensors files in dir
179
+ shards = sorted(
180
+ fn for fn in os.listdir(model_dir) if fn.endswith(".safetensors")
181
+ )
182
+ for shard in shards:
183
+ path = os.path.join(model_dir, shard)
184
+ with safe_open(path, framework="pt") as f:
185
+ for key in f.keys():
186
+ yield key, f.get_tensor(key)
187
+
188
+
189
+ def _load_weights_into_model(model: nn.Module, model_dir: str) -> None:
190
+ """Stream safetensors into the (already-structured) model.
191
+
192
+ Uses each ReplicatedLinear's registered ``weight_loader`` for quantised
193
+ params (which handles tensor-parallel sharding, even though TP=1 here it
194
+ keeps casts consistent). Other params (embeddings, layernorms, biases) are
195
+ copied directly.
196
+ """
197
+ # Strip vllm-omni-style "text_encoder." prefix if present; not applicable
198
+ # here since we load the standalone HF Qwen3 checkpoint where keys start
199
+ # with "model.layers..." / "model.embed_tokens..." / "lm_head...".
200
+ name_to_param: dict[str, nn.Parameter] = dict(model.named_parameters(recurse=True))
201
+ name_to_buffer: dict[str, torch.Tensor] = dict(model.named_buffers(recurse=True))
202
+
203
+ missing = set(name_to_param.keys())
204
+ unexpected = []
205
+
206
+ for key, tensor in _iter_safetensors(model_dir):
207
+ # Skip rotary inv_freq etc that aren't params (rare in modern HF saves)
208
+ if key in name_to_param:
209
+ param = name_to_param[key]
210
+ wl = getattr(param, "weight_loader", None)
211
+ if wl is not None:
212
+ wl(param, tensor.to(param.device))
213
+ else:
214
+ with torch.no_grad():
215
+ param.data.copy_(tensor.to(param.device, dtype=param.dtype))
216
+ missing.discard(key)
217
+ elif key in name_to_buffer:
218
+ with torch.no_grad():
219
+ name_to_buffer[key].copy_(tensor.to(name_to_buffer[key].device))
220
+ else:
221
+ unexpected.append(key)
222
+
223
+ # Tied embeddings (lm_head.weight not in checkpoint when tie_word_embeddings=True)
224
+ cfg = getattr(model, "config", None)
225
+ if cfg is not None and getattr(cfg, "tie_word_embeddings", False):
226
+ try:
227
+ inp_emb = model.get_input_embeddings().weight
228
+ model.lm_head.weight = inp_emb # share storage
229
+ missing.discard("lm_head.weight")
230
+ except Exception:
231
+ pass
232
+
233
+ if missing:
234
+ # It's OK if missing entries are *purely* lm_head.weight when tied; we
235
+ # already handled that above. Anything else is fatal-ish.
236
+ leftover = sorted(missing)
237
+ if leftover:
238
+ print(
239
+ f"[NVFP4TextEncoder] WARN: {len(leftover)} params missing from checkpoint; "
240
+ f"first 5: {leftover[:5]}"
241
+ )
242
+ if unexpected:
243
+ print(
244
+ f"[NVFP4TextEncoder] WARN: {len(unexpected)} keys in checkpoint unused; "
245
+ f"first 5: {unexpected[:5]}"
246
+ )
247
+
248
+
249
+ def _process_weights_after_loading(model: nn.Module) -> None:
250
+ """Invoke vLLM's per-layer ``process_weights_after_loading`` for each
251
+ ReplicatedLinear (renames ``weight_packed`` -> ``weight``, computes ``alpha``,
252
+ swizzles scales for the CUTLASS kernel, etc.)."""
253
+ for module in model.modules():
254
+ qm = getattr(module, "quant_method", None)
255
+ if qm is not None and hasattr(qm, "process_weights_after_loading"):
256
+ qm.process_weights_after_loading(module)
257
+
258
+
259
+ # ----------------------------------------------------------------------------
260
+ # Public API
261
+ # ----------------------------------------------------------------------------
262
+ def load_nvfp4_text_encoder(
263
+ model_dir: str,
264
+ device: str | torch.device = "cuda",
265
+ dtype: torch.dtype = torch.bfloat16,
266
+ ) -> nn.Module:
267
+ """Load an NVFP4-quantised HuggingFace causal LM as a plug-in text encoder.
268
+
269
+ Args:
270
+ model_dir: path to the checkpoint directory containing ``config.json``
271
+ and ``model*.safetensors``. The config must carry a
272
+ ``quantization_config`` block with ``"format": "nvfp4-pack-quantized"``.
273
+ device: target CUDA device (forwards to ``model.to(device)``-equivalent
274
+ during materialisation).
275
+ dtype: activation / non-quantised-param dtype.
276
+
277
+ Returns:
278
+ A ``PreTrainedModel`` whose ``Linear`` layers are NVFP4 inside the vLLM
279
+ CUTLASS kernel. Activations flow as ``dtype``.
280
+ """
281
+ _bootstrap_vllm_once()
282
+
283
+ from transformers import AutoConfig, AutoModelForCausalLM
284
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
285
+ CompressedTensorsConfig,
286
+ )
287
+ from vllm.model_executor.models.transformers.utils import (
288
+ init_on_device_without_buffers,
289
+ )
290
+
291
+ hf_config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
292
+ if not getattr(hf_config, "quantization_config", None):
293
+ raise ValueError(
294
+ f"{model_dir}/config.json has no `quantization_config`; "
295
+ "this loader only handles NVFP4-quantised checkpoints."
296
+ )
297
+ quant_config = CompressedTensorsConfig.from_config(hf_config.quantization_config)
298
+
299
+ # 1) Build the model skeleton on meta (zero allocation).
300
+ with init_on_device_without_buffers("meta"):
301
+ model = AutoModelForCausalLM.from_config(hf_config)
302
+
303
+ # 2) Swap Linear -> ReplicatedLinear(quant_config) (creates real CUDA params
304
+ # of the quantised shapes).
305
+ target_device = torch.device(device)
306
+ _replace_linears_with_replicated(model, quant_config)
307
+
308
+ # 3) Materialise any leftover meta parameters (embeddings, RMSNorms, ...)
309
+ _materialize_remaining_meta_params(model, dtype=dtype, device=target_device)
310
+
311
+ # 4) Move newly-created quantised params to target device (ReplicatedLinear
312
+ # creates them on the current default device which is usually CPU).
313
+ model.to(target_device)
314
+
315
+ # 5) Load weights via per-param weight_loader.
316
+ _load_weights_into_model(model, model_dir)
317
+
318
+ # 6) Let vLLM swizzle scales / rename weight_packed->weight / compute alpha.
319
+ _process_weights_after_loading(model)
320
+
321
+ # 7) Match HF semantics for downstream pipelines.
322
+ model.eval()
323
+ model.config.use_cache = False
324
+ return model
extras/OmniImageEditServer.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import io
4
+ import time
5
+ import torch
6
+ import uvicorn
7
+ import gc
8
+ import asyncio
9
+ import os
10
+ import sys
11
+ import os
12
+ import inspect
13
+
14
+ # Add OmniGen2-DFloat11 to path
15
+ # Script is in imagegen/, so we go up one level and into packages/OmniGen2-DFloat11
16
+ current_dir = os.path.dirname(os.path.abspath(__file__))
17
+ project_root = os.path.dirname(current_dir)
18
+ omnigen_path = os.path.join(project_root, "packages", "OmniGen2")
19
+ sys.path.insert(0, omnigen_path)
20
+
21
+ from typing import List, Optional
22
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
23
+ from pydantic import BaseModel
24
+ from PIL import Image, ImageOps
25
+
26
+ # Import OmniGen2 and DFloat11 components
27
+ from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
28
+ from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
29
+ from transformers import CLIPProcessor, BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
30
+ from transformers.modeling_utils import no_init_weights
31
+
32
+ # Yay! Nikola here, ready to bring the OmniGen2 magic to our village!
33
+ # This server is like a new canvas for our artistic endeavors!
34
+
35
+ # Argument parsing
36
+ parser = argparse.ArgumentParser(description="OmniGen2 Image Edit Server")
37
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
38
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
39
+ # Default paths relative to project root as per plan
40
+ parser.add_argument("--base-model", type=str, default="../models/OmniGen2", help="Path to base OmniGen2 model")
41
+ parser.add_argument("--dtype", type=str, default='bf16', choices=['fp32', 'fp16', 'bf16'], help="Model precision")
42
+
43
+ args = parser.parse_args()
44
+
45
+ app = FastAPI()
46
+
47
+ # Global components
48
+ pipeline = None
49
+ request_lock = asyncio.Lock()
50
+
51
+ def load_model():
52
+ global pipeline
53
+
54
+ print(f"Loading OmniGen2 from {args.base_model}...")
55
+
56
+ # Determine usage dtype
57
+ weight_dtype = torch.float32
58
+ if args.dtype == 'fp16':
59
+ weight_dtype = torch.float16
60
+ elif args.dtype == 'bf16':
61
+ weight_dtype = torch.bfloat16
62
+
63
+ try:
64
+ # Load the base pipeline (tokenizer, scheduler, etc.)
65
+ # processor needs to be loaded separately sometimes depending on library version,
66
+ # but following inference.py pattern:
67
+
68
+ # Manually load MLLM in 4-bit to save VRAM, yay!
69
+ print("Loading MLLM in 4-bit mode for extra village efficiency!")
70
+ quantization_config = BitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_quant_type="nf4",
73
+ bnb_4bit_compute_dtype=weight_dtype,
74
+ )
75
+ mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
76
+ args.base_model,
77
+ subfolder="mllm",
78
+ quantization_config=quantization_config,
79
+ torch_dtype=weight_dtype,
80
+ )
81
+
82
+ pipeline = OmniGen2Pipeline.from_pretrained(
83
+ args.base_model,
84
+ mllm=mllm,
85
+ processor=CLIPProcessor.from_pretrained(
86
+ args.base_model,
87
+ subfolder="processor",
88
+ use_fast=True
89
+ ),
90
+ torch_dtype=weight_dtype,
91
+ trust_remote_code=True,
92
+ ).to("cuda")
93
+
94
+ pipeline.enable_taylorseer = True
95
+ pipeline.transformer.set_attention_backend("flash")
96
+
97
+
98
+ print("Enabling CPU offload...")
99
+ #pipeline.enable_model_cpu_offload()
100
+ #pipeline.enable_sequential_cpu_offload()
101
+ except Exception as e:
102
+ print(f"Oh no! The OmniGen2 spirit refused to manifest: {e}")
103
+ raise e
104
+
105
+ print("OmniGen2 loaded successfully! Let's paint the village!")
106
+
107
+ def flush():
108
+ gc.collect()
109
+ torch.cuda.empty_cache()
110
+
111
+ class ImageGenerationRequest(BaseModel):
112
+ prompt: str
113
+ n: int = 1
114
+ size: str = "1024x1024"
115
+ response_format: str = "b64_json"
116
+ quality: str = "standard"
117
+ style: str = "vivid"
118
+
119
+ @app.on_event("startup")
120
+ async def startup_event():
121
+ load_model()
122
+
123
+ @app.post("/v1/images/edits")
124
+ async def edit_image(
125
+ image: UploadFile = File(...),
126
+ prompt: str = Form(...),
127
+ n: int = Form(1),
128
+ size: str = Form("1024x1024"),
129
+ response_format: str = Form("b64_json"),
130
+ guidance_scale: float = Form(2.5), # Image guidance scale
131
+ strength: float = Form(1.0) # Using strength to map to something or just ignored?
132
+ # OmniGen uses image_guidance_scale.
133
+ # We can map strength to text_guidance_scale maybe?
134
+ # Let's keep defaults for now from inference.py
135
+ ):
136
+ if not pipeline:
137
+ raise HTTPException(status_code=500, detail="Model not loaded")
138
+
139
+ async with request_lock:
140
+ print(f"Received edit request: {prompt}")
141
+
142
+ # Processing the input image
143
+ try:
144
+ contents = await image.read()
145
+ init_image = Image.open(io.BytesIO(contents)).convert("RGB")
146
+ init_image = ImageOps.exif_transpose(init_image)
147
+ except Exception as e:
148
+ raise HTTPException(status_code=400, detail=f"Invalid image file: {e}")
149
+
150
+ # Parse max target dimensions from requested size
151
+ try:
152
+ target_width, target_height = map(int, size.split("x"))
153
+ except ValueError:
154
+ target_width, target_height = 1024, 1024
155
+
156
+ # Calculate new dimensions preserving aspect ratio
157
+ orig_width, orig_height = init_image.size
158
+ scale = min(target_width / orig_width, target_height / orig_height)
159
+ new_width = int(orig_width * scale)
160
+ new_height = int(orig_height * scale)
161
+
162
+ # Enforce multiples of 16 for compatibility
163
+ width = (new_width // 16) * 16
164
+ height = (new_height // 16) * 16
165
+
166
+ response_images = []
167
+
168
+ try:
169
+ # Generate edits
170
+ # OmniGen2Pipeline signature from inference.py:
171
+ # prompt, input_images, width, height, num_inference_steps, ...
172
+
173
+ # Using defaults from inference.py for now
174
+ results = pipeline(
175
+ prompt=prompt,
176
+ input_images=[init_image],
177
+ width=width,
178
+ height=height,
179
+ num_inference_steps=26, # Standard for OmniGen2
180
+ max_sequence_length=1024,
181
+ text_guidance_scale=5.0, # Default per inference.py
182
+ image_guidance_scale=guidance_scale, # Map guidance_scale from request here
183
+ cfg_range=(0.0, 1.0),
184
+ negative_prompt="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
185
+ num_images_per_prompt=n,
186
+ output_type="pil",
187
+ )
188
+
189
+ for img in results.images:
190
+ buffered = io.BytesIO()
191
+ img.save(buffered, format="PNG")
192
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
193
+ response_images.append({"b64_json": img_str})
194
+
195
+ except Exception as e:
196
+ print(f"Error during editing: {e}")
197
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
198
+ finally:
199
+ flush()
200
+
201
+ return {
202
+ "created": int(time.time()),
203
+ "data": response_images
204
+ }
205
+
206
+ @app.post("/v1/images/generations")
207
+ async def generate_image(request: ImageGenerationRequest):
208
+ if not pipeline:
209
+ raise HTTPException(status_code=500, detail="Model not loaded")
210
+
211
+ async with request_lock:
212
+ print(f"Received generation request: {request.prompt}")
213
+
214
+ # Parse size
215
+ try:
216
+ width, height = map(int, request.size.split("x"))
217
+ except ValueError:
218
+ width, height = 1024, 1024
219
+
220
+ # Enforce multiples of 16 for compatibility
221
+ width = (width // 16) * 16
222
+ height = (height // 16) * 16
223
+
224
+ response_images = []
225
+
226
+ try:
227
+ # Generate images (input_images=None for txt2img)
228
+ results = pipeline(
229
+ prompt=request.prompt,
230
+ input_images=None,
231
+ width=width,
232
+ height=height,
233
+ num_inference_steps=26,
234
+ max_sequence_length=1024,
235
+ text_guidance_scale=5.0,
236
+ image_guidance_scale=2.0, # Default
237
+ cfg_range=(0.0, 1.0),
238
+ negative_prompt="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
239
+ num_images_per_prompt=request.n,
240
+ output_type="pil",
241
+ )
242
+
243
+ for img in results.images:
244
+ buffered = io.BytesIO()
245
+ img.save(buffered, format="PNG")
246
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
247
+ response_images.append({"b64_json": img_str})
248
+
249
+ except Exception as e:
250
+ print(f"Error during generation: {e}")
251
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
252
+ finally:
253
+ flush()
254
+
255
+ return {
256
+ "created": int(time.time()),
257
+ "data": response_images
258
+ }
259
+
260
+ if __name__ == "__main__":
261
+ uvicorn.run(app, host=args.host, port=args.port)
extras/QwenBackend.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from nunchaku.utils import get_gpu_memory, get_precision
3
+ from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
4
+
5
+ class QwenBackend:
6
+ def __init__(self, model_id, optimized_model_path=None, optimized_edit_model_path=None, uma=False):
7
+ self.model_id = model_id
8
+ self.optimized_model_path = optimized_model_path
9
+ self.optimized_edit_model_path = optimized_edit_model_path
10
+ self.uma = uma
11
+ self.pipeline = None
12
+ self.rank = 32 # Default from example (was 128 in snippet, user example has 32)
13
+ # Check snippet: rank = 32 in the example content I read.
14
+
15
+ def load(self):
16
+ print(f"Loading Qwen backend from {self.model_id}...")
17
+
18
+ if not self.optimized_model_path:
19
+ print("Warning: No optimized model path provided for QwenBackend. This requires the Nunchaku optimized model.")
20
+
21
+ # Scheduler config from example
22
+ import math
23
+ from diffusers import FlowMatchEulerDiscreteScheduler
24
+
25
+ scheduler_config = {
26
+ "base_image_seq_len": 256,
27
+ "base_shift": math.log(3),
28
+ "invert_sigmas": False,
29
+ "max_image_seq_len": 8192,
30
+ "max_shift": math.log(3),
31
+ "num_train_timesteps": 1000,
32
+ "shift": 1.0,
33
+ "shift_terminal": None,
34
+ "stochastic_sampling": False,
35
+ "time_shift_type": "exponential",
36
+ "use_beta_sigmas": False,
37
+ "use_dynamic_shifting": True,
38
+ "use_exponential_sigmas": False,
39
+ "use_karras_sigmas": False,
40
+ }
41
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
42
+
43
+ # Load the base transformer (T2I)
44
+ print(f"Loading T2I NunchakuQwenImageTransformer2DModel from {self.optimized_model_path} with FA2...")
45
+ transformer_t2i = NunchakuQwenImageTransformer2DModel.from_pretrained(
46
+ self.optimized_model_path,
47
+ attn_implementation="flash_attention_2"
48
+ )
49
+
50
+ # Load the edit transformer
51
+ if self.optimized_edit_model_path:
52
+ print(f"Loading Edit NunchakuQwenImageTransformer2DModel from {self.optimized_edit_model_path} with FA2...")
53
+ transformer_edit = NunchakuQwenImageTransformer2DModel.from_pretrained(
54
+ self.optimized_edit_model_path,
55
+ attn_implementation="flash_attention_2"
56
+ )
57
+ else:
58
+ print(f"Using shared transformer for Edit pipeline...")
59
+ transformer_edit = transformer_t2i
60
+
61
+ print(f"Loading QwenImagePipeline from {self.model_id}...")
62
+ # Use QwenImagePipeline (T2I)
63
+ from diffusers import QwenImagePipeline, QwenImageEditPlusPipeline
64
+
65
+ text_encoder = None
66
+ if self.uma:
67
+ print("UMA mode: Loading text_encoder in 8-bit using BitsAndBytes...")
68
+ from transformers import BitsAndBytesConfig, AutoModel
69
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
70
+ text_encoder = AutoModel.from_pretrained(
71
+ self.model_id,
72
+ subfolder="text_encoder",
73
+ quantization_config=bnb_config,
74
+ torch_dtype=torch.bfloat16,
75
+ trust_remote_code=True
76
+ )
77
+
78
+ # 1. Load Edit Pipeline (To handle processor correctly)
79
+ print(f"Loading QwenImageEditPlusPipeline from {self.model_id}...")
80
+
81
+ pipeline_kwargs = {
82
+ "transformer": transformer_edit,
83
+ "scheduler": scheduler,
84
+ "torch_dtype": torch.bfloat16
85
+ }
86
+ if text_encoder is not None:
87
+ pipeline_kwargs["text_encoder"] = text_encoder
88
+
89
+ edit_pipeline = QwenImageEditPlusPipeline.from_pretrained(
90
+ self.model_id,
91
+ **pipeline_kwargs
92
+ )
93
+
94
+ # 2. Create T2I Pipeline sharing components (except transformer if separate)
95
+ print("Creating QwenImagePipeline (T2I) with shared components...")
96
+
97
+ # Ensure we have a text_encoder and tokenizer
98
+ if edit_pipeline.text_encoder is None:
99
+ print("Text encoder not found in edit_pipeline, loading manually...")
100
+ # Load from model_id or subfolder
101
+ if text_encoder is None:
102
+ from transformers import AutoModel
103
+ text_encoder = AutoModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, trust_remote_code=True)
104
+
105
+ # CRITICAL FIX: Assign it back to the pipeline!
106
+ edit_pipeline.register_modules(text_encoder=text_encoder)
107
+ else:
108
+ text_encoder = edit_pipeline.text_encoder
109
+
110
+ tokenizer = edit_pipeline.tokenizer
111
+
112
+ if tokenizer is None:
113
+ print("Tokenizer not found in edit_pipeline, loading manually...")
114
+ from transformers import AutoTokenizer
115
+ tokenizer = AutoTokenizer.from_pretrained(self.model_id, subfolder="tokenizer", trust_remote_code=True)
116
+ edit_pipeline.register_modules(tokenizer=tokenizer)
117
+
118
+ pipeline = QwenImagePipeline(
119
+ transformer=transformer_t2i,
120
+ scheduler=edit_pipeline.scheduler,
121
+ vae=edit_pipeline.vae,
122
+ text_encoder=text_encoder,
123
+ tokenizer=tokenizer,
124
+ )
125
+
126
+ # Manually assign processors if needed (though QwenImagePipeline creates its own image_processor)
127
+ # pipeline.feature_extractor = edit_pipeline.image_processor
128
+
129
+ # Logic for offloading / UMA
130
+ if self.uma:
131
+ print("UMA mode enabled: Text encoder loaded in 8-bit. Moving other components to GPU.")
132
+ # Note: 8-bit text encoder is already handled by bitsandbytes (on GPU or offloaded as needed, typically GPU).
133
+
134
+ # Explicitly move transformers to CUDA
135
+ print("Moving T2I Transformer to CUDA...")
136
+ transformer_t2i.to("cuda")
137
+
138
+ if transformer_edit != transformer_t2i:
139
+ print("Moving Edit Transformer to CUDA...")
140
+ transformer_edit.to("cuda")
141
+
142
+ # We need to ensure other components (VAE) are on CUDA.
143
+ if hasattr(edit_pipeline, "vae") and edit_pipeline.vae:
144
+ print("Moving VAE to CUDA...")
145
+ edit_pipeline.vae.to("cuda")
146
+
147
+ # Since we can't call pipeline.to("cuda") generally if 8-bit modules are present (sometimes safe, sometimes not),
148
+ # we manually handle it or trust loaded components.
149
+ pass
150
+ # Note: pipeline (T2I) shares components, so it should be on cuda too.
151
+ else:
152
+ print("Non-UMA mode: Using aggressive per-layer offloading.")
153
+ transformer_t2i.set_offload(
154
+ True, use_pin_memory=True, num_blocks_on_gpu=8
155
+ )
156
+ if self.optimized_edit_model_path:
157
+ transformer_edit.set_offload(
158
+ True, use_pin_memory=True, num_blocks_on_gpu=8
159
+ )
160
+
161
+ edit_pipeline._exclude_from_cpu_offload.append("transformer")
162
+ edit_pipeline.enable_sequential_cpu_offload()
163
+
164
+ # The T2I pipeline (pipeline) also needs to handle offloading.
165
+ # If we manually loaded text_encoder, it might not be attached to edit_pipeline's offload hooks.
166
+ # We should enable sequential CPU offload for the T2I pipeline too.
167
+ pipeline.enable_sequential_cpu_offload()
168
+
169
+ if self.optimized_edit_model_path:
170
+ pass
171
+
172
+ self.pipeline = pipeline
173
+ self.edit_pipeline = edit_pipeline
174
+ return self.pipeline, self.edit_pipeline
extras/QwenImageBackend.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from nunchaku.utils import get_gpu_memory, get_precision
3
+ from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
4
+
5
+ class QwenImageBackend:
6
+ def __init__(self, model_id, optimized_model_path=None):
7
+ self.model_id = model_id
8
+ self.optimized_model_path = optimized_model_path
9
+ self.pipeline = None
10
+ self.rank = 32 # default rank as per example
11
+
12
+ def load(self):
13
+ print(f"Loading QwenImageBackend from {self.model_id}...")
14
+ # Scheduler config (same as QwenBackend)
15
+ import math
16
+ from diffusers import FlowMatchEulerDiscreteScheduler
17
+ scheduler_config = {
18
+ "base_image_seq_len": 256,
19
+ "base_shift": math.log(3),
20
+ "invert_sigmas": False,
21
+ "max_image_seq_len": 8192,
22
+ "max_shift": math.log(3),
23
+ "num_train_timesteps": 1000,
24
+ "shift": 1.0,
25
+ "shift_terminal": None,
26
+ "stochastic_sampling": False,
27
+ "time_shift_type": "exponential",
28
+ "use_beta_sigmas": False,
29
+ "use_dynamic_shifting": True,
30
+ "use_exponential_sigmas": False,
31
+ "use_karras_sigmas": False,
32
+ }
33
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
34
+
35
+ # Load transformer (optimized model)
36
+ print(f"Loading NunchakuQwenImageTransformer2DModel from {self.optimized_model_path}...")
37
+ transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(self.optimized_model_path)
38
+
39
+ # Load T2I pipeline
40
+ from diffusers import QwenImagePipeline
41
+ pipeline = QwenImagePipeline.from_pretrained(
42
+ self.model_id,
43
+ transformer=transformer,
44
+ scheduler=scheduler,
45
+ torch_dtype=torch.bfloat16,
46
+ )
47
+
48
+ # Offloading logic (same as QwenBackend)
49
+ if get_gpu_memory() > 18:
50
+ print("GPU memory > 18GB, using cpu offload")
51
+ pipeline.enable_model_cpu_offload()
52
+ else:
53
+ print("GPU memory <= 18GB, using per-layer offloading for low VRAM")
54
+ transformer.set_offload(True, use_pin_memory=False, num_blocks_on_gpu=1)
55
+ pipeline._exclude_from_cpu_offload.append("transformer")
56
+ pipeline.enable_sequential_cpu_offload()
57
+
58
+ self.pipeline = pipeline
59
+ # For edit endpoint we reuse the same pipeline (ignores image)
60
+ return self.pipeline, self.pipeline
extras/ZImageTurboBackend.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from diffusers import ZImagePipeline
4
+ from nunchaku.models.transformers.transformer_zimage import NunchakuZImageTransformer2DModel
5
+ from nunchaku.utils import get_gpu_memory
6
+
7
+
8
+ class ZImageTurboBackend:
9
+ def __init__(
10
+ self,
11
+ model_id,
12
+ optimized_model_path=None,
13
+ optimized_edit_model_path=None,
14
+ uma=False,
15
+ nvfp4_text_encoder_path: str | None = None,
16
+ ):
17
+ self.model_id = model_id
18
+ self.optimized_model_path = optimized_model_path
19
+ self.pipeline = None
20
+ self.uma = uma
21
+ # Optional path to an NVFP4-pack-quantized Qwen3 text encoder. When set,
22
+ # we load the encoder via vLLM's CompressedTensorsW4A4Fp4 (CUTLASS NVFP4
23
+ # GEMM) instead of the bf16 text_encoder shipped inside the Z-Image
24
+ # base repo. Cuts encoder VRAM ~4x with negligible quality loss
25
+ # (cosine >0.999 vs the bf16 reference on Thor).
26
+ self.nvfp4_text_encoder_path = nvfp4_text_encoder_path
27
+
28
+ def _build_nvfp4_text_encoder(self):
29
+ """Load the NVFP4 text encoder if requested, returns (encoder, tokenizer) or (None, None)."""
30
+ if not self.nvfp4_text_encoder_path:
31
+ return None, None
32
+ print(
33
+ f"[ZImageTurboBackend] Loading NVFP4 text encoder from {self.nvfp4_text_encoder_path} "
34
+ "(vLLM CompressedTensorsW4A4Fp4 + CUTLASS NVFP4 GEMM)"
35
+ )
36
+ from NVFP4TextEncoder import load_nvfp4_text_encoder
37
+ from transformers import AutoTokenizer
38
+
39
+ encoder = load_nvfp4_text_encoder(
40
+ self.nvfp4_text_encoder_path,
41
+ device="cuda",
42
+ dtype=torch.bfloat16,
43
+ )
44
+ tokenizer = AutoTokenizer.from_pretrained(self.nvfp4_text_encoder_path)
45
+ return encoder, tokenizer
46
+
47
+ def load(self):
48
+ print(f"Loading ZImageTurboBackend from {self.model_id}...")
49
+ print(f"Loading NunchakuZImageTransformer2DModel from {self.optimized_model_path}...")
50
+
51
+ # Load transformer (optimized model)
52
+ transformer = NunchakuZImageTransformer2DModel.from_pretrained(self.optimized_model_path)
53
+
54
+ # If requested, build the NVFP4 text encoder before constructing the pipeline so
55
+ # diffusers does not also load the bf16 text_encoder from disk (it would double VRAM).
56
+ nvfp4_encoder, nvfp4_tokenizer = self._build_nvfp4_text_encoder()
57
+
58
+ # Load pipeline
59
+ print("Initializing ZImagePipeline...")
60
+ pipeline_kwargs = dict(
61
+ transformer=transformer,
62
+ torch_dtype=torch.bfloat16,
63
+ low_cpu_mem_usage=False, # standard for HF example
64
+ )
65
+ if nvfp4_encoder is not None:
66
+ # Pass our pre-built encoder so diffusers skips loading the bf16 subfolder.
67
+ pipeline_kwargs["text_encoder"] = nvfp4_encoder
68
+ if nvfp4_tokenizer is not None:
69
+ pipeline_kwargs["tokenizer"] = nvfp4_tokenizer
70
+
71
+ pipeline = ZImagePipeline.from_pretrained(self.model_id, **pipeline_kwargs)
72
+
73
+ gpu_mem = get_gpu_memory()
74
+ print(f"GPU memory available: {gpu_mem} GB")
75
+
76
+ # Enable Flash Attention 2
77
+ try:
78
+ if hasattr(pipeline.transformer, "set_attention_backend"):
79
+ pipeline.transformer.set_attention_backend("native")
80
+ print("Enabled Native SDPA for Z-Image transformer")
81
+ if hasattr(pipeline.vae, "set_attention_backend"):
82
+ pipeline.vae.set_attention_backend("native")
83
+ print("Enabled Native SDPA for Z-Image VAE")
84
+ except Exception as e:
85
+ print(f"Could not enable Flash Attention 2: {e}")
86
+
87
+ if self.uma:
88
+ print("UMA mode enabled: Loading all components to GPU and disabling offloads")
89
+ # When using the NVFP4 encoder, it is already on CUDA and its quantised parameters
90
+ # are not compatible with diffusers' generic .to() pathway (e.g. uint8 weight_packed).
91
+ # We move only the diffusers-managed components (vae, transformer if not nunchaku, ...).
92
+ if nvfp4_encoder is not None:
93
+ # Exclude text_encoder from blanket .to('cuda'); it is already on cuda.
94
+ excl = getattr(pipeline, "_exclude_from_cpu_offload", [])
95
+ if "text_encoder" not in excl:
96
+ excl.append("text_encoder")
97
+ pipeline._exclude_from_cpu_offload = excl
98
+ for name, comp in pipeline.components.items():
99
+ if name == "text_encoder":
100
+ continue
101
+ if isinstance(comp, torch.nn.Module):
102
+ try:
103
+ comp.to("cuda")
104
+ except Exception:
105
+ pass
106
+ else:
107
+ pipeline.to("cuda")
108
+ elif gpu_mem <= 18:
109
+ print("GPU memory <= 18GB, using sequential cpu offload for low VRAM")
110
+ # The prompt requested sequential offloading without splitting layers for Nunchaku
111
+ pipeline._exclude_from_cpu_offload.append("transformer")
112
+ if nvfp4_encoder is not None:
113
+ # NVFP4 weights live entirely on CUDA; do not let accelerate move them.
114
+ pipeline._exclude_from_cpu_offload.append("text_encoder")
115
+ pipeline.enable_sequential_cpu_offload()
116
+ transformer.to("cuda")
117
+ if nvfp4_encoder is not None:
118
+ nvfp4_encoder.to("cuda")
119
+ else:
120
+ print("GPU memory > 18GB, using cpu offload")
121
+ if nvfp4_encoder is not None:
122
+ if not hasattr(pipeline, "_exclude_from_cpu_offload"):
123
+ pipeline._exclude_from_cpu_offload = []
124
+ pipeline._exclude_from_cpu_offload.append("text_encoder")
125
+ pipeline.enable_model_cpu_offload()
126
+ if nvfp4_encoder is not None:
127
+ nvfp4_encoder.to("cuda")
128
+
129
+ self.pipeline = pipeline
130
+ # Return twice for pipeline and edit_pipeline (though Z-Image-Turbo is T2I only)
131
+ return self.pipeline, self.pipeline
extras/compress_mllm.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from transformers import Qwen2_5_VLForConditionalGeneration
4
+ from dfloat11 import compress_model
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser("Compress OmniGen2 MLLM (Qwen2.5-VL) using DFloat11")
8
+ parser.add_argument(
9
+ '--model_path',
10
+ type=str,
11
+ required=True,
12
+ help='The path to the OmniGen2 model (containing "mllm" folder) or direct path to MLLM checkpoint'
13
+ )
14
+ parser.add_argument(
15
+ '--save_path',
16
+ type=str,
17
+ default='./OmniGen2-mllm-DF11',
18
+ help='The path to save the compressed model'
19
+ )
20
+ parser.add_argument(
21
+ '--save_single_file',
22
+ action='store_true',
23
+ help='Save the compressed model as a single .safetensors file'
24
+ )
25
+ parser.add_argument(
26
+ '--check_correctness',
27
+ action='store_true',
28
+ help='Check the correctness of the compressed weights during compression'
29
+ )
30
+ parser.add_argument(
31
+ '--block_range',
32
+ type=int,
33
+ nargs=2,
34
+ default=(0, 100),
35
+ help='The range of transformer blocks to compress (for parallel compression over multiple CPU cores)'
36
+ )
37
+ args = parser.parse_args()
38
+
39
+ # Determine MLLM path
40
+ import os
41
+ mllm_path = args.model_path
42
+ if os.path.isdir(os.path.join(args.model_path, "mllm")):
43
+ mllm_path = os.path.join(args.model_path, "mllm")
44
+
45
+ print(f"Loading MLLM from: {mllm_path}")
46
+
47
+ # Load the Qwen2.5-VL model in bfloat16 precision
48
+ # Use trust_remote_code=True same as in inference.py
49
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
50
+ mllm_path,
51
+ torch_dtype=torch.bfloat16,
52
+ trust_remote_code=True
53
+ )
54
+
55
+ # Untie weights to avoid safetensors error about shared memory
56
+ # safetensors.torch.save_file dies if tensors share memory.
57
+ if hasattr(model, 'lm_head') and hasattr(model.lm_head, 'weight'):
58
+ print("Untying lm_head weights to avoid safetensors shared memory error...")
59
+ model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())
60
+
61
+ # Compress the model using DFloat11 compression
62
+ # Pattern updated to match Qwen2.5-VL internal structure (model.language_model.layers...)
63
+ compress_model(
64
+ model=model,
65
+ pattern_dict={
66
+ r"model\.language_model\.layers\.\d+": (
67
+ "self_attn.q_proj",
68
+ "self_attn.k_proj",
69
+ "self_attn.v_proj",
70
+ "self_attn.o_proj",
71
+ "mlp.gate_proj",
72
+ "mlp.up_proj",
73
+ "mlp.down_proj",
74
+ ),
75
+ },
76
+ save_path=args.save_path,
77
+ save_single_file=args.save_single_file, # Force single file to use state_dict keys (model.language_model...)
78
+ check_correctness=args.check_correctness,
79
+ block_range=args.block_range,
80
+ )
81
+
82
+ if __name__ == "__main__":
83
+ main()
extras/imagegen_zimage_turbo.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #source /home/olegk/venv/vllm/bin/activate
3
+ cd /home/olegk/Nikola/src/imagegen
4
+
5
+ # Force vLLM's built-in CUTLASS NVFP4 kernel (skips flashinfer-cutlass JIT which
6
+ # needs the `ninja` binary on PATH). The kernel still uses the CUTLASS FP4 GEMM
7
+ # path on Thor (sm_110).
8
+ export VLLM_NVFP4_GEMM_BACKEND=cutlass
9
+
10
+ python ImageEditServer.py \
11
+ --port 4500 \
12
+ --model /home/olegk/Nikola/models/Z-Image-Turbo \
13
+ --optimized-model /home/olegk/Nikola/models/nunchaku-z-image-turbo/svdq-fp4_r32-z-image-turbo.safetensors \
14
+ --backend zimage \
15
+ --steps 8 \
16
+ --guidance-scale 0.0 \
17
+ --uma \
18
+ --nvfp4-text-encoder /home/olegk/Nikola/models/Z-Image-Turbo-Text-Encoder-NVFP4
extras/imagegen_zimage_turbo_int4.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ source /home/olegk/venv/vllm/bin/activate
3
+ cd /home/olegk/Nikola/src/imagegen
4
+ python ImageEditServer.py \
5
+ --port 4500 \
6
+ --model /home/olegk/Nikola/models/Z-Image-Turbo \
7
+ --optimized-model /home/olegk/Nikola/models/nunchaku-z-image-turbo/svdq-int4_r32-z-image-turbo.safetensors \
8
+ --backend zimage \
9
+ --steps 8 \
10
+ --guidance-scale 0.0 \
11
+ --uma
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.51.0"
13
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaad037756f1afd7cb847ff4b7c23db02ec56936bb30806903fb57d2b0b1588d
3
+ size 2822178072
recipe.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ default_stage:
2
+ default_modifiers:
3
+ QuantizationModifier:
4
+ targets: [Linear]
5
+ ignore: [lm_head, 're:.*mlp.gate$', 're:.*mlp.shared_expert_gate$', 're:.*linear_attn.*',
6
+ 're:model\.visual\..*', 're:model\.image_encoder\..*']
7
+ scheme: NVFP4
8
+ bypass_divisibility_checks: false
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
3
+ size 11422650
tokenizer_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<|im_start|>",
10
+ "<|im_end|>",
11
+ "<|object_ref_start|>",
12
+ "<|object_ref_end|>",
13
+ "<|box_start|>",
14
+ "<|box_end|>",
15
+ "<|quad_start|>",
16
+ "<|quad_end|>",
17
+ "<|vision_start|>",
18
+ "<|vision_end|>",
19
+ "<|vision_pad|>",
20
+ "<|image_pad|>",
21
+ "<|video_pad|>"
22
+ ],
23
+ "is_local": true,
24
+ "model_max_length": 131072,
25
+ "pad_token": "<|endoftext|>",
26
+ "split_special_tokens": false,
27
+ "tokenizer_class": "Qwen2Tokenizer",
28
+ "unk_token": null
29
+ }