File size: 12,238 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 | from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Tuple
import torch
from mmgp import offload
from safetensors import safe_open
from shared.utils import files_locator as fl
from .florence2 import Florence2Config, Florence2ForConditionalGeneration, Florence2Processor
from .florence2.image_processing_florence2 import Florence2ImageProcessorLite
from transformers import AutoTokenizer, BartTokenizer, BartTokenizerFast
from .assets import (
FLORENCE2_FILES,
FLORENCE2_FOLDER,
LLAMA32_FILES,
LLAMA32_FOLDER,
LLAMAJOY_FILES,
LLAMAJOY_FOLDER,
PROMPT_ENHANCER_REPO,
)
@dataclass(slots=True)
class PromptEnhancerRuntime:
image_caption_model: Any = None
image_caption_processor: Any = None
llm_model: Any = None
llm_tokenizer: Any = None
pipe_models: dict[str, Any] = field(default_factory=dict)
budgets: dict[str, int] = field(default_factory=dict)
co_tenants: dict[str, list[str]] = field(default_factory=dict)
def ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled: int, qwen_backend: str = "quanto_int8"):
enhancer_enabled = int(enhancer_enabled)
if enhancer_enabled == 1:
process_files_def(
repoId=PROMPT_ENHANCER_REPO,
sourceFolderList=[FLORENCE2_FOLDER, LLAMA32_FOLDER],
fileList=[
FLORENCE2_FILES,
LLAMA32_FILES,
],
)
return
if enhancer_enabled == 2:
process_files_def(
repoId=PROMPT_ENHANCER_REPO,
sourceFolderList=[FLORENCE2_FOLDER, LLAMAJOY_FOLDER],
fileList=[
FLORENCE2_FILES,
LLAMAJOY_FILES,
],
)
return
if enhancer_enabled in (3, 4):
from .qwen35_vl import ensure_qwen35_prompt_enhancer_assets, get_qwen35_prompt_enhancer_variant
ensure_qwen35_prompt_enhancer_assets(process_files_def, backend=qwen_backend, variant=get_qwen35_prompt_enhancer_variant(enhancer_enabled))
def download_prompt_enhancer_assets(enhancer_enabled: int, qwen_backend: str = "quanto_int8", send_cmd=None, progress=None, status_text="Downloading Prompt Enhancer model files..."):
enhancer_enabled = int(enhancer_enabled)
if enhancer_enabled <= 0:
return False
from shared.utils.download import download_def_missing_files, process_files_def_if_needed
downloaded = False
status_sent = False
def process_download_def(**download_def):
nonlocal downloaded, status_sent
has_missing_files = len(download_def_missing_files(download_def)) > 0
download_status_text = None
if has_missing_files and not status_sent:
if progress is not None:
progress(0, status_text)
download_status_text = status_text
status_sent = True
downloaded = process_files_def_if_needed(download_def, send_cmd=send_cmd, status_text=download_status_text) or downloaded
ensure_prompt_enhancer_assets(process_download_def, enhancer_enabled=enhancer_enabled, qwen_backend=qwen_backend)
return downloaded
def unload_prompt_enhancer_models(*models):
seen = set()
for model in models:
if model is None:
continue
model_id = id(model)
if model_id in seen:
continue
seen.add(model_id)
unload = getattr(model, "unload", None)
if callable(unload):
unload()
def _set_pad_token_from_tokenizer(model, tokenizer):
model.generation_config.pad_token = tokenizer.eos_token
if model.generation_config.pad_token_id is None:
eos_token_id = model.generation_config.eos_token_id
model.generation_config.pad_token_id = eos_token_id[0] if isinstance(eos_token_id, list) else eos_token_id
def _load_llama32_prompt_enhancer():
llm_model = offload.fast_load_transformers_model(
fl.locate_file(f"{LLAMA32_FOLDER}/Llama3_2_quanto_bf16_int8.safetensors"),
defaultConfigPath=fl.locate_file(f"{LLAMA32_FOLDER}/config.json", error_if_none=False),
configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"},
writable_tensors=False,
)
llm_model._validate_model_kwargs = lambda *_args, **_kwargs: None
llm_model._offload_hooks = ["generate"]
llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMA32_FOLDER))
_set_pad_token_from_tokenizer(llm_model, llm_tokenizer)
llm_model.eval()
return llm_model, llm_tokenizer, 5000
def _load_joycaption_prompt_enhancer():
def preprocess_sd(sd, quant_map=None, tied_map=None):
rules = {"model.language_model": "model", "model.vision_tower": None, "model.multi_modal_projector": None}
return tuple(offload.map_state_dict([sd, quant_map, tied_map], rules))
llm_model = offload.fast_load_transformers_model(
fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_joycaption_quanto_bf16_int8.safetensors"),
forcedConfigPath=fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_config.json", error_if_none=False),
configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"},
preprocess_sd=preprocess_sd,
writable_tensors=False,
)
llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMAJOY_FOLDER))
_set_pad_token_from_tokenizer(llm_model, llm_tokenizer)
llm_model.eval()
return llm_model, llm_tokenizer, 10000
def load_prompt_enhancer_runtime(process_files_def, enhancer_enabled: int, lm_decoder_engine: str = "", qwen_backend: str = "quanto_int8") -> PromptEnhancerRuntime:
enhancer_enabled = int(enhancer_enabled)
runtime = PromptEnhancerRuntime()
if enhancer_enabled <= 0:
return runtime
ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled=enhancer_enabled, qwen_backend=qwen_backend)
if enhancer_enabled in (3, 4):
from .qwen35_text import load_qwen35_text_prompt_enhancer
from .qwen35_vl import (
enhancer_quantization_QUANTO_INT8,
alias_qwen35_text_embedding_for_mmgp,
get_qwen35_assets_dir_name,
get_qwen35_prompt_enhancer_variant,
load_qwen35_vl_prompt_enhancer,
)
backend = qwen_backend or enhancer_quantization_QUANTO_INT8
qwen35_variant = get_qwen35_prompt_enhancer_variant(enhancer_enabled)
assets_dir_name = get_qwen35_assets_dir_name(qwen35_variant)
assets_dir = fl.locate_folder(assets_dir_name, error_if_none=False) or fl.get_download_location(assets_dir_name)
runtime.llm_model = load_qwen35_text_prompt_enhancer(
assets_dir=assets_dir,
backend=backend,
attn_implementation="sdpa",
requested_lm_engine=lm_decoder_engine,
variant=qwen35_variant,
)
runtime.llm_tokenizer = getattr(runtime.llm_model, "_prompt_enhancer_tokenizer", None)
runtime.llm_model.eval()
caption_embedding_model = alias_qwen35_text_embedding_for_mmgp(runtime.llm_model)
runtime.image_caption_model, vision_tower_model = load_qwen35_vl_prompt_enhancer(
assets_dir=assets_dir,
attn_implementation="sdpa",
text_model=runtime.llm_model,
input_embedding_model=caption_embedding_model,
backend=backend,
variant=qwen35_variant,
)
runtime.image_caption_processor = getattr(runtime.image_caption_model, "_prompt_enhancer_processor", None)
runtime.image_caption_model.eval()
runtime.pipe_models["prompt_enhancer_image_caption_vision_tower_model"] = vision_tower_model
runtime.pipe_models["prompt_enhancer_image_caption_embedding_model"] = caption_embedding_model
runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model
runtime.budgets["prompt_enhancer_image_caption_vision_tower_model"] = 3000
runtime.budgets["prompt_enhancer_image_caption_embedding_model"] = 2000
runtime.budgets["prompt_enhancer_llm_model"] = 10000
runtime.co_tenants["prompt_enhancer_image_caption_vision_tower_model"] = ["prompt_enhancer_image_caption_embedding_model"]
runtime.co_tenants["prompt_enhancer_image_caption_embedding_model"] = ["prompt_enhancer_image_caption_vision_tower_model"]
return runtime
runtime.image_caption_model, runtime.image_caption_processor = load_florence2(fl.locate_folder(FLORENCE2_FOLDER), attn_implementation="sdpa")
runtime.image_caption_model._model_dtype = torch.float
runtime.image_caption_model.eval()
runtime.pipe_models["prompt_enhancer_image_caption_model"] = runtime.image_caption_model
if enhancer_enabled == 1:
runtime.llm_model, runtime.llm_tokenizer, budget = _load_llama32_prompt_enhancer()
else:
runtime.llm_model, runtime.llm_tokenizer, budget = _load_joycaption_prompt_enhancer()
runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model
runtime.budgets["prompt_enhancer_llm_model"] = budget
return runtime
def _load_state_dict(weights_path: Path) -> dict:
if weights_path.suffix == ".safetensors":
state_dict = {}
with safe_open(str(weights_path), framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
return state_dict
return torch.load(str(weights_path), map_location="cpu")
def _resolve_weights_path(model_path: Path) -> Path:
# Prefer fp32 weights for stability/quality when available.
preferred = model_path / "xmodel.safetensors"
if preferred.exists():
return preferred
fallback = model_path / "model.safetensors"
if fallback.exists():
return fallback
fallback = model_path / "pytorch_model.bin"
if fallback.exists():
return fallback
raise FileNotFoundError(
f"No Florence2 weights found in {model_path} (expected model.safetensors/xmodel.safetensors/pytorch_model.bin)"
)
def load_florence2(
model_dir: str,
attn_implementation: str = "sdpa",
) -> Tuple[Florence2ForConditionalGeneration, Florence2Processor]:
model_path = Path(model_dir)
if not model_path.exists():
raise FileNotFoundError(f"Florence2 folder not found: {model_path}")
config = Florence2Config.from_pretrained(str(model_path))
if attn_implementation:
config._attn_implementation = attn_implementation
weights_path = _resolve_weights_path(model_path)
state_dict = _load_state_dict(weights_path)
model = Florence2ForConditionalGeneration(config)
load_info = model.load_state_dict(state_dict, strict=False)
del state_dict
if load_info.missing_keys:
allowed_missing = {
"language_model.model.encoder.embed_tokens.weight",
"language_model.model.decoder.embed_tokens.weight",
}
extra_missing = [k for k in load_info.missing_keys if k not in allowed_missing]
if extra_missing:
print(f"Florence2 missing keys: {extra_missing}")
if load_info.unexpected_keys:
print(f"Florence2 unexpected keys: {len(load_info.unexpected_keys)}")
model.eval()
image_processor = Florence2ImageProcessorLite.from_preprocessor_config(model_path)
tokenizer = None
tokenizer_errors = []
for tok_cls in (BartTokenizerFast, BartTokenizer):
try:
tokenizer = tok_cls.from_pretrained(str(model_path))
break
except Exception as exc:
tokenizer_errors.append(exc)
if tokenizer is None:
raise RuntimeError(f"Unable to load Florence2 tokenizer: {tokenizer_errors}")
try:
processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer)
except TypeError as exc:
if "CLIPImageProcessor" not in str(exc):
raise
try:
from transformers import CLIPImageProcessor
except Exception:
from transformers.models.clip import CLIPImageProcessor
image_processor = CLIPImageProcessor.from_pretrained(str(model_path))
processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer)
return model, processor
|