Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- chat_template.jinja +117 -0
- config.json +121 -0
- extras/Flux2Backend.py +70 -0
- extras/GlmBackend.py +45 -0
- extras/ImageEditServer.py +496 -0
- extras/ImageGenClient.py +150 -0
- extras/ImageGenServer.py +123 -0
- extras/ImageGenServer_cpu.py +307 -0
- extras/ImageGenServer_new.py +176 -0
- extras/KontextBackend.py +93 -0
- extras/NVFP4TextEncoder.py +324 -0
- extras/OmniImageEditServer.py +261 -0
- extras/QwenBackend.py +174 -0
- extras/QwenImageBackend.py +60 -0
- extras/ZImageTurboBackend.py +131 -0
- extras/compress_mllm.py +83 -0
- extras/imagegen_zimage_turbo.sh +18 -0
- extras/imagegen_zimage_turbo_int4.sh +11 -0
- generation_config.json +13 -0
- model.safetensors +3 -0
- recipe.yaml +8 -0
- tokenizer.json +3 -0
- tokenizer_config.json +29 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% macro render_extra_keys(json_dict, handled_keys) %}
|
| 2 |
+
{%- if json_dict is mapping %}
|
| 3 |
+
{%- for json_key in json_dict if json_key not in handled_keys %}
|
| 4 |
+
{%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
|
| 5 |
+
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
|
| 6 |
+
{%- else %}
|
| 7 |
+
{{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
|
| 8 |
+
{%- endif %}
|
| 9 |
+
{%- endfor %}
|
| 10 |
+
{%- endif %}
|
| 11 |
+
{% endmacro %}
|
| 12 |
+
|
| 13 |
+
{%- if messages[0]["role"] == "system" %}
|
| 14 |
+
{%- set system_message = messages[0]["content"] %}
|
| 15 |
+
{%- set loop_messages = messages[1:] %}
|
| 16 |
+
{%- else %}
|
| 17 |
+
{%- set loop_messages = messages %}
|
| 18 |
+
{%- endif %}
|
| 19 |
+
|
| 20 |
+
{%- if not tools is defined %}
|
| 21 |
+
{%- set tools = [] %}
|
| 22 |
+
{%- endif %}
|
| 23 |
+
|
| 24 |
+
{%- if system_message is defined %}
|
| 25 |
+
{{- "<|im_start|>system\n" + system_message }}
|
| 26 |
+
{%- else %}
|
| 27 |
+
{%- if tools is iterable and tools | length > 0 %}
|
| 28 |
+
{{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }}
|
| 29 |
+
{%- endif %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- if tools is iterable and tools | length > 0 %}
|
| 32 |
+
{{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }}
|
| 33 |
+
{{- "<tools>" }}
|
| 34 |
+
{%- for tool in tools %}
|
| 35 |
+
{%- if tool.function is defined %}
|
| 36 |
+
{%- set tool = tool.function %}
|
| 37 |
+
{%- endif %}
|
| 38 |
+
{{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
|
| 39 |
+
{%- if tool.description is defined %}
|
| 40 |
+
{{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{{- '\n<parameters>' }}
|
| 43 |
+
{%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
|
| 44 |
+
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
| 45 |
+
{{- '\n<parameter>' }}
|
| 46 |
+
{{- '\n<name>' ~ param_name ~ '</name>' }}
|
| 47 |
+
{%- if param_fields.type is defined %}
|
| 48 |
+
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
|
| 49 |
+
{%- endif %}
|
| 50 |
+
{%- if param_fields.description is defined %}
|
| 51 |
+
{{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
|
| 52 |
+
{%- endif %}
|
| 53 |
+
{%- set handled_keys = ['name', 'type', 'description'] %}
|
| 54 |
+
{{- render_extra_keys(param_fields, handled_keys) }}
|
| 55 |
+
{{- '\n</parameter>' }}
|
| 56 |
+
{%- endfor %}
|
| 57 |
+
{%- endif %}
|
| 58 |
+
{% set handled_keys = ['type', 'properties'] %}
|
| 59 |
+
{{- render_extra_keys(tool.parameters, handled_keys) }}
|
| 60 |
+
{{- '\n</parameters>' }}
|
| 61 |
+
{%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
|
| 62 |
+
{{- render_extra_keys(tool, handled_keys) }}
|
| 63 |
+
{{- '\n</function>' }}
|
| 64 |
+
{%- endfor %}
|
| 65 |
+
{{- "\n</tools>" }}
|
| 66 |
+
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{%- if system_message is defined %}
|
| 69 |
+
{{- '<|im_end|>\n' }}
|
| 70 |
+
{%- else %}
|
| 71 |
+
{%- if tools is iterable and tools | length > 0 %}
|
| 72 |
+
{{- '<|im_end|>\n' }}
|
| 73 |
+
{%- endif %}
|
| 74 |
+
{%- endif %}
|
| 75 |
+
{%- for message in loop_messages %}
|
| 76 |
+
{%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
|
| 77 |
+
{{- '<|im_start|>' + message.role }}
|
| 78 |
+
{%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}
|
| 79 |
+
{{- '\n' + message.content | trim + '\n' }}
|
| 80 |
+
{%- endif %}
|
| 81 |
+
{%- for tool_call in message.tool_calls %}
|
| 82 |
+
{%- if tool_call.function is defined %}
|
| 83 |
+
{%- set tool_call = tool_call.function %}
|
| 84 |
+
{%- endif %}
|
| 85 |
+
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
| 86 |
+
{%- if tool_call.arguments is defined %}
|
| 87 |
+
{%- for args_name, args_value in tool_call.arguments|items %}
|
| 88 |
+
{{- '<parameter=' + args_name + '>\n' }}
|
| 89 |
+
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
| 90 |
+
{{- args_value }}
|
| 91 |
+
{{- '\n</parameter>\n' }}
|
| 92 |
+
{%- endfor %}
|
| 93 |
+
{%- endif %}
|
| 94 |
+
{{- '</function>\n</tool_call>' }}
|
| 95 |
+
{%- endfor %}
|
| 96 |
+
{{- '<|im_end|>\n' }}
|
| 97 |
+
{%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %}
|
| 98 |
+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
| 99 |
+
{%- elif message.role == "tool" %}
|
| 100 |
+
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
| 101 |
+
{{- '<|im_start|>user\n' }}
|
| 102 |
+
{%- endif %}
|
| 103 |
+
{{- '<tool_response>\n' }}
|
| 104 |
+
{{- message.content }}
|
| 105 |
+
{{- '\n</tool_response>\n' }}
|
| 106 |
+
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
| 107 |
+
{{- '<|im_end|>\n' }}
|
| 108 |
+
{%- elif loop.last %}
|
| 109 |
+
{{- '<|im_end|>\n' }}
|
| 110 |
+
{%- endif %}
|
| 111 |
+
{%- else %}
|
| 112 |
+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
|
| 113 |
+
{%- endif %}
|
| 114 |
+
{%- endfor %}
|
| 115 |
+
{%- if add_generation_prompt %}
|
| 116 |
+
{{- '<|im_start|>assistant\n' }}
|
| 117 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"dtype": "bfloat16",
|
| 9 |
+
"eos_token_id": 151645,
|
| 10 |
+
"head_dim": 128,
|
| 11 |
+
"hidden_act": "silu",
|
| 12 |
+
"hidden_size": 2560,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 9728,
|
| 15 |
+
"layer_types": [
|
| 16 |
+
"full_attention",
|
| 17 |
+
"full_attention",
|
| 18 |
+
"full_attention",
|
| 19 |
+
"full_attention",
|
| 20 |
+
"full_attention",
|
| 21 |
+
"full_attention",
|
| 22 |
+
"full_attention",
|
| 23 |
+
"full_attention",
|
| 24 |
+
"full_attention",
|
| 25 |
+
"full_attention",
|
| 26 |
+
"full_attention",
|
| 27 |
+
"full_attention",
|
| 28 |
+
"full_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"full_attention",
|
| 31 |
+
"full_attention",
|
| 32 |
+
"full_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"full_attention",
|
| 35 |
+
"full_attention",
|
| 36 |
+
"full_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"full_attention",
|
| 39 |
+
"full_attention",
|
| 40 |
+
"full_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"full_attention",
|
| 43 |
+
"full_attention",
|
| 44 |
+
"full_attention",
|
| 45 |
+
"full_attention",
|
| 46 |
+
"full_attention",
|
| 47 |
+
"full_attention",
|
| 48 |
+
"full_attention",
|
| 49 |
+
"full_attention",
|
| 50 |
+
"full_attention",
|
| 51 |
+
"full_attention"
|
| 52 |
+
],
|
| 53 |
+
"max_position_embeddings": 40960,
|
| 54 |
+
"max_window_layers": 36,
|
| 55 |
+
"model_type": "qwen3",
|
| 56 |
+
"num_attention_heads": 32,
|
| 57 |
+
"num_hidden_layers": 36,
|
| 58 |
+
"num_key_value_heads": 8,
|
| 59 |
+
"pad_token_id": null,
|
| 60 |
+
"quantization_config": {
|
| 61 |
+
"config_groups": {
|
| 62 |
+
"group_0": {
|
| 63 |
+
"format": "nvfp4-pack-quantized",
|
| 64 |
+
"input_activations": {
|
| 65 |
+
"actorder": null,
|
| 66 |
+
"block_structure": null,
|
| 67 |
+
"dynamic": "local",
|
| 68 |
+
"group_size": 16,
|
| 69 |
+
"num_bits": 4,
|
| 70 |
+
"observer": "static_minmax",
|
| 71 |
+
"observer_kwargs": {},
|
| 72 |
+
"scale_dtype": "torch.float8_e4m3fn",
|
| 73 |
+
"strategy": "tensor_group",
|
| 74 |
+
"symmetric": true,
|
| 75 |
+
"type": "float",
|
| 76 |
+
"zp_dtype": null
|
| 77 |
+
},
|
| 78 |
+
"output_activations": null,
|
| 79 |
+
"targets": [
|
| 80 |
+
"Linear"
|
| 81 |
+
],
|
| 82 |
+
"weights": {
|
| 83 |
+
"actorder": null,
|
| 84 |
+
"block_structure": null,
|
| 85 |
+
"dynamic": false,
|
| 86 |
+
"group_size": 16,
|
| 87 |
+
"num_bits": 4,
|
| 88 |
+
"observer": "memoryless_minmax",
|
| 89 |
+
"observer_kwargs": {},
|
| 90 |
+
"scale_dtype": "torch.float8_e4m3fn",
|
| 91 |
+
"strategy": "tensor_group",
|
| 92 |
+
"symmetric": true,
|
| 93 |
+
"type": "float",
|
| 94 |
+
"zp_dtype": null
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
},
|
| 98 |
+
"format": "nvfp4-pack-quantized",
|
| 99 |
+
"global_compression_ratio": null,
|
| 100 |
+
"ignore": [
|
| 101 |
+
"lm_head"
|
| 102 |
+
],
|
| 103 |
+
"kv_cache_scheme": null,
|
| 104 |
+
"quant_method": "compressed-tensors",
|
| 105 |
+
"quantization_status": "compressed",
|
| 106 |
+
"sparsity_config": {},
|
| 107 |
+
"transform_config": {},
|
| 108 |
+
"version": "0.15.1.dev14+g01a1c9a"
|
| 109 |
+
},
|
| 110 |
+
"rms_norm_eps": 1e-06,
|
| 111 |
+
"rope_parameters": {
|
| 112 |
+
"rope_theta": 1000000,
|
| 113 |
+
"rope_type": "default"
|
| 114 |
+
},
|
| 115 |
+
"sliding_window": null,
|
| 116 |
+
"tie_word_embeddings": true,
|
| 117 |
+
"transformers_version": "5.2.0",
|
| 118 |
+
"use_cache": true,
|
| 119 |
+
"use_sliding_window": false,
|
| 120 |
+
"vocab_size": 151936
|
| 121 |
+
}
|
extras/Flux2Backend.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import Mistral3ForConditionalGeneration, PixtralProcessor, BitsAndBytesConfig
|
| 3 |
+
from diffusers import Flux2Pipeline, AutoencoderKLFlux2, Flux2Transformer2DModel
|
| 4 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 5 |
+
|
| 6 |
+
class Flux2Backend:
|
| 7 |
+
def __init__(self, model_id):
|
| 8 |
+
self.model_id = model_id
|
| 9 |
+
self.pipeline = None
|
| 10 |
+
|
| 11 |
+
def load(self):
|
| 12 |
+
print(f"Loading Flux2 backend from {self.model_id}...")
|
| 13 |
+
|
| 14 |
+
quantization_config = BitsAndBytesConfig(
|
| 15 |
+
load_in_4bit=True,
|
| 16 |
+
bnb_4bit_quant_type="nf4",
|
| 17 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 18 |
+
bnb_4bit_use_double_quant=True,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Scheduler
|
| 22 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 23 |
+
self.model_id,
|
| 24 |
+
subfolder="scheduler",
|
| 25 |
+
torch_dtype=torch.bfloat16
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# VAE - loaded manually with full precision
|
| 29 |
+
vae = AutoencoderKLFlux2.from_pretrained(
|
| 30 |
+
self.model_id,
|
| 31 |
+
subfolder="vae",
|
| 32 |
+
torch_dtype=torch.float16
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
tokenizer = PixtralProcessor.from_pretrained(
|
| 36 |
+
self.model_id,
|
| 37 |
+
subfolder="tokenizer",
|
| 38 |
+
torch_dtype=torch.float16
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
|
| 42 |
+
self.model_id,
|
| 43 |
+
subfolder="text_encoder",
|
| 44 |
+
torch_dtype=torch.float16,
|
| 45 |
+
quantization_config=quantization_config
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
dit = Flux2Transformer2DModel.from_pretrained(
|
| 49 |
+
self.model_id,
|
| 50 |
+
subfolder="transformer",
|
| 51 |
+
torch_dtype=torch.float16,
|
| 52 |
+
quantization_config=quantization_config
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Standard loading without Nunchaku optimization
|
| 57 |
+
# Constructing pipeline manually rather than from_pretrained
|
| 58 |
+
pipeline = Flux2Pipeline(
|
| 59 |
+
scheduler=scheduler,
|
| 60 |
+
vae=vae,
|
| 61 |
+
text_encoder=text_encoder,
|
| 62 |
+
tokenizer=tokenizer,
|
| 63 |
+
transformer=dit,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.pipeline = pipeline
|
| 67 |
+
self.pipeline.to("cuda")
|
| 68 |
+
self.pipeline.transformer.set_attention_backend("flash")
|
| 69 |
+
|
| 70 |
+
return self.pipeline, self.pipeline
|
extras/GlmBackend.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import diffusers
|
| 3 |
+
try:
|
| 4 |
+
from sdnq import SDNQConfig
|
| 5 |
+
from sdnq.common import use_torch_compile as triton_is_available
|
| 6 |
+
from sdnq.loader import apply_sdnq_options_to_model
|
| 7 |
+
SDNQ_AVAILABLE = True
|
| 8 |
+
except ImportError:
|
| 9 |
+
print("SDNQ not found, optimized GLM loading will be skipped.")
|
| 10 |
+
SDNQ_AVAILABLE = False
|
| 11 |
+
|
| 12 |
+
class GlmBackend:
|
| 13 |
+
def __init__(self, model_id="Disty0/GLM-Image-SDNQ-4bit-dynamic"):
|
| 14 |
+
self.model_id = model_id
|
| 15 |
+
self.pipeline = None
|
| 16 |
+
|
| 17 |
+
def load(self):
|
| 18 |
+
print(f"Loading GLM backend from {self.model_id}...")
|
| 19 |
+
|
| 20 |
+
# Load the pipeline
|
| 21 |
+
# Using bfloat16 as per request snippet
|
| 22 |
+
pipeline = diffusers.GlmImagePipeline.from_pretrained(
|
| 23 |
+
self.model_id,
|
| 24 |
+
torch_dtype=torch.bfloat16,
|
| 25 |
+
trust_remote_code=True,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if SDNQ_AVAILABLE:
|
| 29 |
+
# Enable INT8 MatMul for GPUs if Triton is available
|
| 30 |
+
if triton_is_available and (torch.cuda.is_available() or torch.xpu.is_available()):
|
| 31 |
+
print("Applying SDNQ optimizations (INT8 MatMul)...")
|
| 32 |
+
pipeline.transformer = apply_sdnq_options_to_model(pipeline.transformer, use_quantized_matmul=True)
|
| 33 |
+
# pipeline.transformer = torch.compile(pipeline.transformer) # Optional, commented out as in snippet
|
| 34 |
+
else:
|
| 35 |
+
print("Triton or CUDA/XPU not available, skipping SDNQ optimization.")
|
| 36 |
+
|
| 37 |
+
print("Enabling CPU offload for GLM pipeline...")
|
| 38 |
+
pipeline.enable_model_cpu_offload()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
self.pipeline = pipeline
|
| 42 |
+
|
| 43 |
+
# The user stated: "this one uses same pipe line for image generation and editing"
|
| 44 |
+
# So we return the same pipeline for both.
|
| 45 |
+
return self.pipeline, self.pipeline
|
extras/ImageEditServer.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import base64
|
| 3 |
+
import io
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import uvicorn
|
| 7 |
+
import gc
|
| 8 |
+
import asyncio
|
| 9 |
+
import traceback
|
| 10 |
+
from typing import List, Optional, Union
|
| 11 |
+
from contextlib import asynccontextmanager
|
| 12 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 13 |
+
from pydantic import BaseModel
|
| 14 |
+
from PIL import Image, ImageOps
|
| 15 |
+
|
| 16 |
+
# Argument parsing
|
| 17 |
+
parser = argparse.ArgumentParser(description="Flux Image Edit Server with Nunchaku")
|
| 18 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
| 19 |
+
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
| 20 |
+
parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-Kontext-dev", help="Path or Repo ID of the base model")
|
| 21 |
+
parser.add_argument("--optimized-model", type=str, default=None, help="Path to the optimized Nunchaku model safetensors file")
|
| 22 |
+
parser.add_argument("--optimized-edit-model", type=str, default=None, help="Path to the optimized Nunchaku model safetensors file for editing (optional)")
|
| 23 |
+
parser.add_argument("--backend", type=str, default="kontext", choices=["kontext", "flux2", "qwen", "glm", "zimage"], help="Backend to use: 'kontext', 'flux2', 'qwen', 'glm', or 'zimage'")
|
| 24 |
+
parser.add_argument("--steps", type=int, default=28, help="Default number of inference steps")
|
| 25 |
+
parser.add_argument("--guidance-scale", type=float, default=3.5, help="Default guidance scale")
|
| 26 |
+
parser.add_argument("--qwenimage", action="store_true", help="Use QwenImageBackend (T2I only) instead of full Qwen edit backend")
|
| 27 |
+
parser.add_argument("--uma", action="store_true", help="Enable Unified Memory Architecture mode (load all to GPU, disable offload)")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--nvfp4-text-encoder",
|
| 30 |
+
type=str,
|
| 31 |
+
default=None,
|
| 32 |
+
help=(
|
| 33 |
+
"Path to an NVFP4-pack-quantized HuggingFace text encoder "
|
| 34 |
+
"(compressed-tensors format). Currently honoured by the zimage backend; "
|
| 35 |
+
"swaps in vLLM's W4A4 NVFP4 CUTLASS GEMM for ~4x text-encoder VRAM savings."
|
| 36 |
+
),
|
| 37 |
+
)
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
@asynccontextmanager
|
| 41 |
+
async def lifespan(app: FastAPI):
|
| 42 |
+
# Startup logic
|
| 43 |
+
load_model()
|
| 44 |
+
yield
|
| 45 |
+
# Shutdown logic (if any) could go here
|
| 46 |
+
|
| 47 |
+
app = FastAPI(lifespan=lifespan)
|
| 48 |
+
|
| 49 |
+
# Global components
|
| 50 |
+
IMAGE_DIMENSION_ALIGNMENT = 32
|
| 51 |
+
pipeline = None
|
| 52 |
+
edit_pipeline = None
|
| 53 |
+
request_lock = asyncio.Lock()
|
| 54 |
+
is_sleeping_flag = False
|
| 55 |
+
sleep_requested = False
|
| 56 |
+
|
| 57 |
+
def load_model():
|
| 58 |
+
global pipeline, edit_pipeline
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
if args.backend == "kontext":
|
| 62 |
+
import KontextBackend
|
| 63 |
+
print(f"Initializing KontextBackend...")
|
| 64 |
+
backend = KontextBackend.KontextBackend(args.model, args.optimized_model)
|
| 65 |
+
pipeline, edit_pipeline = backend.load()
|
| 66 |
+
elif args.backend == "flux2":
|
| 67 |
+
import Flux2Backend
|
| 68 |
+
print(f"Initializing Flux2Backend...")
|
| 69 |
+
backend = Flux2Backend.Flux2Backend(args.model)
|
| 70 |
+
pipeline, edit_pipeline = backend.load()
|
| 71 |
+
elif args.backend == "glm":
|
| 72 |
+
import GlmBackend
|
| 73 |
+
print(f"Initializing GlmBackend...")
|
| 74 |
+
# Use provided model or default to the one in the snippet if args.model is generic
|
| 75 |
+
# The user might pass the specific GLM model via --model, or we default in GlmBackend.
|
| 76 |
+
# Let's pass args.model if it's not the default flux one, otherwise let GlmBackend use its default.
|
| 77 |
+
model_to_use = args.model if args.model != "black-forest-labs/FLUX.1-Kontext-dev" else "Disty0/GLM-Image-SDNQ-4bit-dynamic"
|
| 78 |
+
backend = GlmBackend.GlmBackend(model_to_use)
|
| 79 |
+
pipeline, edit_pipeline = backend.load()
|
| 80 |
+
elif args.backend.startswith("qwen"):
|
| 81 |
+
if args.qwenimage:
|
| 82 |
+
import QwenImageBackend
|
| 83 |
+
print(f"Initializing QwenImageBackend (T2I only)...")
|
| 84 |
+
backend = QwenImageBackend.QwenImageBackend(args.model, args.optimized_model)
|
| 85 |
+
pipeline, edit_pipeline = backend.load()
|
| 86 |
+
else:
|
| 87 |
+
import QwenBackend
|
| 88 |
+
print(f"Initializing QwenBackend...")
|
| 89 |
+
backend = QwenBackend.QwenBackend(args.model, args.optimized_model, optimized_edit_model_path=args.optimized_edit_model, uma=args.uma)
|
| 90 |
+
pipeline, edit_pipeline = backend.load()
|
| 91 |
+
elif args.backend == "zimage":
|
| 92 |
+
import ZImageTurboBackend
|
| 93 |
+
print(f"Initializing ZImageTurboBackend...")
|
| 94 |
+
backend = ZImageTurboBackend.ZImageTurboBackend(
|
| 95 |
+
args.model,
|
| 96 |
+
args.optimized_model,
|
| 97 |
+
uma=args.uma,
|
| 98 |
+
nvfp4_text_encoder_path=args.nvfp4_text_encoder,
|
| 99 |
+
)
|
| 100 |
+
pipeline, edit_pipeline = backend.load()
|
| 101 |
+
else:
|
| 102 |
+
raise ValueError(f"Unknown backend: {args.backend}")
|
| 103 |
+
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Oh no! The model refused to wake up: {e}")
|
| 106 |
+
raise e
|
| 107 |
+
|
| 108 |
+
# Enable progress bar for diffusers
|
| 109 |
+
import diffusers.utils.logging
|
| 110 |
+
diffusers.utils.logging.enable_progress_bar()
|
| 111 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 112 |
+
|
| 113 |
+
print("Model loaded successfully! Ready for editing quests!")
|
| 114 |
+
|
| 115 |
+
def flush():
|
| 116 |
+
gc.collect()
|
| 117 |
+
torch.cuda.empty_cache()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class ImageGenerationRequest(BaseModel):
|
| 121 |
+
prompt: str
|
| 122 |
+
n: int = 1
|
| 123 |
+
size: str = "1024x1024"
|
| 124 |
+
response_format: str = "b64_json"
|
| 125 |
+
quality: str = "standard"
|
| 126 |
+
style: str = "vivid"
|
| 127 |
+
num_inference_steps: Optional[int] = None
|
| 128 |
+
guidance_scale: Optional[float] = None
|
| 129 |
+
negative_prompt: Optional[str] = None
|
| 130 |
+
seed: Optional[int] = None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@app.post("/v1/sleep")
|
| 134 |
+
async def sleep_endpoint():
|
| 135 |
+
global is_sleeping_flag, sleep_requested
|
| 136 |
+
sleep_requested = True
|
| 137 |
+
try:
|
| 138 |
+
async with request_lock:
|
| 139 |
+
if not is_sleeping_flag and sleep_requested:
|
| 140 |
+
print("Sleep requested, moving models to CPU...")
|
| 141 |
+
for p in [pipeline, edit_pipeline]:
|
| 142 |
+
if not p: continue
|
| 143 |
+
for name, component in p.components.items():
|
| 144 |
+
if isinstance(component, torch.nn.Module):
|
| 145 |
+
# Special handling for Nunchaku which blocks .to() if offload is True
|
| 146 |
+
if hasattr(component, "set_offload") and getattr(component, "offload", False):
|
| 147 |
+
component.set_offload(False)
|
| 148 |
+
component._nunchaku_was_offloaded = True
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
component.to("cpu")
|
| 152 |
+
except Exception as e:
|
| 153 |
+
pass
|
| 154 |
+
flush()
|
| 155 |
+
is_sleeping_flag = True
|
| 156 |
+
finally:
|
| 157 |
+
sleep_requested = False
|
| 158 |
+
return {"status": "sleep completed", "is_sleeping": is_sleeping_flag}
|
| 159 |
+
|
| 160 |
+
@app.post("/v1/wake_up")
|
| 161 |
+
async def wake_up_endpoint():
|
| 162 |
+
global is_sleeping_flag, sleep_requested
|
| 163 |
+
sleep_requested = False
|
| 164 |
+
async with request_lock:
|
| 165 |
+
if is_sleeping_flag:
|
| 166 |
+
print("Waking up, restoring models to CUDA...")
|
| 167 |
+
for p in [pipeline, edit_pipeline]:
|
| 168 |
+
if not p: continue
|
| 169 |
+
excluded = getattr(p, "_exclude_from_cpu_offload", [])
|
| 170 |
+
for name, component in p.components.items():
|
| 171 |
+
if isinstance(component, torch.nn.Module):
|
| 172 |
+
if getattr(component, "_nunchaku_was_offloaded", False):
|
| 173 |
+
component.set_offload(True, use_pin_memory=True, num_blocks_on_gpu=8)
|
| 174 |
+
for attr in ["img_in", "txt_in", "txt_norm", "time_text_embed", "norm_out", "proj_out"]:
|
| 175 |
+
if hasattr(component, attr):
|
| 176 |
+
try:
|
| 177 |
+
getattr(component, attr).to("cuda")
|
| 178 |
+
except Exception:
|
| 179 |
+
pass
|
| 180 |
+
component._nunchaku_was_offloaded = False
|
| 181 |
+
elif not hasattr(component, "_hf_hook") or name in excluded:
|
| 182 |
+
try:
|
| 183 |
+
component.to("cuda")
|
| 184 |
+
except Exception:
|
| 185 |
+
pass
|
| 186 |
+
is_sleeping_flag = False
|
| 187 |
+
return {"status": "awoken", "is_sleeping": False}
|
| 188 |
+
|
| 189 |
+
@app.get("/v1/is_sleeping")
|
| 190 |
+
async def is_sleeping_endpoint():
|
| 191 |
+
return {"is_sleeping": is_sleeping_flag}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@app.get("/v1/memory_stats")
|
| 195 |
+
async def memory_stats_endpoint():
|
| 196 |
+
"""Lightweight introspection endpoint that returns PyTorch's CUDA allocator
|
| 197 |
+
snapshot. Used to diagnose VRAM/UMA bloat without restarting the server."""
|
| 198 |
+
stats = {}
|
| 199 |
+
if torch.cuda.is_available():
|
| 200 |
+
stats["allocated_gb"] = torch.cuda.memory_allocated() / 1e9
|
| 201 |
+
stats["reserved_gb"] = torch.cuda.memory_reserved() / 1e9
|
| 202 |
+
stats["max_allocated_gb"] = torch.cuda.max_memory_allocated() / 1e9
|
| 203 |
+
stats["max_reserved_gb"] = torch.cuda.max_memory_reserved() / 1e9
|
| 204 |
+
# Top allocations by size from the allocator snapshot (>=64 MiB)
|
| 205 |
+
try:
|
| 206 |
+
snap = torch.cuda.memory_snapshot()
|
| 207 |
+
blocks = []
|
| 208 |
+
for seg in snap:
|
| 209 |
+
for b in seg.get("blocks", []):
|
| 210 |
+
if b.get("state") == "active_allocated" and b.get("size", 0) >= 64 * 1024 * 1024:
|
| 211 |
+
blocks.append(b["size"])
|
| 212 |
+
blocks.sort(reverse=True)
|
| 213 |
+
stats["large_active_blocks_gb"] = [round(s / 1e9, 3) for s in blocks[:20]]
|
| 214 |
+
stats["large_active_blocks_total_gb"] = round(sum(blocks) / 1e9, 3)
|
| 215 |
+
stats["large_active_blocks_count"] = len(blocks)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
stats["snapshot_error"] = str(e)
|
| 218 |
+
# Walk Python objects to find big tensors and group them
|
| 219 |
+
try:
|
| 220 |
+
import gc as _gc
|
| 221 |
+
seen = set()
|
| 222 |
+
big = []
|
| 223 |
+
for obj in _gc.get_objects():
|
| 224 |
+
try:
|
| 225 |
+
if isinstance(obj, torch.Tensor) and obj.is_cuda:
|
| 226 |
+
ptr = obj.data_ptr()
|
| 227 |
+
if ptr in seen or ptr == 0:
|
| 228 |
+
continue
|
| 229 |
+
seen.add(ptr)
|
| 230 |
+
sz = obj.element_size() * obj.numel()
|
| 231 |
+
if sz >= 16 * 1024 * 1024:
|
| 232 |
+
big.append((sz, tuple(obj.shape), str(obj.dtype)))
|
| 233 |
+
except Exception:
|
| 234 |
+
continue
|
| 235 |
+
big.sort(reverse=True)
|
| 236 |
+
# Group by (shape, dtype)
|
| 237 |
+
from collections import Counter
|
| 238 |
+
grouped = Counter((shape, dtype) for _, shape, dtype in big)
|
| 239 |
+
stats["big_tensor_groups"] = [
|
| 240 |
+
{"shape": list(shape), "dtype": dtype, "count": cnt,
|
| 241 |
+
"size_gb_each": round(
|
| 242 |
+
(1 if shape == () else (lambda l: __import__('functools').reduce(lambda a, b: a*b, l, 1))(shape)) * (
|
| 243 |
+
8 if 'int64' in dtype or 'float64' in dtype else
|
| 244 |
+
4 if 'int32' in dtype or 'float32' in dtype else
|
| 245 |
+
2 if 'bfloat16' in dtype or 'float16' in dtype else 1
|
| 246 |
+
) / 1e9, 4)}
|
| 247 |
+
for (shape, dtype), cnt in grouped.most_common(30)
|
| 248 |
+
]
|
| 249 |
+
stats["big_tensor_count"] = len(big)
|
| 250 |
+
stats["big_tensor_total_gb"] = round(sum(s for s, _, _ in big) / 1e9, 3)
|
| 251 |
+
except Exception as e:
|
| 252 |
+
stats["walk_error"] = str(e)
|
| 253 |
+
return stats
|
| 254 |
+
|
| 255 |
+
@app.post("/v1/images/edits")
|
| 256 |
+
async def edit_image(
|
| 257 |
+
image: Union[List[UploadFile], UploadFile] = File(...),
|
| 258 |
+
prompt: str = Form(...),
|
| 259 |
+
n: int = Form(1),
|
| 260 |
+
size: str = Form("1024x1024"),
|
| 261 |
+
response_format: str = Form("b64_json"), # Default to b64_json
|
| 262 |
+
guidance_scale: Optional[float] = Form(None),
|
| 263 |
+
num_inference_steps: Optional[int] = Form(None),
|
| 264 |
+
negative_prompt: Optional[str] = Form(None),
|
| 265 |
+
seed: Optional[int] = Form(None)
|
| 266 |
+
):
|
| 267 |
+
# Use CLI defaults if not provided
|
| 268 |
+
steps = num_inference_steps if num_inference_steps is not None else args.steps
|
| 269 |
+
cfg_scale = guidance_scale if guidance_scale is not None else args.guidance_scale
|
| 270 |
+
neg_prompt = negative_prompt if negative_prompt is not None else "" # Default empty for now, or maybe None?
|
| 271 |
+
|
| 272 |
+
generator = None
|
| 273 |
+
import random
|
| 274 |
+
if seed is None:
|
| 275 |
+
seed = random.randint(0, 2**32 - 1)
|
| 276 |
+
|
| 277 |
+
print(f"Using seed: {seed}")
|
| 278 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
| 279 |
+
|
| 280 |
+
if not edit_pipeline:
|
| 281 |
+
raise HTTPException(status_code=500, detail="Model not loaded")
|
| 282 |
+
|
| 283 |
+
if sleep_requested or is_sleeping_flag:
|
| 284 |
+
raise HTTPException(status_code=503, detail="Server is sleeping or trying to sleep.")
|
| 285 |
+
|
| 286 |
+
async with request_lock:
|
| 287 |
+
print(f"Received edit request: {prompt}")
|
| 288 |
+
|
| 289 |
+
# Processing the input image(s)
|
| 290 |
+
input_files = image if isinstance(image, list) else [image]
|
| 291 |
+
init_images = []
|
| 292 |
+
|
| 293 |
+
try:
|
| 294 |
+
for img_file in input_files:
|
| 295 |
+
await img_file.seek(0)
|
| 296 |
+
contents = await img_file.read()
|
| 297 |
+
img = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 298 |
+
init_images.append(img)
|
| 299 |
+
except Exception as e:
|
| 300 |
+
raise HTTPException(status_code=400, detail=f"Invalid image file: {e}")
|
| 301 |
+
|
| 302 |
+
if not init_images:
|
| 303 |
+
raise HTTPException(status_code=400, detail="No images provided")
|
| 304 |
+
|
| 305 |
+
# Parse max target dimensions from requested size
|
| 306 |
+
try:
|
| 307 |
+
target_width, target_height = map(int, size.split("x"))
|
| 308 |
+
except ValueError:
|
| 309 |
+
target_width, target_height = 1024, 1024
|
| 310 |
+
|
| 311 |
+
# Calculate new dimensions preserving aspect ratio based on the first image
|
| 312 |
+
first_image = init_images[0]
|
| 313 |
+
orig_width, orig_height = first_image.size
|
| 314 |
+
scale = min(target_width / orig_width, target_height / orig_height)
|
| 315 |
+
new_width = int(orig_width * scale)
|
| 316 |
+
new_height = int(orig_height * scale)
|
| 317 |
+
|
| 318 |
+
# Ensure dimensions are aligned to 32 for compatibility (e.g. GLM-Image)
|
| 319 |
+
width = (new_width // IMAGE_DIMENSION_ALIGNMENT) * IMAGE_DIMENSION_ALIGNMENT
|
| 320 |
+
height = (new_height // IMAGE_DIMENSION_ALIGNMENT) * IMAGE_DIMENSION_ALIGNMENT
|
| 321 |
+
|
| 322 |
+
# Resize input images to match the calculated target size, padding if necessary
|
| 323 |
+
resized_images = []
|
| 324 |
+
for img in init_images:
|
| 325 |
+
if img.size != (width, height):
|
| 326 |
+
# Use ImageOps.pad to preserve aspect ratio and center in the target size
|
| 327 |
+
# This handles cases where subsequent images might have different ARs
|
| 328 |
+
img = ImageOps.pad(img, (width, height), method=Image.LANCZOS, color=(0, 0, 0))
|
| 329 |
+
resized_images.append(img)
|
| 330 |
+
|
| 331 |
+
# If single image, pass as item, if multiple, pass as list
|
| 332 |
+
# GLM pipeline has a bug where it checks len() on the input, so it must be a list
|
| 333 |
+
if len(resized_images) > 1 or args.backend == "glm":
|
| 334 |
+
image_input = resized_images
|
| 335 |
+
else:
|
| 336 |
+
image_input = resized_images[0]
|
| 337 |
+
|
| 338 |
+
response_images = []
|
| 339 |
+
|
| 340 |
+
try:
|
| 341 |
+
if args.backend.startswith("qwen"):
|
| 342 |
+
# Qwen specific parameters
|
| 343 |
+
# guidance_scale maps to true_cfg_scale
|
| 344 |
+
if args.qwenimage: # QwenImageBackend is T2I only, so it doesn't take an image
|
| 345 |
+
generated_images = edit_pipeline(
|
| 346 |
+
prompt=prompt,
|
| 347 |
+
height=height,
|
| 348 |
+
width=width,
|
| 349 |
+
num_inference_steps=steps,
|
| 350 |
+
true_cfg_scale=cfg_scale,
|
| 351 |
+
num_images_per_prompt=n,
|
| 352 |
+
generator=generator,
|
| 353 |
+
).images
|
| 354 |
+
else: # Full Qwen edit backend takes an image (or list of images now)
|
| 355 |
+
generated_images = edit_pipeline(
|
| 356 |
+
image=image_input,
|
| 357 |
+
prompt=prompt,
|
| 358 |
+
height=height,
|
| 359 |
+
width=width,
|
| 360 |
+
negative_prompt=neg_prompt,
|
| 361 |
+
num_inference_steps=steps,
|
| 362 |
+
true_cfg_scale=cfg_scale,
|
| 363 |
+
num_images_per_prompt=n,
|
| 364 |
+
generator=generator,
|
| 365 |
+
).images
|
| 366 |
+
else:
|
| 367 |
+
# Standard Flux/Kontext or GLM
|
| 368 |
+
# GLM I2I Fix: Manually move vision encoder to GPU because get_image_features escapes hooks
|
| 369 |
+
if args.backend == "glm" and hasattr(edit_pipeline, "vision_language_encoder"):
|
| 370 |
+
print("Manually moving GLM Vision Encoder to GPU...")
|
| 371 |
+
edit_pipeline.vision_language_encoder.to("cuda")
|
| 372 |
+
|
| 373 |
+
try:
|
| 374 |
+
generated_images = edit_pipeline(
|
| 375 |
+
image=image_input,
|
| 376 |
+
prompt=prompt,
|
| 377 |
+
height=height,
|
| 378 |
+
width=width,
|
| 379 |
+
num_inference_steps=steps,
|
| 380 |
+
guidance_scale=cfg_scale,
|
| 381 |
+
num_images_per_prompt=n,
|
| 382 |
+
generator=generator,
|
| 383 |
+
).images
|
| 384 |
+
finally:
|
| 385 |
+
if args.backend == "glm" and hasattr(edit_pipeline, "vision_language_encoder"):
|
| 386 |
+
print("Moving GLM Vision Encoder back to CPU...")
|
| 387 |
+
edit_pipeline.vision_language_encoder.to("cpu")
|
| 388 |
+
|
| 389 |
+
for img in generated_images:
|
| 390 |
+
buffered = io.BytesIO()
|
| 391 |
+
img.save(buffered, format="PNG")
|
| 392 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 393 |
+
|
| 394 |
+
if response_format == "b64_json":
|
| 395 |
+
response_images.append({"b64_json": img_str})
|
| 396 |
+
else:
|
| 397 |
+
# If url is requested we can't really do it without storage, so we fallback or error?
|
| 398 |
+
# For now, let's just assume simple b64_json as per request
|
| 399 |
+
response_images.append({"b64_json": img_str}) # Fallback
|
| 400 |
+
|
| 401 |
+
except Exception as e:
|
| 402 |
+
print(f"Error during editing: {e}")
|
| 403 |
+
print(traceback.format_exc())
|
| 404 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 405 |
+
finally:
|
| 406 |
+
flush()
|
| 407 |
+
|
| 408 |
+
return {
|
| 409 |
+
"created": int(time.time()),
|
| 410 |
+
"data": response_images
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
@app.post("/v1/images/generations")
|
| 416 |
+
async def generate_image(request: ImageGenerationRequest):
|
| 417 |
+
if not pipeline:
|
| 418 |
+
raise HTTPException(status_code=500, detail="Model not loaded")
|
| 419 |
+
|
| 420 |
+
if sleep_requested or is_sleeping_flag:
|
| 421 |
+
raise HTTPException(status_code=503, detail="Server is sleeping or trying to sleep.")
|
| 422 |
+
|
| 423 |
+
async with request_lock:
|
| 424 |
+
#print(f"Received generation request: {request.prompt}")
|
| 425 |
+
|
| 426 |
+
# Parse size
|
| 427 |
+
try:
|
| 428 |
+
width, height = map(int, request.size.split("x"))
|
| 429 |
+
except ValueError:
|
| 430 |
+
width, height = 1024, 1024
|
| 431 |
+
|
| 432 |
+
# Ensure dimensions are aligned to 32
|
| 433 |
+
width = (width // IMAGE_DIMENSION_ALIGNMENT) * IMAGE_DIMENSION_ALIGNMENT
|
| 434 |
+
height = (height // IMAGE_DIMENSION_ALIGNMENT) * IMAGE_DIMENSION_ALIGNMENT
|
| 435 |
+
|
| 436 |
+
response_images = []
|
| 437 |
+
|
| 438 |
+
try:
|
| 439 |
+
# Generate images (no image argument for txt2img!)
|
| 440 |
+
steps = request.num_inference_steps if request.num_inference_steps is not None else args.steps
|
| 441 |
+
cfg_scale = request.guidance_scale if request.guidance_scale is not None else args.guidance_scale
|
| 442 |
+
# negative_prompt not in standard request body in original snippet, but we added it to model
|
| 443 |
+
neg_prompt = request.negative_prompt if request.negative_prompt is not None else ""
|
| 444 |
+
|
| 445 |
+
generator = None
|
| 446 |
+
import random
|
| 447 |
+
seed = request.seed
|
| 448 |
+
if seed is None:
|
| 449 |
+
seed = random.randint(0, 2**32 - 1)
|
| 450 |
+
|
| 451 |
+
print(f"Using seed: {seed}")
|
| 452 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
| 453 |
+
|
| 454 |
+
if args.backend.startswith("qwen"):
|
| 455 |
+
generated_images = pipeline(
|
| 456 |
+
prompt=request.prompt,
|
| 457 |
+
height=height,
|
| 458 |
+
width=width,
|
| 459 |
+
num_inference_steps=steps,
|
| 460 |
+
true_cfg_scale=cfg_scale,
|
| 461 |
+
num_images_per_prompt=request.n,
|
| 462 |
+
negative_prompt=neg_prompt,
|
| 463 |
+
generator=generator,
|
| 464 |
+
).images
|
| 465 |
+
else:
|
| 466 |
+
generated_images = pipeline(
|
| 467 |
+
prompt=request.prompt,
|
| 468 |
+
height=height,
|
| 469 |
+
width=width,
|
| 470 |
+
num_inference_steps=steps,
|
| 471 |
+
guidance_scale=cfg_scale,
|
| 472 |
+
num_images_per_prompt=request.n,
|
| 473 |
+
generator=generator,
|
| 474 |
+
# Not passing negative_prompt here for generation unless we confirm support in standard Flux pipeline?
|
| 475 |
+
).images
|
| 476 |
+
|
| 477 |
+
for img in generated_images:
|
| 478 |
+
buffered = io.BytesIO()
|
| 479 |
+
img.save(buffered, format="PNG")
|
| 480 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 481 |
+
response_images.append({"b64_json": img_str})
|
| 482 |
+
|
| 483 |
+
except Exception as e:
|
| 484 |
+
print(f"Error during generation: {e}")
|
| 485 |
+
print(traceback.format_exc())
|
| 486 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 487 |
+
finally:
|
| 488 |
+
flush()
|
| 489 |
+
|
| 490 |
+
return {
|
| 491 |
+
"created": int(time.time()),
|
| 492 |
+
"data": response_images
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
if __name__ == "__main__":
|
| 496 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
extras/ImageGenClient.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import base64
|
| 5 |
+
import time
|
| 6 |
+
import io
|
| 7 |
+
import requests
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
# Oh, hello there! Nikola here, ready to help this little client talk to the big server!
|
| 11 |
+
# It's like sending a messenger bird from our village to the capital!
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
# Peeking at the arguments... gotta make sure we have all our supplies for the journey!
|
| 15 |
+
parser = argparse.ArgumentParser(description="ImageGen Client - A little seeker tool!")
|
| 16 |
+
parser.add_argument("--host", type=str, default="localhost", help="Where the server lives (Host)")
|
| 17 |
+
parser.add_argument("--port", type=int, default=8000, help="The door to knock on (Port)")
|
| 18 |
+
parser.add_argument("--num_images", type=int, default=1, help="How many pictures to paint?")
|
| 19 |
+
parser.add_argument("--image_folder", type=str, default="generated_images", help="Where to keep our treasures")
|
| 20 |
+
# Changing defaults to None so we can use input image size if needed!
|
| 21 |
+
parser.add_argument("--width", type=int, default=None, help="Canvas width (default: 1024 or input image size)")
|
| 22 |
+
parser.add_argument("--height", type=int, default=None, help="Canvas height (default: 1024 or input image size)")
|
| 23 |
+
|
| 24 |
+
# New shiny tools for our quest!
|
| 25 |
+
parser.add_argument("--input", type=str, default=None, help="Path to an input image (for image-to-image magic!)")
|
| 26 |
+
parser.add_argument("--max-size", type=int, default=1024, help="Max size for the input image (we don't want it to get too heavy for the bird!)")
|
| 27 |
+
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
# Reading the prompt from the spirits... I mean, stdin!
|
| 31 |
+
# "What do you desire to see?" *sparkle*
|
| 32 |
+
print("Waiting for a prompt from stdin... (Type something and press Ctrl+D!)")
|
| 33 |
+
try:
|
| 34 |
+
prompt = sys.stdin.read().strip()
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"Oh no! The spirits were silent (stdin error): {e}")
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
if not prompt:
|
| 40 |
+
print("Aww, the prompt was empty! The canvas remains blank.")
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
print(f"Yay! We got a prompt: '{prompt}'")
|
| 44 |
+
|
| 45 |
+
# Restoring the canvas size variables from the journey's start!
|
| 46 |
+
final_width = args.width
|
| 47 |
+
final_height = args.height
|
| 48 |
+
|
| 49 |
+
# Prepare prompt and payload
|
| 50 |
+
url_gen = f"http://{args.host}:{args.port}/v1/images/generations"
|
| 51 |
+
url_edit = f"http://{args.host}:{args.port}/v1/images/edits"
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
if args.input:
|
| 55 |
+
print(f"Oh! You brought a reference image: {args.input}. Let's go to the Editing Shrine!")
|
| 56 |
+
|
| 57 |
+
# Prepare for multipart upload
|
| 58 |
+
# We need to open the image file effectively
|
| 59 |
+
if not os.path.exists(args.input):
|
| 60 |
+
print(f"Eek! I can't find the image at {args.input}")
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
# Open image to ensure it's valid and memory-friendly resize if needed
|
| 64 |
+
with Image.open(args.input) as img:
|
| 65 |
+
img = img.convert("RGB")
|
| 66 |
+
w, h = img.size
|
| 67 |
+
max_dim = max(w, h)
|
| 68 |
+
if max_dim > args.max_size:
|
| 69 |
+
scale = args.max_size / max_dim
|
| 70 |
+
new_w = int(w * scale)
|
| 71 |
+
new_h = int(h * scale)
|
| 72 |
+
print(f"Resizing big image from {w}x{h} to {new_w}x{new_h}. Compact and cute!")
|
| 73 |
+
img = img.resize((new_w, new_h), Image.LANCZOS)
|
| 74 |
+
|
| 75 |
+
if final_width is None: final_width = img.width
|
| 76 |
+
if final_height is None: final_height = img.height
|
| 77 |
+
|
| 78 |
+
# Save to buffer for upload
|
| 79 |
+
buffered = io.BytesIO()
|
| 80 |
+
img.save(buffered, format="PNG")
|
| 81 |
+
buffered.seek(0)
|
| 82 |
+
image_bytes = buffered.getvalue()
|
| 83 |
+
|
| 84 |
+
# Construct multipart payload
|
| 85 |
+
files = {
|
| 86 |
+
'image': ('input.png', image_bytes, 'image/png')
|
| 87 |
+
}
|
| 88 |
+
data = {
|
| 89 |
+
'prompt': prompt,
|
| 90 |
+
'n': args.num_images,
|
| 91 |
+
'size': f"{final_width}x{final_height}",
|
| 92 |
+
'response_format': 'b64_json',
|
| 93 |
+
'guidance_scale': 2.5 # Default specific to edit/kontext if needed
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
print(f"Sending input image to {url_edit}... *whoosh*")
|
| 97 |
+
response = requests.post(url_edit, files=files, data=data)
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
# Standard Generation
|
| 101 |
+
print("Just a prompt? Off to the Creation Forge!")
|
| 102 |
+
if final_width is None: final_width = 1024
|
| 103 |
+
if final_height is None: final_height = 1024
|
| 104 |
+
|
| 105 |
+
payload = {
|
| 106 |
+
"prompt": prompt,
|
| 107 |
+
"n": args.num_images,
|
| 108 |
+
"size": f"{final_width}x{final_height}",
|
| 109 |
+
"response_format": "b64_json"
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
print(f"Sending prompt to {url_gen}... *sparkle*")
|
| 113 |
+
response = requests.post(url_gen, json=payload)
|
| 114 |
+
|
| 115 |
+
response.raise_for_status()
|
| 116 |
+
|
| 117 |
+
data = response.json()
|
| 118 |
+
|
| 119 |
+
# Making sure we have a chest for our treasures
|
| 120 |
+
if not os.path.exists(args.image_folder):
|
| 121 |
+
print(f"Creating a new treasure chest at {args.image_folder}...")
|
| 122 |
+
os.makedirs(args.image_folder)
|
| 123 |
+
|
| 124 |
+
# Unpacking the magic
|
| 125 |
+
images = data.get("data", [])
|
| 126 |
+
print(f"Ooh! The server sent back {len(images)} masterpieces!")
|
| 127 |
+
|
| 128 |
+
for i, img_data in enumerate(images):
|
| 129 |
+
# Decoding the spell
|
| 130 |
+
img_bytes = base64.b64decode(img_data["b64_json"])
|
| 131 |
+
|
| 132 |
+
timestamp = int(time.time())
|
| 133 |
+
filename = f"image_{timestamp}_{i}.png"
|
| 134 |
+
filepath = os.path.join(args.image_folder, filename)
|
| 135 |
+
|
| 136 |
+
with open(filepath, "wb") as f:
|
| 137 |
+
f.write(img_bytes)
|
| 138 |
+
|
| 139 |
+
print(f"Saved masterpiece #{i+1} to {filepath}! It sparkles!")
|
| 140 |
+
|
| 141 |
+
except requests.exceptions.ConnectionError:
|
| 142 |
+
print("Oh no! The server didn't answer. Is it sleeping? (Connection Refused)")
|
| 143 |
+
print("Maybe check if the host and port are correct? We tried: " + url)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"Eek! Something went wrong on the journey: {e}")
|
| 146 |
+
# We'll give it a gentle hug and try to understand...
|
| 147 |
+
print("Don't worry, we can try again later!")
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
main()
|
extras/ImageGenServer.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import base64
|
| 3 |
+
import io
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import uvicorn
|
| 7 |
+
import gc
|
| 8 |
+
import asyncio
|
| 9 |
+
from fastapi import FastAPI, HTTPException
|
| 10 |
+
from pydantic import BaseModel
|
| 11 |
+
from diffusers import FluxPipeline
|
| 12 |
+
from nunchaku import NunchakuFluxTransformer2dModel
|
| 13 |
+
|
| 14 |
+
# Argument parsing
|
| 15 |
+
parser = argparse.ArgumentParser(description="Flux Image Generation Server with Nunchaku")
|
| 16 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
| 17 |
+
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
| 18 |
+
parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-dev", help="Path or Repo ID of the base model")
|
| 19 |
+
parser.add_argument("--optimized-model", type=str, required=True, help="Path to the optimized Nunchaku model safetensors file")
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
app = FastAPI()
|
| 23 |
+
|
| 24 |
+
# Global components
|
| 25 |
+
pipeline = None
|
| 26 |
+
request_lock = asyncio.Lock()
|
| 27 |
+
|
| 28 |
+
def load_model():
|
| 29 |
+
global pipeline
|
| 30 |
+
|
| 31 |
+
print(f"Loading base model from {args.model}...")
|
| 32 |
+
print(f"Loading optimized transformer from {args.optimized_model}...")
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
# Load the optimized transformer
|
| 36 |
+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(args.optimized_model)
|
| 37 |
+
|
| 38 |
+
# Load the pipeline with the optimized transformer
|
| 39 |
+
pipeline = FluxPipeline.from_pretrained(
|
| 40 |
+
args.model,
|
| 41 |
+
transformer=transformer,
|
| 42 |
+
torch_dtype=torch.bfloat16,
|
| 43 |
+
).to("cuda")
|
| 44 |
+
|
| 45 |
+
pipeline.transformer.set_attention_backend("flash")
|
| 46 |
+
pipeline.enable_model_cpu_offload()
|
| 47 |
+
pipeline.enable_vae_tiling()
|
| 48 |
+
pipeline.enable_vae_slicing()
|
| 49 |
+
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Error loading model: {e}")
|
| 52 |
+
raise e
|
| 53 |
+
|
| 54 |
+
print("Model loaded successfully!")
|
| 55 |
+
|
| 56 |
+
def flush():
|
| 57 |
+
gc.collect()
|
| 58 |
+
torch.cuda.empty_cache()
|
| 59 |
+
|
| 60 |
+
class ImageGenerationRequest(BaseModel):
|
| 61 |
+
prompt: str
|
| 62 |
+
n: int = 1
|
| 63 |
+
size: str = "1024x1024"
|
| 64 |
+
response_format: str = "b64_json"
|
| 65 |
+
quality: str = "standard"
|
| 66 |
+
style: str = "vivid"
|
| 67 |
+
|
| 68 |
+
@app.on_event("startup")
|
| 69 |
+
async def startup_event():
|
| 70 |
+
load_model()
|
| 71 |
+
|
| 72 |
+
@app.post("/v1/images/generations")
|
| 73 |
+
async def generate_image(request: ImageGenerationRequest):
|
| 74 |
+
if not pipeline:
|
| 75 |
+
raise HTTPException(status_code=500, detail="Model not loaded")
|
| 76 |
+
|
| 77 |
+
async with request_lock:
|
| 78 |
+
print(f"Received request: {request.prompt}")
|
| 79 |
+
|
| 80 |
+
# Parse size
|
| 81 |
+
try:
|
| 82 |
+
width, height = map(int, request.size.split("x"))
|
| 83 |
+
except ValueError:
|
| 84 |
+
width, height = 1024, 1024
|
| 85 |
+
|
| 86 |
+
# Flux requires dimensions to be multiples of 16 (or 8 depending on VAE)
|
| 87 |
+
# Standard Flux dev usually works well with 1024x1024
|
| 88 |
+
# We'll ensure they are divisible by 16 just in case
|
| 89 |
+
width = (width // 16) * 16
|
| 90 |
+
height = (height // 16) * 16
|
| 91 |
+
|
| 92 |
+
images = []
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
# Generate images
|
| 96 |
+
generated_images = pipeline(
|
| 97 |
+
request.prompt,
|
| 98 |
+
height=height,
|
| 99 |
+
width=width,
|
| 100 |
+
num_inference_steps=4, # Standard for Flux Dev
|
| 101 |
+
guidance_scale=3.5, # Nunchaku example uses 3.5, previous code used 4.0. Let's stick to 3.5 or 4.0. Example says 3.5.
|
| 102 |
+
num_images_per_prompt=request.n
|
| 103 |
+
).images
|
| 104 |
+
|
| 105 |
+
for image in generated_images:
|
| 106 |
+
buffered = io.BytesIO()
|
| 107 |
+
image.save(buffered, format="PNG")
|
| 108 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 109 |
+
images.append({"b64_json": img_str})
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Error during generation: {e}")
|
| 113 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 114 |
+
finally:
|
| 115 |
+
flush()
|
| 116 |
+
|
| 117 |
+
return {
|
| 118 |
+
"created": int(time.time()),
|
| 119 |
+
"data": images
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
extras/ImageGenServer_cpu.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import base64
|
| 3 |
+
import io
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import uvicorn
|
| 7 |
+
import numpy as np
|
| 8 |
+
import gc
|
| 9 |
+
import asyncio
|
| 10 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 11 |
+
from accelerate import infer_auto_device_map, dispatch_model
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
from diffusers import (
|
| 14 |
+
Flux2Pipeline,
|
| 15 |
+
Flux2Transformer2DModel,
|
| 16 |
+
AutoencoderKLFlux2,
|
| 17 |
+
FlowMatchEulerDiscreteScheduler
|
| 18 |
+
)
|
| 19 |
+
from diffusers.pipelines.flux2.pipeline_flux2 import compute_empirical_mu, retrieve_timesteps
|
| 20 |
+
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
|
| 21 |
+
from transformers import Mistral3ForConditionalGeneration, AutoProcessor
|
| 22 |
+
|
| 23 |
+
# Argument parsing
|
| 24 |
+
parser = argparse.ArgumentParser(description="Flux2 Image Generation Server")
|
| 25 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
| 26 |
+
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
| 27 |
+
parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-dev", help="Path or Repo ID of the model")
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
app = FastAPI()
|
| 31 |
+
|
| 32 |
+
# Global components
|
| 33 |
+
text_encoder = None
|
| 34 |
+
tokenizer = None
|
| 35 |
+
transformer = None
|
| 36 |
+
vae = None
|
| 37 |
+
scheduler = None
|
| 38 |
+
image_processor = None
|
| 39 |
+
request_lock = asyncio.Lock()
|
| 40 |
+
|
| 41 |
+
# Device maps
|
| 42 |
+
text_encoder_map = None
|
| 43 |
+
transformer_map = None
|
| 44 |
+
vae_map = None
|
| 45 |
+
|
| 46 |
+
GPU_MEMORY_FRACTION = 0.90
|
| 47 |
+
|
| 48 |
+
def load_model():
|
| 49 |
+
global text_encoder, tokenizer, transformer, vae, scheduler, image_processor
|
| 50 |
+
global text_encoder_map, transformer_map, vae_map
|
| 51 |
+
|
| 52 |
+
print(f"Loading model from {args.model}...")
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
print("Loading Flux2 components...")
|
| 56 |
+
|
| 57 |
+
# Calculate max memory per GPU
|
| 58 |
+
#max_memory = {}
|
| 59 |
+
#if torch.cuda.is_available():
|
| 60 |
+
# for i in range(torch.cuda.device_count()):
|
| 61 |
+
# total_mem = torch.cuda.get_device_properties(i).total_memory
|
| 62 |
+
# max_memory[i] = int(total_mem * GPU_MEMORY_FRACTION)
|
| 63 |
+
|
| 64 |
+
max_memory = {
|
| 65 |
+
0: "5GB", # leave a little headroom
|
| 66 |
+
# 1: "10GB",
|
| 67 |
+
"cpu": "120GB" # your 128GB RAM minus OS
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
# Load Text Encoder (Mistral3) on CPU
|
| 71 |
+
print("Loading Text Encoder on CPU...")
|
| 72 |
+
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
|
| 73 |
+
args.model,
|
| 74 |
+
subfolder="text_encoder",
|
| 75 |
+
torch_dtype=torch.bfloat16,
|
| 76 |
+
device_map="cpu"
|
| 77 |
+
)
|
| 78 |
+
print("Calculating Text Encoder device map...")
|
| 79 |
+
text_encoder_map = infer_auto_device_map(text_encoder, max_memory=max_memory)
|
| 80 |
+
|
| 81 |
+
# Load Tokenizer on CPU
|
| 82 |
+
print("Loading Tokenizer on CPU...")
|
| 83 |
+
tokenizer = AutoProcessor.from_pretrained(
|
| 84 |
+
args.model,
|
| 85 |
+
subfolder="tokenizer",
|
| 86 |
+
device_map="cpu"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Load Transformer on CPU
|
| 90 |
+
print("Loading Transformer on CPU...")
|
| 91 |
+
transformer = Flux2Transformer2DModel.from_pretrained(
|
| 92 |
+
args.model,
|
| 93 |
+
subfolder="transformer",
|
| 94 |
+
torch_dtype=torch.bfloat16,
|
| 95 |
+
device_map="cpu"
|
| 96 |
+
)
|
| 97 |
+
print("Calculating Transformer device map...")
|
| 98 |
+
transformer_map = infer_auto_device_map(transformer, max_memory=max_memory)
|
| 99 |
+
|
| 100 |
+
# Load VAE on CPU
|
| 101 |
+
print("Loading VAE on CPU...")
|
| 102 |
+
vae = AutoencoderKLFlux2.from_pretrained(
|
| 103 |
+
args.model,
|
| 104 |
+
subfolder="vae",
|
| 105 |
+
torch_dtype=torch.bfloat16,
|
| 106 |
+
device_map="cpu"
|
| 107 |
+
)
|
| 108 |
+
print("Calculating VAE device map...")
|
| 109 |
+
vae_map = infer_auto_device_map(vae, max_memory=max_memory)
|
| 110 |
+
|
| 111 |
+
# Initialize Scheduler
|
| 112 |
+
print("Initializing Scheduler...")
|
| 113 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 114 |
+
args.model,
|
| 115 |
+
subfolder="scheduler"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Initialize Image Processor
|
| 119 |
+
print("Initializing Image Processor...")
|
| 120 |
+
# VAE scale factor logic from pipeline
|
| 121 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
| 122 |
+
image_processor = Flux2ImageProcessor(vae_scale_factor=vae_scale_factor * 2)
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"Error loading model: {e}")
|
| 126 |
+
raise e
|
| 127 |
+
|
| 128 |
+
print("Model loaded successfully!")
|
| 129 |
+
|
| 130 |
+
def flush():
|
| 131 |
+
gc.collect()
|
| 132 |
+
torch.cuda.empty_cache()
|
| 133 |
+
|
| 134 |
+
class ImageGenerationRequest(BaseModel):
|
| 135 |
+
prompt: str
|
| 136 |
+
n: int = 1
|
| 137 |
+
size: str = "1024x1024"
|
| 138 |
+
response_format: str = "b64_json"
|
| 139 |
+
quality: str = "standard"
|
| 140 |
+
style: str = "vivid"
|
| 141 |
+
|
| 142 |
+
@app.on_event("startup")
|
| 143 |
+
async def startup_event():
|
| 144 |
+
load_model()
|
| 145 |
+
|
| 146 |
+
@app.post("/v1/images/generations")
|
| 147 |
+
async def generate_image(request: ImageGenerationRequest):
|
| 148 |
+
if not transformer:
|
| 149 |
+
raise HTTPException(status_code=500, detail="Model not loaded")
|
| 150 |
+
|
| 151 |
+
async with request_lock:
|
| 152 |
+
print(f"Received request: {request.prompt}")
|
| 153 |
+
|
| 154 |
+
# Parse size
|
| 155 |
+
try:
|
| 156 |
+
width, height = map(int, request.size.split("x"))
|
| 157 |
+
except ValueError:
|
| 158 |
+
width, height = 1024, 1024
|
| 159 |
+
|
| 160 |
+
num_inference_steps = 28
|
| 161 |
+
guidance_scale = 4.0
|
| 162 |
+
max_sequence_length = 512
|
| 163 |
+
device = torch.device("cuda")
|
| 164 |
+
dtype = torch.bfloat16
|
| 165 |
+
|
| 166 |
+
images = []
|
| 167 |
+
|
| 168 |
+
# 1. Generate embeddings on CPU
|
| 169 |
+
print("Generating embeddings...")
|
| 170 |
+
flush()
|
| 171 |
+
prompt_embeds = Flux2Pipeline._get_mistral_3_small_prompt_embeds(
|
| 172 |
+
text_encoder=text_encoder,
|
| 173 |
+
tokenizer=tokenizer,
|
| 174 |
+
prompt=request.prompt,
|
| 175 |
+
# device=torch.device("cpu"),
|
| 176 |
+
max_sequence_length=max_sequence_length
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# prompt_embeds = prompt_embeds.to("cuda")
|
| 181 |
+
|
| 182 |
+
# 2. Prepare Latents
|
| 183 |
+
# Flux latents are turned into 2x2 patches and packed.
|
| 184 |
+
# This means the latent width and height has to be divisible by the patch size.
|
| 185 |
+
# So the vae scale factor is multiplied by the patch size to account for this
|
| 186 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
| 187 |
+
|
| 188 |
+
height = height or 1024
|
| 189 |
+
width = width or 1024
|
| 190 |
+
|
| 191 |
+
# Resize to be divisible by vae_scale_factor * 2
|
| 192 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 193 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 194 |
+
|
| 195 |
+
num_channels_latents = transformer.config.in_channels // 4
|
| 196 |
+
shape = (1, num_channels_latents * 4, height // 2, width // 2)
|
| 197 |
+
|
| 198 |
+
# 3. Prepare IDs
|
| 199 |
+
# We need to prepare text_ids and latent_ids
|
| 200 |
+
# prompt_embeds shape: (batch_size, seq_len, hidden_dim)
|
| 201 |
+
batch_size, seq_len, _ = prompt_embeds.shape
|
| 202 |
+
|
| 203 |
+
# Repeat for num_images_per_prompt (assuming 1 for now per loop iteration)
|
| 204 |
+
# If request.n > 1, we loop outside or handle batching. Here we loop outside.
|
| 205 |
+
|
| 206 |
+
# Prepare text IDs
|
| 207 |
+
text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(device)
|
| 208 |
+
|
| 209 |
+
for _ in range(request.n):
|
| 210 |
+
# Generate random latents
|
| 211 |
+
latents = torch.randn(shape, device=device, dtype=dtype)
|
| 212 |
+
|
| 213 |
+
# Prepare latent IDs
|
| 214 |
+
latent_ids = Flux2Pipeline._prepare_latent_ids(latents).to(device)
|
| 215 |
+
|
| 216 |
+
# Pack latents
|
| 217 |
+
packed_latents = Flux2Pipeline._pack_latents(latents)
|
| 218 |
+
|
| 219 |
+
# 4. Prepare Timesteps
|
| 220 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 221 |
+
image_seq_len = packed_latents.shape[1]
|
| 222 |
+
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
| 223 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 224 |
+
scheduler,
|
| 225 |
+
num_inference_steps,
|
| 226 |
+
device,
|
| 227 |
+
sigmas=sigmas,
|
| 228 |
+
mu=mu,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# --- SWAP TRANSFORMER TO CUDA ---
|
| 232 |
+
print("Moving Transformer to CUDA...")
|
| 233 |
+
flush()
|
| 234 |
+
dispatch_model(transformer, device_map=transformer_map)
|
| 235 |
+
|
| 236 |
+
# 5. Denoising Loop
|
| 237 |
+
print("Starting denoising loop on CUDA...")
|
| 238 |
+
scheduler.set_begin_index(0)
|
| 239 |
+
|
| 240 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 241 |
+
guidance = guidance.expand(packed_latents.shape[0])
|
| 242 |
+
|
| 243 |
+
for i, t in enumerate(timesteps):
|
| 244 |
+
start_time = time.time()
|
| 245 |
+
# broadcast to batch dimension
|
| 246 |
+
timestep = t.expand(packed_latents.shape[0]).to(packed_latents.dtype)
|
| 247 |
+
|
| 248 |
+
noise_pred = transformer(
|
| 249 |
+
hidden_states=packed_latents,
|
| 250 |
+
timestep=timestep / 1000,
|
| 251 |
+
guidance=guidance,
|
| 252 |
+
encoder_hidden_states=prompt_embeds,
|
| 253 |
+
txt_ids=text_ids,
|
| 254 |
+
img_ids=latent_ids,
|
| 255 |
+
return_dict=False,
|
| 256 |
+
)[0]
|
| 257 |
+
|
| 258 |
+
# step
|
| 259 |
+
packed_latents = scheduler.step(noise_pred, t, packed_latents, return_dict=False)[0]
|
| 260 |
+
|
| 261 |
+
step_time = time.time() - start_time
|
| 262 |
+
print(f"Step {i+1}/{num_inference_steps}: {step_time:.2f}s")
|
| 263 |
+
|
| 264 |
+
# --- SWAP TRANSFORMER TO CPU ---
|
| 265 |
+
print("Moving Transformer to CPU...")
|
| 266 |
+
transformer.to("cpu")
|
| 267 |
+
flush()
|
| 268 |
+
|
| 269 |
+
# --- SWAP VAE TO CUDA ---
|
| 270 |
+
print("Moving VAE to CUDA...")
|
| 271 |
+
dispatch_model(vae, device_map=vae_map)
|
| 272 |
+
|
| 273 |
+
# 6. Decode
|
| 274 |
+
print("Decoding on CUDA...")
|
| 275 |
+
# Move packed_latents to CUDA for decoding (already there, but ensuring)
|
| 276 |
+
packed_latents = packed_latents.to(device)
|
| 277 |
+
latent_ids = latent_ids.to(device)
|
| 278 |
+
|
| 279 |
+
latents = Flux2Pipeline._unpack_latents_with_ids(packed_latents, latent_ids)
|
| 280 |
+
|
| 281 |
+
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
| 282 |
+
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
| 283 |
+
latents.device, latents.dtype
|
| 284 |
+
)
|
| 285 |
+
latents = latents * latents_bn_std + latents_bn_mean
|
| 286 |
+
latents = Flux2Pipeline._unpatchify_latents(latents)
|
| 287 |
+
|
| 288 |
+
image = vae.decode(latents, return_dict=False)[0]
|
| 289 |
+
image = image_processor.postprocess(image, output_type="pil")[0]
|
| 290 |
+
|
| 291 |
+
# --- SWAP VAE TO CPU ---
|
| 292 |
+
print("Moving VAE to CPU...")
|
| 293 |
+
vae.to("cpu")
|
| 294 |
+
|
| 295 |
+
# Convert to base64
|
| 296 |
+
buffered = io.BytesIO()
|
| 297 |
+
image.save(buffered, format="PNG")
|
| 298 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 299 |
+
images.append({"b64_json": img_str})
|
| 300 |
+
|
| 301 |
+
return {
|
| 302 |
+
"created": int(time.time()),
|
| 303 |
+
"data": images
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
if __name__ == "__main__":
|
| 307 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
extras/ImageGenServer_new.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import base64
|
| 3 |
+
import io
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import uvicorn
|
| 7 |
+
import gc
|
| 8 |
+
import asyncio
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from fastapi import FastAPI, HTTPException
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
from diffusers import FluxPipeline, FluxKontextPipeline
|
| 13 |
+
from nunchaku import NunchakuFluxTransformer2dModel
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
# Argument parsing
|
| 17 |
+
parser = argparse.ArgumentParser(description="Flux Image Generation Server with Nunchaku")
|
| 18 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
| 19 |
+
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
| 20 |
+
parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-dev", help="Path or Repo ID of the base model")
|
| 21 |
+
parser.add_argument("--optimized-model", type=str, required=True, help="Path to the optimized Nunchaku model safetensors file")
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
app = FastAPI()
|
| 25 |
+
|
| 26 |
+
# Global components
|
| 27 |
+
pipeline = None
|
| 28 |
+
img2img_pipeline = None
|
| 29 |
+
request_lock = asyncio.Lock()
|
| 30 |
+
|
| 31 |
+
def load_model():
|
| 32 |
+
global pipeline, img2img_pipeline
|
| 33 |
+
|
| 34 |
+
print(f"Loading base model from {args.model}...")
|
| 35 |
+
print(f"Loading optimized transformer from {args.optimized_model}...")
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
# Load the optimized transformer
|
| 39 |
+
# Ensuring transformer is in bfloat16 to match the pipeline expectation
|
| 40 |
+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(args.optimized_model)
|
| 41 |
+
|
| 42 |
+
# Load the pipeline with the optimized transformer
|
| 43 |
+
pipeline = FluxPipeline.from_pretrained(
|
| 44 |
+
args.model,
|
| 45 |
+
transformer=transformer,
|
| 46 |
+
torch_dtype=torch.bfloat16,
|
| 47 |
+
).to("cuda")
|
| 48 |
+
|
| 49 |
+
# Load the Img2Img/Context pipeline sharing the same components
|
| 50 |
+
# We use strict component sharing to avoid VRAM duplication
|
| 51 |
+
print("Initializing FluxKontextPipeline for image inputs...")
|
| 52 |
+
# Since FluxKontextPipeline shares architecture with FluxPipeline, we can initialize it with the same components
|
| 53 |
+
img2img_pipeline = FluxKontextPipeline.from_pretrained(
|
| 54 |
+
args.model,
|
| 55 |
+
transformer=pipeline.transformer,
|
| 56 |
+
vae=pipeline.vae,
|
| 57 |
+
text_encoder=pipeline.text_encoder,
|
| 58 |
+
text_encoder_2=pipeline.text_encoder_2,
|
| 59 |
+
tokenizer=pipeline.tokenizer,
|
| 60 |
+
tokenizer_2=pipeline.tokenizer_2,
|
| 61 |
+
scheduler=pipeline.scheduler,
|
| 62 |
+
torch_dtype=torch.bfloat16
|
| 63 |
+
).to("cuda")
|
| 64 |
+
|
| 65 |
+
# Enable CPU offload for the main pipeline.
|
| 66 |
+
# Since components are shared, this should handle memory management for both.
|
| 67 |
+
pipeline.enable_model_cpu_offload()
|
| 68 |
+
# img2img_pipeline.enable_model_cpu_offload() # Avoid double hook registration
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Error loading model: {e}")
|
| 72 |
+
raise e
|
| 73 |
+
|
| 74 |
+
print("Model loaded successfully!")
|
| 75 |
+
|
| 76 |
+
def flush():
|
| 77 |
+
gc.collect()
|
| 78 |
+
torch.cuda.empty_cache()
|
| 79 |
+
|
| 80 |
+
class ImageGenerationRequest(BaseModel):
|
| 81 |
+
prompt: str
|
| 82 |
+
n: int = 1
|
| 83 |
+
size: str = "1024x1024"
|
| 84 |
+
response_format: str = "b64_json"
|
| 85 |
+
quality: str = "standard"
|
| 86 |
+
style: str = "vivid"
|
| 87 |
+
image: Optional[str] = None # Base64 encoded image
|
| 88 |
+
|
| 89 |
+
@app.on_event("startup")
|
| 90 |
+
async def startup_event():
|
| 91 |
+
load_model()
|
| 92 |
+
|
| 93 |
+
@app.post("/v1/images/generations")
|
| 94 |
+
async def generate_image(request: ImageGenerationRequest):
|
| 95 |
+
if not pipeline:
|
| 96 |
+
raise HTTPException(status_code=500, detail="Model not loaded")
|
| 97 |
+
|
| 98 |
+
async with request_lock:
|
| 99 |
+
print(f"Received request: {request.prompt}")
|
| 100 |
+
|
| 101 |
+
# Parse size
|
| 102 |
+
try:
|
| 103 |
+
width, height = map(int, request.size.split("x"))
|
| 104 |
+
except ValueError:
|
| 105 |
+
width, height = 1024, 1024
|
| 106 |
+
|
| 107 |
+
# Flux requires dimensions to be multiples of 16 (or 8 depending on VAE)
|
| 108 |
+
# Standard Flux dev usually works well with 1024x1024
|
| 109 |
+
# We'll ensure they are divisible by 16 just in case
|
| 110 |
+
width = (width // 16) * 16
|
| 111 |
+
height = (height // 16) * 16
|
| 112 |
+
|
| 113 |
+
images = []
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
input_image = None
|
| 117 |
+
if request.image:
|
| 118 |
+
try:
|
| 119 |
+
# Handle data URI if present
|
| 120 |
+
img_data = request.image
|
| 121 |
+
if "," in img_data:
|
| 122 |
+
img_data = img_data.split(",")[1]
|
| 123 |
+
|
| 124 |
+
input_bytes = base64.b64decode(img_data)
|
| 125 |
+
input_image = Image.open(io.BytesIO(input_bytes)).convert("RGB")
|
| 126 |
+
# Resize input image to match request size
|
| 127 |
+
input_image = input_image.resize((width, height), Image.LANCZOS)
|
| 128 |
+
print(f"Processed input image of size {input_image.size}")
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"Failed to decode input image: {e}")
|
| 131 |
+
raise HTTPException(status_code=400, detail="Invalid image data")
|
| 132 |
+
|
| 133 |
+
# Generate images
|
| 134 |
+
if input_image:
|
| 135 |
+
# Use FluxKontextPipeline
|
| 136 |
+
print("Running FluxKontextPipeline...")
|
| 137 |
+
generated_images = pipeline(
|
| 138 |
+
image=input_image,
|
| 139 |
+
prompt=request.prompt,
|
| 140 |
+
height=height,
|
| 141 |
+
width=width,
|
| 142 |
+
num_inference_steps=28,
|
| 143 |
+
guidance_scale=2.5, # Recommended for Kontext
|
| 144 |
+
num_images_per_prompt=request.n
|
| 145 |
+
).images
|
| 146 |
+
else:
|
| 147 |
+
# Use standard FluxPipeline
|
| 148 |
+
print("Running FluxPipeline...")
|
| 149 |
+
generated_images = pipeline(
|
| 150 |
+
request.prompt,
|
| 151 |
+
height=height,
|
| 152 |
+
width=width,
|
| 153 |
+
num_inference_steps=28, # Standard for Flux Dev
|
| 154 |
+
guidance_scale=3.5, # Nunchaku example uses 3.5
|
| 155 |
+
num_images_per_prompt=request.n
|
| 156 |
+
).images
|
| 157 |
+
|
| 158 |
+
for image in generated_images:
|
| 159 |
+
buffered = io.BytesIO()
|
| 160 |
+
image.save(buffered, format="PNG")
|
| 161 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 162 |
+
images.append({"b64_json": img_str})
|
| 163 |
+
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"Error during generation: {e}")
|
| 166 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 167 |
+
finally:
|
| 168 |
+
flush()
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
"created": int(time.time()),
|
| 172 |
+
"data": images
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
extras/KontextBackend.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import T5EncoderModel, BitsAndBytesConfig
|
| 3 |
+
from diffusers import FluxKontextPipeline
|
| 4 |
+
|
| 5 |
+
class KontextBackend:
|
| 6 |
+
def __init__(self, model_id, optimized_model_path=None):
|
| 7 |
+
self.model_id = model_id
|
| 8 |
+
self.optimized_model_path = optimized_model_path
|
| 9 |
+
self.pipeline = None
|
| 10 |
+
|
| 11 |
+
def load(self):
|
| 12 |
+
print(f"Loading Kontext backend from {self.model_id}...")
|
| 13 |
+
|
| 14 |
+
if self.optimized_model_path:
|
| 15 |
+
print(f"Loading optimized transformer from {self.optimized_model_path}...")
|
| 16 |
+
# Load the optimized transformer (Nunchaku style! *hyah!*)
|
| 17 |
+
try:
|
| 18 |
+
from nunchaku import NunchakuFluxTransformer2dModel
|
| 19 |
+
except ImportError:
|
| 20 |
+
print("Oops, nunchaku not found! Please install it for optimized magic.")
|
| 21 |
+
raise
|
| 22 |
+
|
| 23 |
+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(self.optimized_model_path)
|
| 24 |
+
|
| 25 |
+
text_quant_config = BitsAndBytesConfig(
|
| 26 |
+
load_in_4bit=True,
|
| 27 |
+
bnb_4bit_quant_type="nf4",
|
| 28 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 29 |
+
bnb_4bit_use_double_quant=True
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
| 33 |
+
self.model_id,
|
| 34 |
+
subfolder="text_encoder_2",
|
| 35 |
+
quantization_config=text_quant_config,
|
| 36 |
+
torch_dtype=torch.bfloat16 # bfloat16 for your NVIDIA setup—faster magic!
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Load the pipeline with the optimized transformer
|
| 40 |
+
# We need FluxKontextPipeline for editing magic!
|
| 41 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
| 42 |
+
self.model_id,
|
| 43 |
+
text_encoder_2=text_encoder_2_4bit,
|
| 44 |
+
transformer=transformer,
|
| 45 |
+
torch_dtype=torch.bfloat16,
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
print("No optimized model path provided for KontextBackend. Falling back to standard loading if possible, or maybe we should insist on one?")
|
| 49 |
+
# Original code implied usage of optimized model for Kontext was the main path, but let's support standard if needed,
|
| 50 |
+
# or minimally just load standard logic if that was the fallback.
|
| 51 |
+
# Looking at original code: "if args.optimized_model: ... else: ... Flux2Pipeline"
|
| 52 |
+
# Wait, the original code fell back to Flux2Pipeline if no optimized model was present!
|
| 53 |
+
# The user request says: "create KontextBackend.py that creates a pipeline from base and optional optimized paths"
|
| 54 |
+
# So KontextBackend *should* support both optimized and unoptimized? Or was the fallback in original code actually switching to Flux2?
|
| 55 |
+
# Original code:
|
| 56 |
+
# if args.optimized_model:
|
| 57 |
+
# # Load Nunchaku stuff
|
| 58 |
+
# pipeline = FluxKontextPipeline(...)
|
| 59 |
+
# else:
|
| 60 |
+
# # Load standard stuff
|
| 61 |
+
# pipeline = Flux2Pipeline(...)
|
| 62 |
+
#
|
| 63 |
+
# The USER request says: "KontextBackend.py that creates a pipeline from base and optional optimized paths".
|
| 64 |
+
# This implies if I choose "kontext" backend but don't provide optimized path, it should still load a FluxKontextPipeline (presumably unoptimized/standard).
|
| 65 |
+
# However, FluxKontextPipeline might expect specific components.
|
| 66 |
+
# Let's assume standard loading for FluxKontextPipeline if no optimized model is separate.
|
| 67 |
+
|
| 68 |
+
print(f"Loading standard FluxKontextPipeline from {self.model_id}...")
|
| 69 |
+
# Assuming standard 4-bit loading for memory savings similar to before
|
| 70 |
+
quantization_config = BitsAndBytesConfig(
|
| 71 |
+
load_in_4bit=True,
|
| 72 |
+
bnb_4bit_quant_type="nf4",
|
| 73 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 74 |
+
bnb_4bit_use_double_quant=True,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Use basic from_pretrained
|
| 78 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
| 79 |
+
self.model_id,
|
| 80 |
+
torch_dtype=torch.bfloat16
|
| 81 |
+
# We might need quantization for components if memory is tight, but from_pretrained handles a lot.
|
| 82 |
+
# Let's keep it simple for now as we don't have the Nunchaku specific loading here.
|
| 83 |
+
)
|
| 84 |
+
# Actually, if we look at how specialized the optimized loading was, standard loading might just be:
|
| 85 |
+
# pipeline = FluxKontextPipeline.from_pretrained(model_id, torch_dtype=...)
|
| 86 |
+
|
| 87 |
+
self.pipeline = pipeline
|
| 88 |
+
self.pipeline.to("cuda")
|
| 89 |
+
|
| 90 |
+
# Additional setup if needed (like offload)
|
| 91 |
+
# self.pipeline.enable_model_cpu_offload() # User code had this for optimized path
|
| 92 |
+
|
| 93 |
+
return self.pipeline, self.pipeline
|
extras/NVFP4TextEncoder.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NVFP4 text encoder loader for diffusers image pipelines.
|
| 3 |
+
|
| 4 |
+
Loads a compressed-tensors NVFP4-pack-quantized HuggingFace causal LM and wraps
|
| 5 |
+
it so it can be plugged into ``diffusers.ZImagePipeline`` (or any pipeline
|
| 6 |
+
calling ``self.text_encoder(input_ids, attention_mask, output_hidden_states=True)``).
|
| 7 |
+
|
| 8 |
+
Strategy:
|
| 9 |
+
- Instantiate the HF model on the ``meta`` device (no real allocation).
|
| 10 |
+
- Walk every ``torch.nn.Linear`` and swap it for vLLM's ``ReplicatedLinear`` with
|
| 11 |
+
``CompressedTensorsConfig`` derived from the checkpoint's
|
| 12 |
+
``quantization_config``. This registers ``weight_packed`` / ``weight_scale`` /
|
| 13 |
+
``*_global_scale`` parameters in the exact layout vLLM's
|
| 14 |
+
``CompressedTensorsW4A4Fp4`` scheme expects.
|
| 15 |
+
- Materialise remaining (non-Linear) parameters (embeddings, RMSNorm, k/q norms)
|
| 16 |
+
on the target device & dtype.
|
| 17 |
+
- Stream the safetensors file and dispatch each tensor through the registered
|
| 18 |
+
vLLM ``weight_loader`` (which handles layout swizzling on
|
| 19 |
+
``process_weights_after_loading``).
|
| 20 |
+
- Tie the LM head to the input embedding when ``config.tie_word_embeddings``.
|
| 21 |
+
|
| 22 |
+
The result is a regular ``nn.Module`` matching the HF model's call signature
|
| 23 |
+
(``forward(input_ids, attention_mask, output_hidden_states)``) -- usable directly
|
| 24 |
+
as ``ZImagePipeline.text_encoder``.
|
| 25 |
+
|
| 26 |
+
vLLM requires a minimal global context (distributed process group + model
|
| 27 |
+
parallel state + active VllmConfig) even at TP=1 because ``ReplicatedLinear``
|
| 28 |
+
queries the TP world size at construction. We bootstrap that lazily once.
|
| 29 |
+
|
| 30 |
+
Forced kernel: we set ``VLLM_NVFP4_GEMM_BACKEND=cutlass`` to skip
|
| 31 |
+
flashinfer-cutlass JIT (which needs the ``ninja`` binary on PATH). The vLLM
|
| 32 |
+
CUTLASS kernel is built into the wheel.
|
| 33 |
+
"""
|
| 34 |
+
from __future__ import annotations
|
| 35 |
+
|
| 36 |
+
import json
|
| 37 |
+
import os
|
| 38 |
+
from collections.abc import Iterator
|
| 39 |
+
from typing import Optional
|
| 40 |
+
|
| 41 |
+
import torch
|
| 42 |
+
import torch.nn as nn
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ----------------------------------------------------------------------------
|
| 46 |
+
# One-time vLLM bootstrap (TP=1, no engine, just enough context for ReplicatedLinear)
|
| 47 |
+
# ----------------------------------------------------------------------------
|
| 48 |
+
_VLLM_BOOTSTRAPPED = False
|
| 49 |
+
_VLLM_CONFIG_CTX = None # holds the entered set_current_vllm_config context manager
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _bootstrap_vllm_once() -> None:
|
| 53 |
+
"""Initialise the bits of vLLM that ReplicatedLinear needs at TP=1.
|
| 54 |
+
|
| 55 |
+
Idempotent. Uses ``gloo`` so it works without NCCL/CUDA-aware MPI and even
|
| 56 |
+
when CUDA is busy with the diffusion transformer.
|
| 57 |
+
"""
|
| 58 |
+
global _VLLM_BOOTSTRAPPED, _VLLM_CONFIG_CTX
|
| 59 |
+
if _VLLM_BOOTSTRAPPED:
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
# Force CUTLASS to avoid flashinfer-cutlass JIT (requires `ninja` on PATH).
|
| 63 |
+
os.environ.setdefault("VLLM_NVFP4_GEMM_BACKEND", "cutlass")
|
| 64 |
+
|
| 65 |
+
from vllm.config import VllmConfig
|
| 66 |
+
from vllm.config.vllm import set_current_vllm_config
|
| 67 |
+
from vllm.distributed import (
|
| 68 |
+
ensure_model_parallel_initialized,
|
| 69 |
+
init_distributed_environment,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Pick a free port; world_size=1.
|
| 73 |
+
import socket
|
| 74 |
+
|
| 75 |
+
s = socket.socket()
|
| 76 |
+
s.bind(("127.0.0.1", 0))
|
| 77 |
+
port = s.getsockname()[1]
|
| 78 |
+
s.close()
|
| 79 |
+
|
| 80 |
+
if not torch.distributed.is_initialized():
|
| 81 |
+
init_distributed_environment(
|
| 82 |
+
world_size=1,
|
| 83 |
+
rank=0,
|
| 84 |
+
local_rank=0,
|
| 85 |
+
distributed_init_method=f"tcp://127.0.0.1:{port}",
|
| 86 |
+
backend="gloo",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Enter a long-lived VllmConfig context. We never exit it -- the encoder
|
| 90 |
+
# may construct submodules lazily and ReplicatedLinear calls
|
| 91 |
+
# get_current_vllm_config() at init.
|
| 92 |
+
vc = VllmConfig()
|
| 93 |
+
_VLLM_CONFIG_CTX = set_current_vllm_config(vc)
|
| 94 |
+
_VLLM_CONFIG_CTX.__enter__()
|
| 95 |
+
|
| 96 |
+
ensure_model_parallel_initialized(1, 1)
|
| 97 |
+
_VLLM_BOOTSTRAPPED = True
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ----------------------------------------------------------------------------
|
| 101 |
+
# Module: linear replacement
|
| 102 |
+
# ----------------------------------------------------------------------------
|
| 103 |
+
def _replace_linears_with_replicated(
|
| 104 |
+
model: nn.Module, quant_config
|
| 105 |
+
) -> None:
|
| 106 |
+
"""Recursively swap every ``nn.Linear`` for vLLM ``ReplicatedLinear``.
|
| 107 |
+
|
| 108 |
+
Carries the ``prefix`` so quant_config's ``ignore`` patterns (e.g. ``lm_head``)
|
| 109 |
+
are correctly applied.
|
| 110 |
+
"""
|
| 111 |
+
from vllm.model_executor.layers.linear import ReplicatedLinear
|
| 112 |
+
|
| 113 |
+
def _walk(parent: nn.Module, prefix: str) -> None:
|
| 114 |
+
for child_name, child in list(parent.named_children()):
|
| 115 |
+
qname = f"{prefix}.{child_name}" if prefix else child_name
|
| 116 |
+
if isinstance(child, nn.Linear):
|
| 117 |
+
new = ReplicatedLinear(
|
| 118 |
+
input_size=child.in_features,
|
| 119 |
+
output_size=child.out_features,
|
| 120 |
+
bias=child.bias is not None,
|
| 121 |
+
quant_config=quant_config,
|
| 122 |
+
prefix=qname,
|
| 123 |
+
return_bias=False,
|
| 124 |
+
params_dtype=torch.bfloat16,
|
| 125 |
+
)
|
| 126 |
+
setattr(parent, child_name, new)
|
| 127 |
+
else:
|
| 128 |
+
_walk(child, qname)
|
| 129 |
+
|
| 130 |
+
_walk(model, prefix="")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _materialize_remaining_meta_params(
|
| 134 |
+
model: nn.Module, dtype: torch.dtype, device: torch.device
|
| 135 |
+
) -> None:
|
| 136 |
+
"""Replace any ``meta`` parameter with empty real storage.
|
| 137 |
+
|
| 138 |
+
Only touches parameters NOT already created on a real device by the
|
| 139 |
+
ReplicatedLinear swap above (i.e. embeddings, layernorms, biases).
|
| 140 |
+
"""
|
| 141 |
+
for name, param in list(model.named_parameters(recurse=True)):
|
| 142 |
+
if param.device.type == "meta":
|
| 143 |
+
real = nn.Parameter(
|
| 144 |
+
torch.empty(param.shape, dtype=dtype, device=device),
|
| 145 |
+
requires_grad=False,
|
| 146 |
+
)
|
| 147 |
+
# Replace in the parent module
|
| 148 |
+
parent = model
|
| 149 |
+
*path, leaf = name.split(".")
|
| 150 |
+
for p in path:
|
| 151 |
+
parent = getattr(parent, p)
|
| 152 |
+
setattr(parent, leaf, real)
|
| 153 |
+
# Same for buffers (e.g. rotary inv_freq if registered as buffer on meta)
|
| 154 |
+
for name, buf in list(model.named_buffers(recurse=True)):
|
| 155 |
+
if buf.device.type == "meta":
|
| 156 |
+
real = torch.empty(buf.shape, dtype=buf.dtype, device=device)
|
| 157 |
+
parent = model
|
| 158 |
+
*path, leaf = name.split(".")
|
| 159 |
+
for p in path:
|
| 160 |
+
parent = getattr(parent, p)
|
| 161 |
+
parent.register_buffer(leaf, real, persistent=False)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ----------------------------------------------------------------------------
|
| 165 |
+
# Weight loading
|
| 166 |
+
# ----------------------------------------------------------------------------
|
| 167 |
+
def _iter_safetensors(model_dir: str) -> Iterator[tuple[str, torch.Tensor]]:
|
| 168 |
+
"""Yield (name, tensor) pairs from all *.safetensors shards in ``model_dir``."""
|
| 169 |
+
from safetensors import safe_open
|
| 170 |
+
|
| 171 |
+
# Single-file checkpoint or sharded? Prefer ``model.safetensors.index.json``.
|
| 172 |
+
index_path = os.path.join(model_dir, "model.safetensors.index.json")
|
| 173 |
+
if os.path.exists(index_path):
|
| 174 |
+
with open(index_path) as f:
|
| 175 |
+
index = json.load(f)
|
| 176 |
+
shards = sorted(set(index["weight_map"].values()))
|
| 177 |
+
else:
|
| 178 |
+
# Find all *.safetensors files in dir
|
| 179 |
+
shards = sorted(
|
| 180 |
+
fn for fn in os.listdir(model_dir) if fn.endswith(".safetensors")
|
| 181 |
+
)
|
| 182 |
+
for shard in shards:
|
| 183 |
+
path = os.path.join(model_dir, shard)
|
| 184 |
+
with safe_open(path, framework="pt") as f:
|
| 185 |
+
for key in f.keys():
|
| 186 |
+
yield key, f.get_tensor(key)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _load_weights_into_model(model: nn.Module, model_dir: str) -> None:
|
| 190 |
+
"""Stream safetensors into the (already-structured) model.
|
| 191 |
+
|
| 192 |
+
Uses each ReplicatedLinear's registered ``weight_loader`` for quantised
|
| 193 |
+
params (which handles tensor-parallel sharding, even though TP=1 here it
|
| 194 |
+
keeps casts consistent). Other params (embeddings, layernorms, biases) are
|
| 195 |
+
copied directly.
|
| 196 |
+
"""
|
| 197 |
+
# Strip vllm-omni-style "text_encoder." prefix if present; not applicable
|
| 198 |
+
# here since we load the standalone HF Qwen3 checkpoint where keys start
|
| 199 |
+
# with "model.layers..." / "model.embed_tokens..." / "lm_head...".
|
| 200 |
+
name_to_param: dict[str, nn.Parameter] = dict(model.named_parameters(recurse=True))
|
| 201 |
+
name_to_buffer: dict[str, torch.Tensor] = dict(model.named_buffers(recurse=True))
|
| 202 |
+
|
| 203 |
+
missing = set(name_to_param.keys())
|
| 204 |
+
unexpected = []
|
| 205 |
+
|
| 206 |
+
for key, tensor in _iter_safetensors(model_dir):
|
| 207 |
+
# Skip rotary inv_freq etc that aren't params (rare in modern HF saves)
|
| 208 |
+
if key in name_to_param:
|
| 209 |
+
param = name_to_param[key]
|
| 210 |
+
wl = getattr(param, "weight_loader", None)
|
| 211 |
+
if wl is not None:
|
| 212 |
+
wl(param, tensor.to(param.device))
|
| 213 |
+
else:
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
param.data.copy_(tensor.to(param.device, dtype=param.dtype))
|
| 216 |
+
missing.discard(key)
|
| 217 |
+
elif key in name_to_buffer:
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
name_to_buffer[key].copy_(tensor.to(name_to_buffer[key].device))
|
| 220 |
+
else:
|
| 221 |
+
unexpected.append(key)
|
| 222 |
+
|
| 223 |
+
# Tied embeddings (lm_head.weight not in checkpoint when tie_word_embeddings=True)
|
| 224 |
+
cfg = getattr(model, "config", None)
|
| 225 |
+
if cfg is not None and getattr(cfg, "tie_word_embeddings", False):
|
| 226 |
+
try:
|
| 227 |
+
inp_emb = model.get_input_embeddings().weight
|
| 228 |
+
model.lm_head.weight = inp_emb # share storage
|
| 229 |
+
missing.discard("lm_head.weight")
|
| 230 |
+
except Exception:
|
| 231 |
+
pass
|
| 232 |
+
|
| 233 |
+
if missing:
|
| 234 |
+
# It's OK if missing entries are *purely* lm_head.weight when tied; we
|
| 235 |
+
# already handled that above. Anything else is fatal-ish.
|
| 236 |
+
leftover = sorted(missing)
|
| 237 |
+
if leftover:
|
| 238 |
+
print(
|
| 239 |
+
f"[NVFP4TextEncoder] WARN: {len(leftover)} params missing from checkpoint; "
|
| 240 |
+
f"first 5: {leftover[:5]}"
|
| 241 |
+
)
|
| 242 |
+
if unexpected:
|
| 243 |
+
print(
|
| 244 |
+
f"[NVFP4TextEncoder] WARN: {len(unexpected)} keys in checkpoint unused; "
|
| 245 |
+
f"first 5: {unexpected[:5]}"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _process_weights_after_loading(model: nn.Module) -> None:
|
| 250 |
+
"""Invoke vLLM's per-layer ``process_weights_after_loading`` for each
|
| 251 |
+
ReplicatedLinear (renames ``weight_packed`` -> ``weight``, computes ``alpha``,
|
| 252 |
+
swizzles scales for the CUTLASS kernel, etc.)."""
|
| 253 |
+
for module in model.modules():
|
| 254 |
+
qm = getattr(module, "quant_method", None)
|
| 255 |
+
if qm is not None and hasattr(qm, "process_weights_after_loading"):
|
| 256 |
+
qm.process_weights_after_loading(module)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ----------------------------------------------------------------------------
|
| 260 |
+
# Public API
|
| 261 |
+
# ----------------------------------------------------------------------------
|
| 262 |
+
def load_nvfp4_text_encoder(
|
| 263 |
+
model_dir: str,
|
| 264 |
+
device: str | torch.device = "cuda",
|
| 265 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 266 |
+
) -> nn.Module:
|
| 267 |
+
"""Load an NVFP4-quantised HuggingFace causal LM as a plug-in text encoder.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
model_dir: path to the checkpoint directory containing ``config.json``
|
| 271 |
+
and ``model*.safetensors``. The config must carry a
|
| 272 |
+
``quantization_config`` block with ``"format": "nvfp4-pack-quantized"``.
|
| 273 |
+
device: target CUDA device (forwards to ``model.to(device)``-equivalent
|
| 274 |
+
during materialisation).
|
| 275 |
+
dtype: activation / non-quantised-param dtype.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
A ``PreTrainedModel`` whose ``Linear`` layers are NVFP4 inside the vLLM
|
| 279 |
+
CUTLASS kernel. Activations flow as ``dtype``.
|
| 280 |
+
"""
|
| 281 |
+
_bootstrap_vllm_once()
|
| 282 |
+
|
| 283 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 284 |
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
| 285 |
+
CompressedTensorsConfig,
|
| 286 |
+
)
|
| 287 |
+
from vllm.model_executor.models.transformers.utils import (
|
| 288 |
+
init_on_device_without_buffers,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
hf_config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
|
| 292 |
+
if not getattr(hf_config, "quantization_config", None):
|
| 293 |
+
raise ValueError(
|
| 294 |
+
f"{model_dir}/config.json has no `quantization_config`; "
|
| 295 |
+
"this loader only handles NVFP4-quantised checkpoints."
|
| 296 |
+
)
|
| 297 |
+
quant_config = CompressedTensorsConfig.from_config(hf_config.quantization_config)
|
| 298 |
+
|
| 299 |
+
# 1) Build the model skeleton on meta (zero allocation).
|
| 300 |
+
with init_on_device_without_buffers("meta"):
|
| 301 |
+
model = AutoModelForCausalLM.from_config(hf_config)
|
| 302 |
+
|
| 303 |
+
# 2) Swap Linear -> ReplicatedLinear(quant_config) (creates real CUDA params
|
| 304 |
+
# of the quantised shapes).
|
| 305 |
+
target_device = torch.device(device)
|
| 306 |
+
_replace_linears_with_replicated(model, quant_config)
|
| 307 |
+
|
| 308 |
+
# 3) Materialise any leftover meta parameters (embeddings, RMSNorms, ...)
|
| 309 |
+
_materialize_remaining_meta_params(model, dtype=dtype, device=target_device)
|
| 310 |
+
|
| 311 |
+
# 4) Move newly-created quantised params to target device (ReplicatedLinear
|
| 312 |
+
# creates them on the current default device which is usually CPU).
|
| 313 |
+
model.to(target_device)
|
| 314 |
+
|
| 315 |
+
# 5) Load weights via per-param weight_loader.
|
| 316 |
+
_load_weights_into_model(model, model_dir)
|
| 317 |
+
|
| 318 |
+
# 6) Let vLLM swizzle scales / rename weight_packed->weight / compute alpha.
|
| 319 |
+
_process_weights_after_loading(model)
|
| 320 |
+
|
| 321 |
+
# 7) Match HF semantics for downstream pipelines.
|
| 322 |
+
model.eval()
|
| 323 |
+
model.config.use_cache = False
|
| 324 |
+
return model
|
extras/OmniImageEditServer.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import base64
|
| 3 |
+
import io
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import uvicorn
|
| 7 |
+
import gc
|
| 8 |
+
import asyncio
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
import inspect
|
| 13 |
+
|
| 14 |
+
# Add OmniGen2-DFloat11 to path
|
| 15 |
+
# Script is in imagegen/, so we go up one level and into packages/OmniGen2-DFloat11
|
| 16 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
project_root = os.path.dirname(current_dir)
|
| 18 |
+
omnigen_path = os.path.join(project_root, "packages", "OmniGen2")
|
| 19 |
+
sys.path.insert(0, omnigen_path)
|
| 20 |
+
|
| 21 |
+
from typing import List, Optional
|
| 22 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 23 |
+
from pydantic import BaseModel
|
| 24 |
+
from PIL import Image, ImageOps
|
| 25 |
+
|
| 26 |
+
# Import OmniGen2 and DFloat11 components
|
| 27 |
+
from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
|
| 28 |
+
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
|
| 29 |
+
from transformers import CLIPProcessor, BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
|
| 30 |
+
from transformers.modeling_utils import no_init_weights
|
| 31 |
+
|
| 32 |
+
# Yay! Nikola here, ready to bring the OmniGen2 magic to our village!
|
| 33 |
+
# This server is like a new canvas for our artistic endeavors!
|
| 34 |
+
|
| 35 |
+
# Argument parsing
|
| 36 |
+
parser = argparse.ArgumentParser(description="OmniGen2 Image Edit Server")
|
| 37 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
| 38 |
+
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
| 39 |
+
# Default paths relative to project root as per plan
|
| 40 |
+
parser.add_argument("--base-model", type=str, default="../models/OmniGen2", help="Path to base OmniGen2 model")
|
| 41 |
+
parser.add_argument("--dtype", type=str, default='bf16', choices=['fp32', 'fp16', 'bf16'], help="Model precision")
|
| 42 |
+
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
app = FastAPI()
|
| 46 |
+
|
| 47 |
+
# Global components
|
| 48 |
+
pipeline = None
|
| 49 |
+
request_lock = asyncio.Lock()
|
| 50 |
+
|
| 51 |
+
def load_model():
|
| 52 |
+
global pipeline
|
| 53 |
+
|
| 54 |
+
print(f"Loading OmniGen2 from {args.base_model}...")
|
| 55 |
+
|
| 56 |
+
# Determine usage dtype
|
| 57 |
+
weight_dtype = torch.float32
|
| 58 |
+
if args.dtype == 'fp16':
|
| 59 |
+
weight_dtype = torch.float16
|
| 60 |
+
elif args.dtype == 'bf16':
|
| 61 |
+
weight_dtype = torch.bfloat16
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
# Load the base pipeline (tokenizer, scheduler, etc.)
|
| 65 |
+
# processor needs to be loaded separately sometimes depending on library version,
|
| 66 |
+
# but following inference.py pattern:
|
| 67 |
+
|
| 68 |
+
# Manually load MLLM in 4-bit to save VRAM, yay!
|
| 69 |
+
print("Loading MLLM in 4-bit mode for extra village efficiency!")
|
| 70 |
+
quantization_config = BitsAndBytesConfig(
|
| 71 |
+
load_in_4bit=True,
|
| 72 |
+
bnb_4bit_quant_type="nf4",
|
| 73 |
+
bnb_4bit_compute_dtype=weight_dtype,
|
| 74 |
+
)
|
| 75 |
+
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 76 |
+
args.base_model,
|
| 77 |
+
subfolder="mllm",
|
| 78 |
+
quantization_config=quantization_config,
|
| 79 |
+
torch_dtype=weight_dtype,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
pipeline = OmniGen2Pipeline.from_pretrained(
|
| 83 |
+
args.base_model,
|
| 84 |
+
mllm=mllm,
|
| 85 |
+
processor=CLIPProcessor.from_pretrained(
|
| 86 |
+
args.base_model,
|
| 87 |
+
subfolder="processor",
|
| 88 |
+
use_fast=True
|
| 89 |
+
),
|
| 90 |
+
torch_dtype=weight_dtype,
|
| 91 |
+
trust_remote_code=True,
|
| 92 |
+
).to("cuda")
|
| 93 |
+
|
| 94 |
+
pipeline.enable_taylorseer = True
|
| 95 |
+
pipeline.transformer.set_attention_backend("flash")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
print("Enabling CPU offload...")
|
| 99 |
+
#pipeline.enable_model_cpu_offload()
|
| 100 |
+
#pipeline.enable_sequential_cpu_offload()
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Oh no! The OmniGen2 spirit refused to manifest: {e}")
|
| 103 |
+
raise e
|
| 104 |
+
|
| 105 |
+
print("OmniGen2 loaded successfully! Let's paint the village!")
|
| 106 |
+
|
| 107 |
+
def flush():
|
| 108 |
+
gc.collect()
|
| 109 |
+
torch.cuda.empty_cache()
|
| 110 |
+
|
| 111 |
+
class ImageGenerationRequest(BaseModel):
|
| 112 |
+
prompt: str
|
| 113 |
+
n: int = 1
|
| 114 |
+
size: str = "1024x1024"
|
| 115 |
+
response_format: str = "b64_json"
|
| 116 |
+
quality: str = "standard"
|
| 117 |
+
style: str = "vivid"
|
| 118 |
+
|
| 119 |
+
@app.on_event("startup")
|
| 120 |
+
async def startup_event():
|
| 121 |
+
load_model()
|
| 122 |
+
|
| 123 |
+
@app.post("/v1/images/edits")
|
| 124 |
+
async def edit_image(
|
| 125 |
+
image: UploadFile = File(...),
|
| 126 |
+
prompt: str = Form(...),
|
| 127 |
+
n: int = Form(1),
|
| 128 |
+
size: str = Form("1024x1024"),
|
| 129 |
+
response_format: str = Form("b64_json"),
|
| 130 |
+
guidance_scale: float = Form(2.5), # Image guidance scale
|
| 131 |
+
strength: float = Form(1.0) # Using strength to map to something or just ignored?
|
| 132 |
+
# OmniGen uses image_guidance_scale.
|
| 133 |
+
# We can map strength to text_guidance_scale maybe?
|
| 134 |
+
# Let's keep defaults for now from inference.py
|
| 135 |
+
):
|
| 136 |
+
if not pipeline:
|
| 137 |
+
raise HTTPException(status_code=500, detail="Model not loaded")
|
| 138 |
+
|
| 139 |
+
async with request_lock:
|
| 140 |
+
print(f"Received edit request: {prompt}")
|
| 141 |
+
|
| 142 |
+
# Processing the input image
|
| 143 |
+
try:
|
| 144 |
+
contents = await image.read()
|
| 145 |
+
init_image = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 146 |
+
init_image = ImageOps.exif_transpose(init_image)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
raise HTTPException(status_code=400, detail=f"Invalid image file: {e}")
|
| 149 |
+
|
| 150 |
+
# Parse max target dimensions from requested size
|
| 151 |
+
try:
|
| 152 |
+
target_width, target_height = map(int, size.split("x"))
|
| 153 |
+
except ValueError:
|
| 154 |
+
target_width, target_height = 1024, 1024
|
| 155 |
+
|
| 156 |
+
# Calculate new dimensions preserving aspect ratio
|
| 157 |
+
orig_width, orig_height = init_image.size
|
| 158 |
+
scale = min(target_width / orig_width, target_height / orig_height)
|
| 159 |
+
new_width = int(orig_width * scale)
|
| 160 |
+
new_height = int(orig_height * scale)
|
| 161 |
+
|
| 162 |
+
# Enforce multiples of 16 for compatibility
|
| 163 |
+
width = (new_width // 16) * 16
|
| 164 |
+
height = (new_height // 16) * 16
|
| 165 |
+
|
| 166 |
+
response_images = []
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
# Generate edits
|
| 170 |
+
# OmniGen2Pipeline signature from inference.py:
|
| 171 |
+
# prompt, input_images, width, height, num_inference_steps, ...
|
| 172 |
+
|
| 173 |
+
# Using defaults from inference.py for now
|
| 174 |
+
results = pipeline(
|
| 175 |
+
prompt=prompt,
|
| 176 |
+
input_images=[init_image],
|
| 177 |
+
width=width,
|
| 178 |
+
height=height,
|
| 179 |
+
num_inference_steps=26, # Standard for OmniGen2
|
| 180 |
+
max_sequence_length=1024,
|
| 181 |
+
text_guidance_scale=5.0, # Default per inference.py
|
| 182 |
+
image_guidance_scale=guidance_scale, # Map guidance_scale from request here
|
| 183 |
+
cfg_range=(0.0, 1.0),
|
| 184 |
+
negative_prompt="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
|
| 185 |
+
num_images_per_prompt=n,
|
| 186 |
+
output_type="pil",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
for img in results.images:
|
| 190 |
+
buffered = io.BytesIO()
|
| 191 |
+
img.save(buffered, format="PNG")
|
| 192 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 193 |
+
response_images.append({"b64_json": img_str})
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"Error during editing: {e}")
|
| 197 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
| 198 |
+
finally:
|
| 199 |
+
flush()
|
| 200 |
+
|
| 201 |
+
return {
|
| 202 |
+
"created": int(time.time()),
|
| 203 |
+
"data": response_images
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
@app.post("/v1/images/generations")
|
| 207 |
+
async def generate_image(request: ImageGenerationRequest):
|
| 208 |
+
if not pipeline:
|
| 209 |
+
raise HTTPException(status_code=500, detail="Model not loaded")
|
| 210 |
+
|
| 211 |
+
async with request_lock:
|
| 212 |
+
print(f"Received generation request: {request.prompt}")
|
| 213 |
+
|
| 214 |
+
# Parse size
|
| 215 |
+
try:
|
| 216 |
+
width, height = map(int, request.size.split("x"))
|
| 217 |
+
except ValueError:
|
| 218 |
+
width, height = 1024, 1024
|
| 219 |
+
|
| 220 |
+
# Enforce multiples of 16 for compatibility
|
| 221 |
+
width = (width // 16) * 16
|
| 222 |
+
height = (height // 16) * 16
|
| 223 |
+
|
| 224 |
+
response_images = []
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
# Generate images (input_images=None for txt2img)
|
| 228 |
+
results = pipeline(
|
| 229 |
+
prompt=request.prompt,
|
| 230 |
+
input_images=None,
|
| 231 |
+
width=width,
|
| 232 |
+
height=height,
|
| 233 |
+
num_inference_steps=26,
|
| 234 |
+
max_sequence_length=1024,
|
| 235 |
+
text_guidance_scale=5.0,
|
| 236 |
+
image_guidance_scale=2.0, # Default
|
| 237 |
+
cfg_range=(0.0, 1.0),
|
| 238 |
+
negative_prompt="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
|
| 239 |
+
num_images_per_prompt=request.n,
|
| 240 |
+
output_type="pil",
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
for img in results.images:
|
| 244 |
+
buffered = io.BytesIO()
|
| 245 |
+
img.save(buffered, format="PNG")
|
| 246 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 247 |
+
response_images.append({"b64_json": img_str})
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
print(f"Error during generation: {e}")
|
| 251 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
| 252 |
+
finally:
|
| 253 |
+
flush()
|
| 254 |
+
|
| 255 |
+
return {
|
| 256 |
+
"created": int(time.time()),
|
| 257 |
+
"data": response_images
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
extras/QwenBackend.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from nunchaku.utils import get_gpu_memory, get_precision
|
| 3 |
+
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
|
| 4 |
+
|
| 5 |
+
class QwenBackend:
|
| 6 |
+
def __init__(self, model_id, optimized_model_path=None, optimized_edit_model_path=None, uma=False):
|
| 7 |
+
self.model_id = model_id
|
| 8 |
+
self.optimized_model_path = optimized_model_path
|
| 9 |
+
self.optimized_edit_model_path = optimized_edit_model_path
|
| 10 |
+
self.uma = uma
|
| 11 |
+
self.pipeline = None
|
| 12 |
+
self.rank = 32 # Default from example (was 128 in snippet, user example has 32)
|
| 13 |
+
# Check snippet: rank = 32 in the example content I read.
|
| 14 |
+
|
| 15 |
+
def load(self):
|
| 16 |
+
print(f"Loading Qwen backend from {self.model_id}...")
|
| 17 |
+
|
| 18 |
+
if not self.optimized_model_path:
|
| 19 |
+
print("Warning: No optimized model path provided for QwenBackend. This requires the Nunchaku optimized model.")
|
| 20 |
+
|
| 21 |
+
# Scheduler config from example
|
| 22 |
+
import math
|
| 23 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 24 |
+
|
| 25 |
+
scheduler_config = {
|
| 26 |
+
"base_image_seq_len": 256,
|
| 27 |
+
"base_shift": math.log(3),
|
| 28 |
+
"invert_sigmas": False,
|
| 29 |
+
"max_image_seq_len": 8192,
|
| 30 |
+
"max_shift": math.log(3),
|
| 31 |
+
"num_train_timesteps": 1000,
|
| 32 |
+
"shift": 1.0,
|
| 33 |
+
"shift_terminal": None,
|
| 34 |
+
"stochastic_sampling": False,
|
| 35 |
+
"time_shift_type": "exponential",
|
| 36 |
+
"use_beta_sigmas": False,
|
| 37 |
+
"use_dynamic_shifting": True,
|
| 38 |
+
"use_exponential_sigmas": False,
|
| 39 |
+
"use_karras_sigmas": False,
|
| 40 |
+
}
|
| 41 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
|
| 42 |
+
|
| 43 |
+
# Load the base transformer (T2I)
|
| 44 |
+
print(f"Loading T2I NunchakuQwenImageTransformer2DModel from {self.optimized_model_path} with FA2...")
|
| 45 |
+
transformer_t2i = NunchakuQwenImageTransformer2DModel.from_pretrained(
|
| 46 |
+
self.optimized_model_path,
|
| 47 |
+
attn_implementation="flash_attention_2"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Load the edit transformer
|
| 51 |
+
if self.optimized_edit_model_path:
|
| 52 |
+
print(f"Loading Edit NunchakuQwenImageTransformer2DModel from {self.optimized_edit_model_path} with FA2...")
|
| 53 |
+
transformer_edit = NunchakuQwenImageTransformer2DModel.from_pretrained(
|
| 54 |
+
self.optimized_edit_model_path,
|
| 55 |
+
attn_implementation="flash_attention_2"
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
print(f"Using shared transformer for Edit pipeline...")
|
| 59 |
+
transformer_edit = transformer_t2i
|
| 60 |
+
|
| 61 |
+
print(f"Loading QwenImagePipeline from {self.model_id}...")
|
| 62 |
+
# Use QwenImagePipeline (T2I)
|
| 63 |
+
from diffusers import QwenImagePipeline, QwenImageEditPlusPipeline
|
| 64 |
+
|
| 65 |
+
text_encoder = None
|
| 66 |
+
if self.uma:
|
| 67 |
+
print("UMA mode: Loading text_encoder in 8-bit using BitsAndBytes...")
|
| 68 |
+
from transformers import BitsAndBytesConfig, AutoModel
|
| 69 |
+
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 70 |
+
text_encoder = AutoModel.from_pretrained(
|
| 71 |
+
self.model_id,
|
| 72 |
+
subfolder="text_encoder",
|
| 73 |
+
quantization_config=bnb_config,
|
| 74 |
+
torch_dtype=torch.bfloat16,
|
| 75 |
+
trust_remote_code=True
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# 1. Load Edit Pipeline (To handle processor correctly)
|
| 79 |
+
print(f"Loading QwenImageEditPlusPipeline from {self.model_id}...")
|
| 80 |
+
|
| 81 |
+
pipeline_kwargs = {
|
| 82 |
+
"transformer": transformer_edit,
|
| 83 |
+
"scheduler": scheduler,
|
| 84 |
+
"torch_dtype": torch.bfloat16
|
| 85 |
+
}
|
| 86 |
+
if text_encoder is not None:
|
| 87 |
+
pipeline_kwargs["text_encoder"] = text_encoder
|
| 88 |
+
|
| 89 |
+
edit_pipeline = QwenImageEditPlusPipeline.from_pretrained(
|
| 90 |
+
self.model_id,
|
| 91 |
+
**pipeline_kwargs
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# 2. Create T2I Pipeline sharing components (except transformer if separate)
|
| 95 |
+
print("Creating QwenImagePipeline (T2I) with shared components...")
|
| 96 |
+
|
| 97 |
+
# Ensure we have a text_encoder and tokenizer
|
| 98 |
+
if edit_pipeline.text_encoder is None:
|
| 99 |
+
print("Text encoder not found in edit_pipeline, loading manually...")
|
| 100 |
+
# Load from model_id or subfolder
|
| 101 |
+
if text_encoder is None:
|
| 102 |
+
from transformers import AutoModel
|
| 103 |
+
text_encoder = AutoModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
| 104 |
+
|
| 105 |
+
# CRITICAL FIX: Assign it back to the pipeline!
|
| 106 |
+
edit_pipeline.register_modules(text_encoder=text_encoder)
|
| 107 |
+
else:
|
| 108 |
+
text_encoder = edit_pipeline.text_encoder
|
| 109 |
+
|
| 110 |
+
tokenizer = edit_pipeline.tokenizer
|
| 111 |
+
|
| 112 |
+
if tokenizer is None:
|
| 113 |
+
print("Tokenizer not found in edit_pipeline, loading manually...")
|
| 114 |
+
from transformers import AutoTokenizer
|
| 115 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_id, subfolder="tokenizer", trust_remote_code=True)
|
| 116 |
+
edit_pipeline.register_modules(tokenizer=tokenizer)
|
| 117 |
+
|
| 118 |
+
pipeline = QwenImagePipeline(
|
| 119 |
+
transformer=transformer_t2i,
|
| 120 |
+
scheduler=edit_pipeline.scheduler,
|
| 121 |
+
vae=edit_pipeline.vae,
|
| 122 |
+
text_encoder=text_encoder,
|
| 123 |
+
tokenizer=tokenizer,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Manually assign processors if needed (though QwenImagePipeline creates its own image_processor)
|
| 127 |
+
# pipeline.feature_extractor = edit_pipeline.image_processor
|
| 128 |
+
|
| 129 |
+
# Logic for offloading / UMA
|
| 130 |
+
if self.uma:
|
| 131 |
+
print("UMA mode enabled: Text encoder loaded in 8-bit. Moving other components to GPU.")
|
| 132 |
+
# Note: 8-bit text encoder is already handled by bitsandbytes (on GPU or offloaded as needed, typically GPU).
|
| 133 |
+
|
| 134 |
+
# Explicitly move transformers to CUDA
|
| 135 |
+
print("Moving T2I Transformer to CUDA...")
|
| 136 |
+
transformer_t2i.to("cuda")
|
| 137 |
+
|
| 138 |
+
if transformer_edit != transformer_t2i:
|
| 139 |
+
print("Moving Edit Transformer to CUDA...")
|
| 140 |
+
transformer_edit.to("cuda")
|
| 141 |
+
|
| 142 |
+
# We need to ensure other components (VAE) are on CUDA.
|
| 143 |
+
if hasattr(edit_pipeline, "vae") and edit_pipeline.vae:
|
| 144 |
+
print("Moving VAE to CUDA...")
|
| 145 |
+
edit_pipeline.vae.to("cuda")
|
| 146 |
+
|
| 147 |
+
# Since we can't call pipeline.to("cuda") generally if 8-bit modules are present (sometimes safe, sometimes not),
|
| 148 |
+
# we manually handle it or trust loaded components.
|
| 149 |
+
pass
|
| 150 |
+
# Note: pipeline (T2I) shares components, so it should be on cuda too.
|
| 151 |
+
else:
|
| 152 |
+
print("Non-UMA mode: Using aggressive per-layer offloading.")
|
| 153 |
+
transformer_t2i.set_offload(
|
| 154 |
+
True, use_pin_memory=True, num_blocks_on_gpu=8
|
| 155 |
+
)
|
| 156 |
+
if self.optimized_edit_model_path:
|
| 157 |
+
transformer_edit.set_offload(
|
| 158 |
+
True, use_pin_memory=True, num_blocks_on_gpu=8
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
edit_pipeline._exclude_from_cpu_offload.append("transformer")
|
| 162 |
+
edit_pipeline.enable_sequential_cpu_offload()
|
| 163 |
+
|
| 164 |
+
# The T2I pipeline (pipeline) also needs to handle offloading.
|
| 165 |
+
# If we manually loaded text_encoder, it might not be attached to edit_pipeline's offload hooks.
|
| 166 |
+
# We should enable sequential CPU offload for the T2I pipeline too.
|
| 167 |
+
pipeline.enable_sequential_cpu_offload()
|
| 168 |
+
|
| 169 |
+
if self.optimized_edit_model_path:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
self.pipeline = pipeline
|
| 173 |
+
self.edit_pipeline = edit_pipeline
|
| 174 |
+
return self.pipeline, self.edit_pipeline
|
extras/QwenImageBackend.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from nunchaku.utils import get_gpu_memory, get_precision
|
| 3 |
+
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
|
| 4 |
+
|
| 5 |
+
class QwenImageBackend:
|
| 6 |
+
def __init__(self, model_id, optimized_model_path=None):
|
| 7 |
+
self.model_id = model_id
|
| 8 |
+
self.optimized_model_path = optimized_model_path
|
| 9 |
+
self.pipeline = None
|
| 10 |
+
self.rank = 32 # default rank as per example
|
| 11 |
+
|
| 12 |
+
def load(self):
|
| 13 |
+
print(f"Loading QwenImageBackend from {self.model_id}...")
|
| 14 |
+
# Scheduler config (same as QwenBackend)
|
| 15 |
+
import math
|
| 16 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 17 |
+
scheduler_config = {
|
| 18 |
+
"base_image_seq_len": 256,
|
| 19 |
+
"base_shift": math.log(3),
|
| 20 |
+
"invert_sigmas": False,
|
| 21 |
+
"max_image_seq_len": 8192,
|
| 22 |
+
"max_shift": math.log(3),
|
| 23 |
+
"num_train_timesteps": 1000,
|
| 24 |
+
"shift": 1.0,
|
| 25 |
+
"shift_terminal": None,
|
| 26 |
+
"stochastic_sampling": False,
|
| 27 |
+
"time_shift_type": "exponential",
|
| 28 |
+
"use_beta_sigmas": False,
|
| 29 |
+
"use_dynamic_shifting": True,
|
| 30 |
+
"use_exponential_sigmas": False,
|
| 31 |
+
"use_karras_sigmas": False,
|
| 32 |
+
}
|
| 33 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
|
| 34 |
+
|
| 35 |
+
# Load transformer (optimized model)
|
| 36 |
+
print(f"Loading NunchakuQwenImageTransformer2DModel from {self.optimized_model_path}...")
|
| 37 |
+
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(self.optimized_model_path)
|
| 38 |
+
|
| 39 |
+
# Load T2I pipeline
|
| 40 |
+
from diffusers import QwenImagePipeline
|
| 41 |
+
pipeline = QwenImagePipeline.from_pretrained(
|
| 42 |
+
self.model_id,
|
| 43 |
+
transformer=transformer,
|
| 44 |
+
scheduler=scheduler,
|
| 45 |
+
torch_dtype=torch.bfloat16,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Offloading logic (same as QwenBackend)
|
| 49 |
+
if get_gpu_memory() > 18:
|
| 50 |
+
print("GPU memory > 18GB, using cpu offload")
|
| 51 |
+
pipeline.enable_model_cpu_offload()
|
| 52 |
+
else:
|
| 53 |
+
print("GPU memory <= 18GB, using per-layer offloading for low VRAM")
|
| 54 |
+
transformer.set_offload(True, use_pin_memory=False, num_blocks_on_gpu=1)
|
| 55 |
+
pipeline._exclude_from_cpu_offload.append("transformer")
|
| 56 |
+
pipeline.enable_sequential_cpu_offload()
|
| 57 |
+
|
| 58 |
+
self.pipeline = pipeline
|
| 59 |
+
# For edit endpoint we reuse the same pipeline (ignores image)
|
| 60 |
+
return self.pipeline, self.pipeline
|
extras/ZImageTurboBackend.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers import ZImagePipeline
|
| 4 |
+
from nunchaku.models.transformers.transformer_zimage import NunchakuZImageTransformer2DModel
|
| 5 |
+
from nunchaku.utils import get_gpu_memory
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ZImageTurboBackend:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
model_id,
|
| 12 |
+
optimized_model_path=None,
|
| 13 |
+
optimized_edit_model_path=None,
|
| 14 |
+
uma=False,
|
| 15 |
+
nvfp4_text_encoder_path: str | None = None,
|
| 16 |
+
):
|
| 17 |
+
self.model_id = model_id
|
| 18 |
+
self.optimized_model_path = optimized_model_path
|
| 19 |
+
self.pipeline = None
|
| 20 |
+
self.uma = uma
|
| 21 |
+
# Optional path to an NVFP4-pack-quantized Qwen3 text encoder. When set,
|
| 22 |
+
# we load the encoder via vLLM's CompressedTensorsW4A4Fp4 (CUTLASS NVFP4
|
| 23 |
+
# GEMM) instead of the bf16 text_encoder shipped inside the Z-Image
|
| 24 |
+
# base repo. Cuts encoder VRAM ~4x with negligible quality loss
|
| 25 |
+
# (cosine >0.999 vs the bf16 reference on Thor).
|
| 26 |
+
self.nvfp4_text_encoder_path = nvfp4_text_encoder_path
|
| 27 |
+
|
| 28 |
+
def _build_nvfp4_text_encoder(self):
|
| 29 |
+
"""Load the NVFP4 text encoder if requested, returns (encoder, tokenizer) or (None, None)."""
|
| 30 |
+
if not self.nvfp4_text_encoder_path:
|
| 31 |
+
return None, None
|
| 32 |
+
print(
|
| 33 |
+
f"[ZImageTurboBackend] Loading NVFP4 text encoder from {self.nvfp4_text_encoder_path} "
|
| 34 |
+
"(vLLM CompressedTensorsW4A4Fp4 + CUTLASS NVFP4 GEMM)"
|
| 35 |
+
)
|
| 36 |
+
from NVFP4TextEncoder import load_nvfp4_text_encoder
|
| 37 |
+
from transformers import AutoTokenizer
|
| 38 |
+
|
| 39 |
+
encoder = load_nvfp4_text_encoder(
|
| 40 |
+
self.nvfp4_text_encoder_path,
|
| 41 |
+
device="cuda",
|
| 42 |
+
dtype=torch.bfloat16,
|
| 43 |
+
)
|
| 44 |
+
tokenizer = AutoTokenizer.from_pretrained(self.nvfp4_text_encoder_path)
|
| 45 |
+
return encoder, tokenizer
|
| 46 |
+
|
| 47 |
+
def load(self):
|
| 48 |
+
print(f"Loading ZImageTurboBackend from {self.model_id}...")
|
| 49 |
+
print(f"Loading NunchakuZImageTransformer2DModel from {self.optimized_model_path}...")
|
| 50 |
+
|
| 51 |
+
# Load transformer (optimized model)
|
| 52 |
+
transformer = NunchakuZImageTransformer2DModel.from_pretrained(self.optimized_model_path)
|
| 53 |
+
|
| 54 |
+
# If requested, build the NVFP4 text encoder before constructing the pipeline so
|
| 55 |
+
# diffusers does not also load the bf16 text_encoder from disk (it would double VRAM).
|
| 56 |
+
nvfp4_encoder, nvfp4_tokenizer = self._build_nvfp4_text_encoder()
|
| 57 |
+
|
| 58 |
+
# Load pipeline
|
| 59 |
+
print("Initializing ZImagePipeline...")
|
| 60 |
+
pipeline_kwargs = dict(
|
| 61 |
+
transformer=transformer,
|
| 62 |
+
torch_dtype=torch.bfloat16,
|
| 63 |
+
low_cpu_mem_usage=False, # standard for HF example
|
| 64 |
+
)
|
| 65 |
+
if nvfp4_encoder is not None:
|
| 66 |
+
# Pass our pre-built encoder so diffusers skips loading the bf16 subfolder.
|
| 67 |
+
pipeline_kwargs["text_encoder"] = nvfp4_encoder
|
| 68 |
+
if nvfp4_tokenizer is not None:
|
| 69 |
+
pipeline_kwargs["tokenizer"] = nvfp4_tokenizer
|
| 70 |
+
|
| 71 |
+
pipeline = ZImagePipeline.from_pretrained(self.model_id, **pipeline_kwargs)
|
| 72 |
+
|
| 73 |
+
gpu_mem = get_gpu_memory()
|
| 74 |
+
print(f"GPU memory available: {gpu_mem} GB")
|
| 75 |
+
|
| 76 |
+
# Enable Flash Attention 2
|
| 77 |
+
try:
|
| 78 |
+
if hasattr(pipeline.transformer, "set_attention_backend"):
|
| 79 |
+
pipeline.transformer.set_attention_backend("native")
|
| 80 |
+
print("Enabled Native SDPA for Z-Image transformer")
|
| 81 |
+
if hasattr(pipeline.vae, "set_attention_backend"):
|
| 82 |
+
pipeline.vae.set_attention_backend("native")
|
| 83 |
+
print("Enabled Native SDPA for Z-Image VAE")
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Could not enable Flash Attention 2: {e}")
|
| 86 |
+
|
| 87 |
+
if self.uma:
|
| 88 |
+
print("UMA mode enabled: Loading all components to GPU and disabling offloads")
|
| 89 |
+
# When using the NVFP4 encoder, it is already on CUDA and its quantised parameters
|
| 90 |
+
# are not compatible with diffusers' generic .to() pathway (e.g. uint8 weight_packed).
|
| 91 |
+
# We move only the diffusers-managed components (vae, transformer if not nunchaku, ...).
|
| 92 |
+
if nvfp4_encoder is not None:
|
| 93 |
+
# Exclude text_encoder from blanket .to('cuda'); it is already on cuda.
|
| 94 |
+
excl = getattr(pipeline, "_exclude_from_cpu_offload", [])
|
| 95 |
+
if "text_encoder" not in excl:
|
| 96 |
+
excl.append("text_encoder")
|
| 97 |
+
pipeline._exclude_from_cpu_offload = excl
|
| 98 |
+
for name, comp in pipeline.components.items():
|
| 99 |
+
if name == "text_encoder":
|
| 100 |
+
continue
|
| 101 |
+
if isinstance(comp, torch.nn.Module):
|
| 102 |
+
try:
|
| 103 |
+
comp.to("cuda")
|
| 104 |
+
except Exception:
|
| 105 |
+
pass
|
| 106 |
+
else:
|
| 107 |
+
pipeline.to("cuda")
|
| 108 |
+
elif gpu_mem <= 18:
|
| 109 |
+
print("GPU memory <= 18GB, using sequential cpu offload for low VRAM")
|
| 110 |
+
# The prompt requested sequential offloading without splitting layers for Nunchaku
|
| 111 |
+
pipeline._exclude_from_cpu_offload.append("transformer")
|
| 112 |
+
if nvfp4_encoder is not None:
|
| 113 |
+
# NVFP4 weights live entirely on CUDA; do not let accelerate move them.
|
| 114 |
+
pipeline._exclude_from_cpu_offload.append("text_encoder")
|
| 115 |
+
pipeline.enable_sequential_cpu_offload()
|
| 116 |
+
transformer.to("cuda")
|
| 117 |
+
if nvfp4_encoder is not None:
|
| 118 |
+
nvfp4_encoder.to("cuda")
|
| 119 |
+
else:
|
| 120 |
+
print("GPU memory > 18GB, using cpu offload")
|
| 121 |
+
if nvfp4_encoder is not None:
|
| 122 |
+
if not hasattr(pipeline, "_exclude_from_cpu_offload"):
|
| 123 |
+
pipeline._exclude_from_cpu_offload = []
|
| 124 |
+
pipeline._exclude_from_cpu_offload.append("text_encoder")
|
| 125 |
+
pipeline.enable_model_cpu_offload()
|
| 126 |
+
if nvfp4_encoder is not None:
|
| 127 |
+
nvfp4_encoder.to("cuda")
|
| 128 |
+
|
| 129 |
+
self.pipeline = pipeline
|
| 130 |
+
# Return twice for pipeline and edit_pipeline (though Z-Image-Turbo is T2I only)
|
| 131 |
+
return self.pipeline, self.pipeline
|
extras/compress_mllm.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Qwen2_5_VLForConditionalGeneration
|
| 4 |
+
from dfloat11 import compress_model
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
parser = argparse.ArgumentParser("Compress OmniGen2 MLLM (Qwen2.5-VL) using DFloat11")
|
| 8 |
+
parser.add_argument(
|
| 9 |
+
'--model_path',
|
| 10 |
+
type=str,
|
| 11 |
+
required=True,
|
| 12 |
+
help='The path to the OmniGen2 model (containing "mllm" folder) or direct path to MLLM checkpoint'
|
| 13 |
+
)
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
'--save_path',
|
| 16 |
+
type=str,
|
| 17 |
+
default='./OmniGen2-mllm-DF11',
|
| 18 |
+
help='The path to save the compressed model'
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
'--save_single_file',
|
| 22 |
+
action='store_true',
|
| 23 |
+
help='Save the compressed model as a single .safetensors file'
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
'--check_correctness',
|
| 27 |
+
action='store_true',
|
| 28 |
+
help='Check the correctness of the compressed weights during compression'
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
'--block_range',
|
| 32 |
+
type=int,
|
| 33 |
+
nargs=2,
|
| 34 |
+
default=(0, 100),
|
| 35 |
+
help='The range of transformer blocks to compress (for parallel compression over multiple CPU cores)'
|
| 36 |
+
)
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
# Determine MLLM path
|
| 40 |
+
import os
|
| 41 |
+
mllm_path = args.model_path
|
| 42 |
+
if os.path.isdir(os.path.join(args.model_path, "mllm")):
|
| 43 |
+
mllm_path = os.path.join(args.model_path, "mllm")
|
| 44 |
+
|
| 45 |
+
print(f"Loading MLLM from: {mllm_path}")
|
| 46 |
+
|
| 47 |
+
# Load the Qwen2.5-VL model in bfloat16 precision
|
| 48 |
+
# Use trust_remote_code=True same as in inference.py
|
| 49 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 50 |
+
mllm_path,
|
| 51 |
+
torch_dtype=torch.bfloat16,
|
| 52 |
+
trust_remote_code=True
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Untie weights to avoid safetensors error about shared memory
|
| 56 |
+
# safetensors.torch.save_file dies if tensors share memory.
|
| 57 |
+
if hasattr(model, 'lm_head') and hasattr(model.lm_head, 'weight'):
|
| 58 |
+
print("Untying lm_head weights to avoid safetensors shared memory error...")
|
| 59 |
+
model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())
|
| 60 |
+
|
| 61 |
+
# Compress the model using DFloat11 compression
|
| 62 |
+
# Pattern updated to match Qwen2.5-VL internal structure (model.language_model.layers...)
|
| 63 |
+
compress_model(
|
| 64 |
+
model=model,
|
| 65 |
+
pattern_dict={
|
| 66 |
+
r"model\.language_model\.layers\.\d+": (
|
| 67 |
+
"self_attn.q_proj",
|
| 68 |
+
"self_attn.k_proj",
|
| 69 |
+
"self_attn.v_proj",
|
| 70 |
+
"self_attn.o_proj",
|
| 71 |
+
"mlp.gate_proj",
|
| 72 |
+
"mlp.up_proj",
|
| 73 |
+
"mlp.down_proj",
|
| 74 |
+
),
|
| 75 |
+
},
|
| 76 |
+
save_path=args.save_path,
|
| 77 |
+
save_single_file=args.save_single_file, # Force single file to use state_dict keys (model.language_model...)
|
| 78 |
+
check_correctness=args.check_correctness,
|
| 79 |
+
block_range=args.block_range,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
main()
|
extras/imagegen_zimage_turbo.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#source /home/olegk/venv/vllm/bin/activate
|
| 3 |
+
cd /home/olegk/Nikola/src/imagegen
|
| 4 |
+
|
| 5 |
+
# Force vLLM's built-in CUTLASS NVFP4 kernel (skips flashinfer-cutlass JIT which
|
| 6 |
+
# needs the `ninja` binary on PATH). The kernel still uses the CUTLASS FP4 GEMM
|
| 7 |
+
# path on Thor (sm_110).
|
| 8 |
+
export VLLM_NVFP4_GEMM_BACKEND=cutlass
|
| 9 |
+
|
| 10 |
+
python ImageEditServer.py \
|
| 11 |
+
--port 4500 \
|
| 12 |
+
--model /home/olegk/Nikola/models/Z-Image-Turbo \
|
| 13 |
+
--optimized-model /home/olegk/Nikola/models/nunchaku-z-image-turbo/svdq-fp4_r32-z-image-turbo.safetensors \
|
| 14 |
+
--backend zimage \
|
| 15 |
+
--steps 8 \
|
| 16 |
+
--guidance-scale 0.0 \
|
| 17 |
+
--uma \
|
| 18 |
+
--nvfp4-text-encoder /home/olegk/Nikola/models/Z-Image-Turbo-Text-Encoder-NVFP4
|
extras/imagegen_zimage_turbo_int4.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
source /home/olegk/venv/vllm/bin/activate
|
| 3 |
+
cd /home/olegk/Nikola/src/imagegen
|
| 4 |
+
python ImageEditServer.py \
|
| 5 |
+
--port 4500 \
|
| 6 |
+
--model /home/olegk/Nikola/models/Z-Image-Turbo \
|
| 7 |
+
--optimized-model /home/olegk/Nikola/models/nunchaku-z-image-turbo/svdq-int4_r32-z-image-turbo.safetensors \
|
| 8 |
+
--backend zimage \
|
| 9 |
+
--steps 8 \
|
| 10 |
+
--guidance-scale 0.0 \
|
| 11 |
+
--uma
|
generation_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"eos_token_id": [
|
| 5 |
+
151645,
|
| 6 |
+
151643
|
| 7 |
+
],
|
| 8 |
+
"pad_token_id": 151643,
|
| 9 |
+
"temperature": 0.6,
|
| 10 |
+
"top_k": 20,
|
| 11 |
+
"top_p": 0.95,
|
| 12 |
+
"transformers_version": "4.51.0"
|
| 13 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eaad037756f1afd7cb847ff4b7c23db02ec56936bb30806903fb57d2b0b1588d
|
| 3 |
+
size 2822178072
|
recipe.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_stage:
|
| 2 |
+
default_modifiers:
|
| 3 |
+
QuantizationModifier:
|
| 4 |
+
targets: [Linear]
|
| 5 |
+
ignore: [lm_head, 're:.*mlp.gate$', 're:.*mlp.shared_expert_gate$', 're:.*linear_attn.*',
|
| 6 |
+
're:model\.visual\..*', 're:model\.image_encoder\..*']
|
| 7 |
+
scheme: NVFP4
|
| 8 |
+
bypass_divisibility_checks: false
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
|
| 3 |
+
size 11422650
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": null,
|
| 5 |
+
"clean_up_tokenization_spaces": false,
|
| 6 |
+
"eos_token": "<|im_end|>",
|
| 7 |
+
"errors": "replace",
|
| 8 |
+
"extra_special_tokens": [
|
| 9 |
+
"<|im_start|>",
|
| 10 |
+
"<|im_end|>",
|
| 11 |
+
"<|object_ref_start|>",
|
| 12 |
+
"<|object_ref_end|>",
|
| 13 |
+
"<|box_start|>",
|
| 14 |
+
"<|box_end|>",
|
| 15 |
+
"<|quad_start|>",
|
| 16 |
+
"<|quad_end|>",
|
| 17 |
+
"<|vision_start|>",
|
| 18 |
+
"<|vision_end|>",
|
| 19 |
+
"<|vision_pad|>",
|
| 20 |
+
"<|image_pad|>",
|
| 21 |
+
"<|video_pad|>"
|
| 22 |
+
],
|
| 23 |
+
"is_local": true,
|
| 24 |
+
"model_max_length": 131072,
|
| 25 |
+
"pad_token": "<|endoftext|>",
|
| 26 |
+
"split_special_tokens": false,
|
| 27 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 28 |
+
"unk_token": null
|
| 29 |
+
}
|