from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import Any, Tuple import torch from mmgp import offload from safetensors import safe_open from shared.utils import files_locator as fl from .florence2 import Florence2Config, Florence2ForConditionalGeneration, Florence2Processor from .florence2.image_processing_florence2 import Florence2ImageProcessorLite from transformers import AutoTokenizer, BartTokenizer, BartTokenizerFast FLORENCE2_FOLDER = "Florence2" LLAMA32_FOLDER = "Llama3_2" LLAMAJOY_FOLDER = "llama-joycaption-beta-one-hf-llava" PROMPT_ENHANCER_REPO = "DeepBeepMeep/LTX_Video" @dataclass(slots=True) class PromptEnhancerRuntime: image_caption_model: Any = None image_caption_processor: Any = None llm_model: Any = None llm_tokenizer: Any = None pipe_models: dict[str, Any] = field(default_factory=dict) budgets: dict[str, int] = field(default_factory=dict) def ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled: int, qwen_backend: str = "quanto_int8"): enhancer_enabled = int(enhancer_enabled) if enhancer_enabled == 1: process_files_def( repoId=PROMPT_ENHANCER_REPO, sourceFolderList=[FLORENCE2_FOLDER, LLAMA32_FOLDER], fileList=[ ["config.json", "configuration_florence2.py", "model.safetensors", "preprocessor_config.json", "tokenizer.json", "tokenizer_config.json"], ["config.json", "generation_config.json", "Llama3_2_quanto_bf16_int8.safetensors", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ], ) return if enhancer_enabled == 2: process_files_def( repoId=PROMPT_ENHANCER_REPO, sourceFolderList=[FLORENCE2_FOLDER, LLAMAJOY_FOLDER], fileList=[ ["config.json", "configuration_florence2.py", "model.safetensors", "preprocessor_config.json", "tokenizer.json", "tokenizer_config.json"], ["config.json", "llama_config.json", "llama_joycaption_quanto_bf16_int8.safetensors", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ], ) return if enhancer_enabled in (3, 4): from .qwen35_vl import ensure_qwen35_prompt_enhancer_assets, get_qwen35_prompt_enhancer_variant ensure_qwen35_prompt_enhancer_assets(process_files_def, backend=qwen_backend, variant=get_qwen35_prompt_enhancer_variant(enhancer_enabled)) def unload_prompt_enhancer_models(*models): seen = set() for model in models: if model is None: continue model_id = id(model) if model_id in seen: continue seen.add(model_id) unload = getattr(model, "unload", None) if callable(unload): unload() def _set_pad_token_from_tokenizer(model, tokenizer): model.generation_config.pad_token = tokenizer.eos_token if model.generation_config.pad_token_id is None: eos_token_id = model.generation_config.eos_token_id model.generation_config.pad_token_id = eos_token_id[0] if isinstance(eos_token_id, list) else eos_token_id def _load_llama32_prompt_enhancer(): llm_model = offload.fast_load_transformers_model( fl.locate_file(f"{LLAMA32_FOLDER}/Llama3_2_quanto_bf16_int8.safetensors"), defaultConfigPath=fl.locate_file(f"{LLAMA32_FOLDER}/config.json", error_if_none=False), configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"}, writable_tensors=False, ) llm_model._validate_model_kwargs = lambda *_args, **_kwargs: None llm_model._offload_hooks = ["generate"] llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMA32_FOLDER)) _set_pad_token_from_tokenizer(llm_model, llm_tokenizer) llm_model.eval() return llm_model, llm_tokenizer, 5000 def _load_joycaption_prompt_enhancer(): def preprocess_sd(sd, quant_map=None, tied_map=None): rules = {"model.language_model": "model", "model.vision_tower": None, "model.multi_modal_projector": None} return tuple(offload.map_state_dict([sd, quant_map, tied_map], rules)) llm_model = offload.fast_load_transformers_model( fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_joycaption_quanto_bf16_int8.safetensors"), forcedConfigPath=fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_config.json", error_if_none=False), configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"}, preprocess_sd=preprocess_sd, writable_tensors=False, ) llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMAJOY_FOLDER)) _set_pad_token_from_tokenizer(llm_model, llm_tokenizer) llm_model.eval() return llm_model, llm_tokenizer, 10000 def load_prompt_enhancer_runtime(process_files_def, enhancer_enabled: int, lm_decoder_engine: str = "", qwen_backend: str = "quanto_int8") -> PromptEnhancerRuntime: enhancer_enabled = int(enhancer_enabled) runtime = PromptEnhancerRuntime() if enhancer_enabled <= 0: return runtime ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled=enhancer_enabled, qwen_backend=qwen_backend) if enhancer_enabled in (3, 4): from .qwen35_text import load_qwen35_text_prompt_enhancer from .qwen35_vl import ( enhancer_quantization_QUANTO_INT8, get_qwen35_assets_dir_name, get_qwen35_prompt_enhancer_variant, load_qwen35_vl_prompt_enhancer, ) backend = qwen_backend or enhancer_quantization_QUANTO_INT8 qwen35_variant = get_qwen35_prompt_enhancer_variant(enhancer_enabled) assets_dir_name = get_qwen35_assets_dir_name(qwen35_variant) assets_dir = fl.locate_folder(assets_dir_name, error_if_none=False) or fl.get_download_location(assets_dir_name) runtime.llm_model = load_qwen35_text_prompt_enhancer( assets_dir=assets_dir, backend=backend, attn_implementation="sdpa", requested_lm_engine=lm_decoder_engine, variant=qwen35_variant, ) runtime.llm_tokenizer = getattr(runtime.llm_model, "_prompt_enhancer_tokenizer", None) runtime.llm_model.eval() runtime.image_caption_model, vision_tower_model = load_qwen35_vl_prompt_enhancer( assets_dir=assets_dir, attn_implementation="sdpa", text_model=runtime.llm_model, backend=backend, variant=qwen35_variant, ) runtime.image_caption_processor = getattr(runtime.image_caption_model, "_prompt_enhancer_processor", None) runtime.image_caption_model.eval() runtime.pipe_models["prompt_enhancer_image_caption_vision_tower_model"] = vision_tower_model runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model runtime.budgets["prompt_enhancer_image_caption_vision_tower_model"] = 3000 runtime.budgets["prompt_enhancer_llm_model"] = 10000 return runtime runtime.image_caption_model, runtime.image_caption_processor = load_florence2(fl.locate_folder(FLORENCE2_FOLDER), attn_implementation="sdpa") runtime.image_caption_model._model_dtype = torch.float runtime.image_caption_model.eval() runtime.pipe_models["prompt_enhancer_image_caption_model"] = runtime.image_caption_model if enhancer_enabled == 1: runtime.llm_model, runtime.llm_tokenizer, budget = _load_llama32_prompt_enhancer() else: runtime.llm_model, runtime.llm_tokenizer, budget = _load_joycaption_prompt_enhancer() runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model runtime.budgets["prompt_enhancer_llm_model"] = budget return runtime def _load_state_dict(weights_path: Path) -> dict: if weights_path.suffix == ".safetensors": state_dict = {} with safe_open(str(weights_path), framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) return state_dict return torch.load(str(weights_path), map_location="cpu") def _resolve_weights_path(model_path: Path) -> Path: # Prefer fp32 weights for stability/quality when available. preferred = model_path / "xmodel.safetensors" if preferred.exists(): return preferred fallback = model_path / "model.safetensors" if fallback.exists(): return fallback fallback = model_path / "pytorch_model.bin" if fallback.exists(): return fallback raise FileNotFoundError( f"No Florence2 weights found in {model_path} (expected model.safetensors/xmodel.safetensors/pytorch_model.bin)" ) def load_florence2( model_dir: str, attn_implementation: str = "sdpa", ) -> Tuple[Florence2ForConditionalGeneration, Florence2Processor]: model_path = Path(model_dir) if not model_path.exists(): raise FileNotFoundError(f"Florence2 folder not found: {model_path}") config = Florence2Config.from_pretrained(str(model_path)) if attn_implementation: config._attn_implementation = attn_implementation weights_path = _resolve_weights_path(model_path) state_dict = _load_state_dict(weights_path) model = Florence2ForConditionalGeneration(config) load_info = model.load_state_dict(state_dict, strict=False) del state_dict if load_info.missing_keys: allowed_missing = { "language_model.model.encoder.embed_tokens.weight", "language_model.model.decoder.embed_tokens.weight", } extra_missing = [k for k in load_info.missing_keys if k not in allowed_missing] if extra_missing: print(f"Florence2 missing keys: {extra_missing}") if load_info.unexpected_keys: print(f"Florence2 unexpected keys: {len(load_info.unexpected_keys)}") model.eval() image_processor = Florence2ImageProcessorLite.from_preprocessor_config(model_path) tokenizer = None tokenizer_errors = [] for tok_cls in (BartTokenizerFast, BartTokenizer): try: tokenizer = tok_cls.from_pretrained(str(model_path)) break except Exception as exc: tokenizer_errors.append(exc) if tokenizer is None: raise RuntimeError(f"Unable to load Florence2 tokenizer: {tokenizer_errors}") try: processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer) except TypeError as exc: if "CLIPImageProcessor" not in str(exc): raise try: from transformers import CLIPImageProcessor except Exception: from transformers.models.clip import CLIPImageProcessor image_processor = CLIPImageProcessor.from_pretrained(str(model_path)) processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer) return model, processor