File size: 12,238 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Tuple

import torch
from mmgp import offload
from safetensors import safe_open
from shared.utils import files_locator as fl

from .florence2 import Florence2Config, Florence2ForConditionalGeneration, Florence2Processor
from .florence2.image_processing_florence2 import Florence2ImageProcessorLite

from transformers import AutoTokenizer, BartTokenizer, BartTokenizerFast

from .assets import (
    FLORENCE2_FILES,
    FLORENCE2_FOLDER,
    LLAMA32_FILES,
    LLAMA32_FOLDER,
    LLAMAJOY_FILES,
    LLAMAJOY_FOLDER,
    PROMPT_ENHANCER_REPO,
)


@dataclass(slots=True)
class PromptEnhancerRuntime:
    image_caption_model: Any = None
    image_caption_processor: Any = None
    llm_model: Any = None
    llm_tokenizer: Any = None
    pipe_models: dict[str, Any] = field(default_factory=dict)
    budgets: dict[str, int] = field(default_factory=dict)
    co_tenants: dict[str, list[str]] = field(default_factory=dict)


def ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled: int, qwen_backend: str = "quanto_int8"):
    enhancer_enabled = int(enhancer_enabled)
    if enhancer_enabled == 1:
        process_files_def(
            repoId=PROMPT_ENHANCER_REPO,
            sourceFolderList=[FLORENCE2_FOLDER, LLAMA32_FOLDER],
            fileList=[
                FLORENCE2_FILES,
                LLAMA32_FILES,
            ],
        )
        return
    if enhancer_enabled == 2:
        process_files_def(
            repoId=PROMPT_ENHANCER_REPO,
            sourceFolderList=[FLORENCE2_FOLDER, LLAMAJOY_FOLDER],
            fileList=[
                FLORENCE2_FILES,
                LLAMAJOY_FILES,
            ],
        )
        return
    if enhancer_enabled in (3, 4):
        from .qwen35_vl import ensure_qwen35_prompt_enhancer_assets, get_qwen35_prompt_enhancer_variant

        ensure_qwen35_prompt_enhancer_assets(process_files_def, backend=qwen_backend, variant=get_qwen35_prompt_enhancer_variant(enhancer_enabled))


def download_prompt_enhancer_assets(enhancer_enabled: int, qwen_backend: str = "quanto_int8", send_cmd=None, progress=None, status_text="Downloading Prompt Enhancer model files..."):
    enhancer_enabled = int(enhancer_enabled)
    if enhancer_enabled <= 0:
        return False

    from shared.utils.download import download_def_missing_files, process_files_def_if_needed

    downloaded = False
    status_sent = False

    def process_download_def(**download_def):
        nonlocal downloaded, status_sent
        has_missing_files = len(download_def_missing_files(download_def)) > 0
        download_status_text = None
        if has_missing_files and not status_sent:
            if progress is not None:
                progress(0, status_text)
            download_status_text = status_text
            status_sent = True
        downloaded = process_files_def_if_needed(download_def, send_cmd=send_cmd, status_text=download_status_text) or downloaded

    ensure_prompt_enhancer_assets(process_download_def, enhancer_enabled=enhancer_enabled, qwen_backend=qwen_backend)
    return downloaded


def unload_prompt_enhancer_models(*models):
    seen = set()
    for model in models:
        if model is None:
            continue
        model_id = id(model)
        if model_id in seen:
            continue
        seen.add(model_id)
        unload = getattr(model, "unload", None)
        if callable(unload):
            unload()


def _set_pad_token_from_tokenizer(model, tokenizer):
    model.generation_config.pad_token = tokenizer.eos_token
    if model.generation_config.pad_token_id is None:
        eos_token_id = model.generation_config.eos_token_id
        model.generation_config.pad_token_id = eos_token_id[0] if isinstance(eos_token_id, list) else eos_token_id


def _load_llama32_prompt_enhancer():
    llm_model = offload.fast_load_transformers_model(
        fl.locate_file(f"{LLAMA32_FOLDER}/Llama3_2_quanto_bf16_int8.safetensors"),
        defaultConfigPath=fl.locate_file(f"{LLAMA32_FOLDER}/config.json", error_if_none=False),
        configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"},
        writable_tensors=False,
    )
    llm_model._validate_model_kwargs = lambda *_args, **_kwargs: None
    llm_model._offload_hooks = ["generate"]
    llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMA32_FOLDER))
    _set_pad_token_from_tokenizer(llm_model, llm_tokenizer)
    llm_model.eval()
    return llm_model, llm_tokenizer, 5000


def _load_joycaption_prompt_enhancer():
    def preprocess_sd(sd, quant_map=None, tied_map=None):
        rules = {"model.language_model": "model", "model.vision_tower": None, "model.multi_modal_projector": None}
        return tuple(offload.map_state_dict([sd, quant_map, tied_map], rules))

    llm_model = offload.fast_load_transformers_model(
        fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_joycaption_quanto_bf16_int8.safetensors"),
        forcedConfigPath=fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_config.json", error_if_none=False),
        configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"},
        preprocess_sd=preprocess_sd,
        writable_tensors=False,
    )
    llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMAJOY_FOLDER))
    _set_pad_token_from_tokenizer(llm_model, llm_tokenizer)
    llm_model.eval()
    return llm_model, llm_tokenizer, 10000


def load_prompt_enhancer_runtime(process_files_def, enhancer_enabled: int, lm_decoder_engine: str = "", qwen_backend: str = "quanto_int8") -> PromptEnhancerRuntime:
    enhancer_enabled = int(enhancer_enabled)
    runtime = PromptEnhancerRuntime()
    if enhancer_enabled <= 0:
        return runtime

    ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled=enhancer_enabled, qwen_backend=qwen_backend)

    if enhancer_enabled in (3, 4):
        from .qwen35_text import load_qwen35_text_prompt_enhancer
        from .qwen35_vl import (
            enhancer_quantization_QUANTO_INT8,
            alias_qwen35_text_embedding_for_mmgp,
            get_qwen35_assets_dir_name,
            get_qwen35_prompt_enhancer_variant,
            load_qwen35_vl_prompt_enhancer,
        )

        backend = qwen_backend or enhancer_quantization_QUANTO_INT8
        qwen35_variant = get_qwen35_prompt_enhancer_variant(enhancer_enabled)
        assets_dir_name = get_qwen35_assets_dir_name(qwen35_variant)
        assets_dir = fl.locate_folder(assets_dir_name, error_if_none=False) or fl.get_download_location(assets_dir_name)
        runtime.llm_model = load_qwen35_text_prompt_enhancer(
            assets_dir=assets_dir,
            backend=backend,
            attn_implementation="sdpa",
            requested_lm_engine=lm_decoder_engine,
            variant=qwen35_variant,
        )
        runtime.llm_tokenizer = getattr(runtime.llm_model, "_prompt_enhancer_tokenizer", None)
        runtime.llm_model.eval()
        caption_embedding_model = alias_qwen35_text_embedding_for_mmgp(runtime.llm_model)
        runtime.image_caption_model, vision_tower_model = load_qwen35_vl_prompt_enhancer(
            assets_dir=assets_dir,
            attn_implementation="sdpa",
            text_model=runtime.llm_model,
            input_embedding_model=caption_embedding_model,
            backend=backend,
            variant=qwen35_variant,
        )
        runtime.image_caption_processor = getattr(runtime.image_caption_model, "_prompt_enhancer_processor", None)
        runtime.image_caption_model.eval()
        runtime.pipe_models["prompt_enhancer_image_caption_vision_tower_model"] = vision_tower_model
        runtime.pipe_models["prompt_enhancer_image_caption_embedding_model"] = caption_embedding_model
        runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model
        runtime.budgets["prompt_enhancer_image_caption_vision_tower_model"] = 3000
        runtime.budgets["prompt_enhancer_image_caption_embedding_model"] = 2000
        runtime.budgets["prompt_enhancer_llm_model"] = 10000
        runtime.co_tenants["prompt_enhancer_image_caption_vision_tower_model"] = ["prompt_enhancer_image_caption_embedding_model"]
        runtime.co_tenants["prompt_enhancer_image_caption_embedding_model"] = ["prompt_enhancer_image_caption_vision_tower_model"]
        return runtime

    runtime.image_caption_model, runtime.image_caption_processor = load_florence2(fl.locate_folder(FLORENCE2_FOLDER), attn_implementation="sdpa")
    runtime.image_caption_model._model_dtype = torch.float
    runtime.image_caption_model.eval()
    runtime.pipe_models["prompt_enhancer_image_caption_model"] = runtime.image_caption_model
    if enhancer_enabled == 1:
        runtime.llm_model, runtime.llm_tokenizer, budget = _load_llama32_prompt_enhancer()
    else:
        runtime.llm_model, runtime.llm_tokenizer, budget = _load_joycaption_prompt_enhancer()
    runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model
    runtime.budgets["prompt_enhancer_llm_model"] = budget
    return runtime


def _load_state_dict(weights_path: Path) -> dict:
    if weights_path.suffix == ".safetensors":
        state_dict = {}
        with safe_open(str(weights_path), framework="pt", device="cpu") as f:
            for key in f.keys():
                state_dict[key] = f.get_tensor(key)
        return state_dict
    return torch.load(str(weights_path), map_location="cpu")


def _resolve_weights_path(model_path: Path) -> Path:
    # Prefer fp32 weights for stability/quality when available.
    preferred = model_path / "xmodel.safetensors"
    if preferred.exists():
        return preferred
    fallback = model_path / "model.safetensors"
    if fallback.exists():
        return fallback
    fallback = model_path / "pytorch_model.bin"
    if fallback.exists():
        return fallback
    raise FileNotFoundError(
        f"No Florence2 weights found in {model_path} (expected model.safetensors/xmodel.safetensors/pytorch_model.bin)"
    )


def load_florence2(
    model_dir: str,
    attn_implementation: str = "sdpa",
) -> Tuple[Florence2ForConditionalGeneration, Florence2Processor]:
    model_path = Path(model_dir)
    if not model_path.exists():
        raise FileNotFoundError(f"Florence2 folder not found: {model_path}")

    config = Florence2Config.from_pretrained(str(model_path))
    if attn_implementation:
        config._attn_implementation = attn_implementation
    weights_path = _resolve_weights_path(model_path)
    state_dict = _load_state_dict(weights_path)

    model = Florence2ForConditionalGeneration(config)
    load_info = model.load_state_dict(state_dict, strict=False)
    del state_dict
    if load_info.missing_keys:
        allowed_missing = {
            "language_model.model.encoder.embed_tokens.weight",
            "language_model.model.decoder.embed_tokens.weight",
        }
        extra_missing = [k for k in load_info.missing_keys if k not in allowed_missing]
        if extra_missing:
            print(f"Florence2 missing keys: {extra_missing}")
    if load_info.unexpected_keys:
        print(f"Florence2 unexpected keys: {len(load_info.unexpected_keys)}")
    model.eval()

    image_processor = Florence2ImageProcessorLite.from_preprocessor_config(model_path)
    tokenizer = None
    tokenizer_errors = []
    for tok_cls in (BartTokenizerFast, BartTokenizer):
        try:
            tokenizer = tok_cls.from_pretrained(str(model_path))
            break
        except Exception as exc:
            tokenizer_errors.append(exc)
    if tokenizer is None:
        raise RuntimeError(f"Unable to load Florence2 tokenizer: {tokenizer_errors}")
    try:
        processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer)
    except TypeError as exc:
        if "CLIPImageProcessor" not in str(exc):
            raise
        try:
            from transformers import CLIPImageProcessor
        except Exception:
            from transformers.models.clip import CLIPImageProcessor
        image_processor = CLIPImageProcessor.from_pretrained(str(model_path))
        processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer)

    return model, processor