from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import Any, Tuple import torch from mmgp import offload from safetensors import safe_open from shared.utils import files_locator as fl from .florence2 import Florence2Config, Florence2ForConditionalGeneration, Florence2Processor from .florence2.image_processing_florence2 import Florence2ImageProcessorLite from transformers import AutoTokenizer, BartTokenizer, BartTokenizerFast from .assets import ( FLORENCE2_FILES, FLORENCE2_FOLDER, LLAMA32_FILES, LLAMA32_FOLDER, LLAMAJOY_FILES, LLAMAJOY_FOLDER, PROMPT_ENHANCER_REPO, ) @dataclass(slots=True) class PromptEnhancerRuntime: image_caption_model: Any = None image_caption_processor: Any = None llm_model: Any = None llm_tokenizer: Any = None pipe_models: dict[str, Any] = field(default_factory=dict) budgets: dict[str, int] = field(default_factory=dict) co_tenants: dict[str, list[str]] = field(default_factory=dict) def ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled: int, qwen_backend: str = "quanto_int8"): enhancer_enabled = int(enhancer_enabled) if enhancer_enabled == 1: process_files_def( repoId=PROMPT_ENHANCER_REPO, sourceFolderList=[FLORENCE2_FOLDER, LLAMA32_FOLDER], fileList=[ FLORENCE2_FILES, LLAMA32_FILES, ], ) return if enhancer_enabled == 2: process_files_def( repoId=PROMPT_ENHANCER_REPO, sourceFolderList=[FLORENCE2_FOLDER, LLAMAJOY_FOLDER], fileList=[ FLORENCE2_FILES, LLAMAJOY_FILES, ], ) return if enhancer_enabled in (3, 4): from .qwen35_vl import ensure_qwen35_prompt_enhancer_assets, get_qwen35_prompt_enhancer_variant ensure_qwen35_prompt_enhancer_assets(process_files_def, backend=qwen_backend, variant=get_qwen35_prompt_enhancer_variant(enhancer_enabled)) def download_prompt_enhancer_assets(enhancer_enabled: int, qwen_backend: str = "quanto_int8", send_cmd=None, progress=None, status_text="Downloading Prompt Enhancer model files..."): enhancer_enabled = int(enhancer_enabled) if enhancer_enabled <= 0: return False from shared.utils.download import download_def_missing_files, process_files_def_if_needed downloaded = False status_sent = False def process_download_def(**download_def): nonlocal downloaded, status_sent has_missing_files = len(download_def_missing_files(download_def)) > 0 download_status_text = None if has_missing_files and not status_sent: if progress is not None: progress(0, status_text) download_status_text = status_text status_sent = True downloaded = process_files_def_if_needed(download_def, send_cmd=send_cmd, status_text=download_status_text) or downloaded ensure_prompt_enhancer_assets(process_download_def, enhancer_enabled=enhancer_enabled, qwen_backend=qwen_backend) return downloaded def unload_prompt_enhancer_models(*models): seen = set() for model in models: if model is None: continue model_id = id(model) if model_id in seen: continue seen.add(model_id) unload = getattr(model, "unload", None) if callable(unload): unload() def _set_pad_token_from_tokenizer(model, tokenizer): model.generation_config.pad_token = tokenizer.eos_token if model.generation_config.pad_token_id is None: eos_token_id = model.generation_config.eos_token_id model.generation_config.pad_token_id = eos_token_id[0] if isinstance(eos_token_id, list) else eos_token_id def _load_llama32_prompt_enhancer(): llm_model = offload.fast_load_transformers_model( fl.locate_file(f"{LLAMA32_FOLDER}/Llama3_2_quanto_bf16_int8.safetensors"), defaultConfigPath=fl.locate_file(f"{LLAMA32_FOLDER}/config.json", error_if_none=False), configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"}, writable_tensors=False, ) llm_model._validate_model_kwargs = lambda *_args, **_kwargs: None llm_model._offload_hooks = ["generate"] llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMA32_FOLDER)) _set_pad_token_from_tokenizer(llm_model, llm_tokenizer) llm_model.eval() return llm_model, llm_tokenizer, 5000 def _load_joycaption_prompt_enhancer(): def preprocess_sd(sd, quant_map=None, tied_map=None): rules = {"model.language_model": "model", "model.vision_tower": None, "model.multi_modal_projector": None} return tuple(offload.map_state_dict([sd, quant_map, tied_map], rules)) llm_model = offload.fast_load_transformers_model( fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_joycaption_quanto_bf16_int8.safetensors"), forcedConfigPath=fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_config.json", error_if_none=False), configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"}, preprocess_sd=preprocess_sd, writable_tensors=False, ) llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMAJOY_FOLDER)) _set_pad_token_from_tokenizer(llm_model, llm_tokenizer) llm_model.eval() return llm_model, llm_tokenizer, 10000 def load_prompt_enhancer_runtime(process_files_def, enhancer_enabled: int, lm_decoder_engine: str = "", qwen_backend: str = "quanto_int8") -> PromptEnhancerRuntime: enhancer_enabled = int(enhancer_enabled) runtime = PromptEnhancerRuntime() if enhancer_enabled <= 0: return runtime ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled=enhancer_enabled, qwen_backend=qwen_backend) if enhancer_enabled in (3, 4): from .qwen35_text import load_qwen35_text_prompt_enhancer from .qwen35_vl import ( enhancer_quantization_QUANTO_INT8, alias_qwen35_text_embedding_for_mmgp, get_qwen35_assets_dir_name, get_qwen35_prompt_enhancer_variant, load_qwen35_vl_prompt_enhancer, ) backend = qwen_backend or enhancer_quantization_QUANTO_INT8 qwen35_variant = get_qwen35_prompt_enhancer_variant(enhancer_enabled) assets_dir_name = get_qwen35_assets_dir_name(qwen35_variant) assets_dir = fl.locate_folder(assets_dir_name, error_if_none=False) or fl.get_download_location(assets_dir_name) runtime.llm_model = load_qwen35_text_prompt_enhancer( assets_dir=assets_dir, backend=backend, attn_implementation="sdpa", requested_lm_engine=lm_decoder_engine, variant=qwen35_variant, ) runtime.llm_tokenizer = getattr(runtime.llm_model, "_prompt_enhancer_tokenizer", None) runtime.llm_model.eval() caption_embedding_model = alias_qwen35_text_embedding_for_mmgp(runtime.llm_model) runtime.image_caption_model, vision_tower_model = load_qwen35_vl_prompt_enhancer( assets_dir=assets_dir, attn_implementation="sdpa", text_model=runtime.llm_model, input_embedding_model=caption_embedding_model, backend=backend, variant=qwen35_variant, ) runtime.image_caption_processor = getattr(runtime.image_caption_model, "_prompt_enhancer_processor", None) runtime.image_caption_model.eval() runtime.pipe_models["prompt_enhancer_image_caption_vision_tower_model"] = vision_tower_model runtime.pipe_models["prompt_enhancer_image_caption_embedding_model"] = caption_embedding_model runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model runtime.budgets["prompt_enhancer_image_caption_vision_tower_model"] = 3000 runtime.budgets["prompt_enhancer_image_caption_embedding_model"] = 2000 runtime.budgets["prompt_enhancer_llm_model"] = 10000 runtime.co_tenants["prompt_enhancer_image_caption_vision_tower_model"] = ["prompt_enhancer_image_caption_embedding_model"] runtime.co_tenants["prompt_enhancer_image_caption_embedding_model"] = ["prompt_enhancer_image_caption_vision_tower_model"] return runtime runtime.image_caption_model, runtime.image_caption_processor = load_florence2(fl.locate_folder(FLORENCE2_FOLDER), attn_implementation="sdpa") runtime.image_caption_model._model_dtype = torch.float runtime.image_caption_model.eval() runtime.pipe_models["prompt_enhancer_image_caption_model"] = runtime.image_caption_model if enhancer_enabled == 1: runtime.llm_model, runtime.llm_tokenizer, budget = _load_llama32_prompt_enhancer() else: runtime.llm_model, runtime.llm_tokenizer, budget = _load_joycaption_prompt_enhancer() runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model runtime.budgets["prompt_enhancer_llm_model"] = budget return runtime def _load_state_dict(weights_path: Path) -> dict: if weights_path.suffix == ".safetensors": state_dict = {} with safe_open(str(weights_path), framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) return state_dict return torch.load(str(weights_path), map_location="cpu") def _resolve_weights_path(model_path: Path) -> Path: # Prefer fp32 weights for stability/quality when available. preferred = model_path / "xmodel.safetensors" if preferred.exists(): return preferred fallback = model_path / "model.safetensors" if fallback.exists(): return fallback fallback = model_path / "pytorch_model.bin" if fallback.exists(): return fallback raise FileNotFoundError( f"No Florence2 weights found in {model_path} (expected model.safetensors/xmodel.safetensors/pytorch_model.bin)" ) def load_florence2( model_dir: str, attn_implementation: str = "sdpa", ) -> Tuple[Florence2ForConditionalGeneration, Florence2Processor]: model_path = Path(model_dir) if not model_path.exists(): raise FileNotFoundError(f"Florence2 folder not found: {model_path}") config = Florence2Config.from_pretrained(str(model_path)) if attn_implementation: config._attn_implementation = attn_implementation weights_path = _resolve_weights_path(model_path) state_dict = _load_state_dict(weights_path) model = Florence2ForConditionalGeneration(config) load_info = model.load_state_dict(state_dict, strict=False) del state_dict if load_info.missing_keys: allowed_missing = { "language_model.model.encoder.embed_tokens.weight", "language_model.model.decoder.embed_tokens.weight", } extra_missing = [k for k in load_info.missing_keys if k not in allowed_missing] if extra_missing: print(f"Florence2 missing keys: {extra_missing}") if load_info.unexpected_keys: print(f"Florence2 unexpected keys: {len(load_info.unexpected_keys)}") model.eval() image_processor = Florence2ImageProcessorLite.from_preprocessor_config(model_path) tokenizer = None tokenizer_errors = [] for tok_cls in (BartTokenizerFast, BartTokenizer): try: tokenizer = tok_cls.from_pretrained(str(model_path)) break except Exception as exc: tokenizer_errors.append(exc) if tokenizer is None: raise RuntimeError(f"Unable to load Florence2 tokenizer: {tokenizer_errors}") try: processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer) except TypeError as exc: if "CLIPImageProcessor" not in str(exc): raise try: from transformers import CLIPImageProcessor except Exception: from transformers.models.clip import CLIPImageProcessor image_processor = CLIPImageProcessor.from_pretrained(str(model_path)) processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer) return model, processor