| from __future__ import annotations
|
|
|
| from dataclasses import dataclass, field
|
| from pathlib import Path
|
| from typing import Any, Tuple
|
|
|
| import torch
|
| from mmgp import offload
|
| from safetensors import safe_open
|
| from shared.utils import files_locator as fl
|
|
|
| from .florence2 import Florence2Config, Florence2ForConditionalGeneration, Florence2Processor
|
| from .florence2.image_processing_florence2 import Florence2ImageProcessorLite
|
|
|
| from transformers import AutoTokenizer, BartTokenizer, BartTokenizerFast
|
|
|
|
|
| FLORENCE2_FOLDER = "Florence2"
|
| LLAMA32_FOLDER = "Llama3_2"
|
| LLAMAJOY_FOLDER = "llama-joycaption-beta-one-hf-llava"
|
| PROMPT_ENHANCER_REPO = "DeepBeepMeep/LTX_Video"
|
|
|
|
|
| @dataclass(slots=True)
|
| class PromptEnhancerRuntime:
|
| image_caption_model: Any = None
|
| image_caption_processor: Any = None
|
| llm_model: Any = None
|
| llm_tokenizer: Any = None
|
| pipe_models: dict[str, Any] = field(default_factory=dict)
|
| budgets: dict[str, int] = field(default_factory=dict)
|
|
|
|
|
| def ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled: int, qwen_backend: str = "quanto_int8"):
|
| enhancer_enabled = int(enhancer_enabled)
|
| if enhancer_enabled == 1:
|
| process_files_def(
|
| repoId=PROMPT_ENHANCER_REPO,
|
| sourceFolderList=[FLORENCE2_FOLDER, LLAMA32_FOLDER],
|
| fileList=[
|
| ["config.json", "configuration_florence2.py", "model.safetensors", "preprocessor_config.json", "tokenizer.json", "tokenizer_config.json"],
|
| ["config.json", "generation_config.json", "Llama3_2_quanto_bf16_int8.safetensors", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"],
|
| ],
|
| )
|
| return
|
| if enhancer_enabled == 2:
|
| process_files_def(
|
| repoId=PROMPT_ENHANCER_REPO,
|
| sourceFolderList=[FLORENCE2_FOLDER, LLAMAJOY_FOLDER],
|
| fileList=[
|
| ["config.json", "configuration_florence2.py", "model.safetensors", "preprocessor_config.json", "tokenizer.json", "tokenizer_config.json"],
|
| ["config.json", "llama_config.json", "llama_joycaption_quanto_bf16_int8.safetensors", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"],
|
| ],
|
| )
|
| return
|
| if enhancer_enabled in (3, 4):
|
| from .qwen35_vl import ensure_qwen35_prompt_enhancer_assets, get_qwen35_prompt_enhancer_variant
|
|
|
| ensure_qwen35_prompt_enhancer_assets(process_files_def, backend=qwen_backend, variant=get_qwen35_prompt_enhancer_variant(enhancer_enabled))
|
|
|
|
|
| def unload_prompt_enhancer_models(*models):
|
| seen = set()
|
| for model in models:
|
| if model is None:
|
| continue
|
| model_id = id(model)
|
| if model_id in seen:
|
| continue
|
| seen.add(model_id)
|
| unload = getattr(model, "unload", None)
|
| if callable(unload):
|
| unload()
|
|
|
|
|
| def _set_pad_token_from_tokenizer(model, tokenizer):
|
| model.generation_config.pad_token = tokenizer.eos_token
|
| if model.generation_config.pad_token_id is None:
|
| eos_token_id = model.generation_config.eos_token_id
|
| model.generation_config.pad_token_id = eos_token_id[0] if isinstance(eos_token_id, list) else eos_token_id
|
|
|
|
|
| def _load_llama32_prompt_enhancer():
|
| llm_model = offload.fast_load_transformers_model(
|
| fl.locate_file(f"{LLAMA32_FOLDER}/Llama3_2_quanto_bf16_int8.safetensors"),
|
| defaultConfigPath=fl.locate_file(f"{LLAMA32_FOLDER}/config.json", error_if_none=False),
|
| configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"},
|
| writable_tensors=False,
|
| )
|
| llm_model._validate_model_kwargs = lambda *_args, **_kwargs: None
|
| llm_model._offload_hooks = ["generate"]
|
| llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMA32_FOLDER))
|
| _set_pad_token_from_tokenizer(llm_model, llm_tokenizer)
|
| llm_model.eval()
|
| return llm_model, llm_tokenizer, 5000
|
|
|
|
|
| def _load_joycaption_prompt_enhancer():
|
| def preprocess_sd(sd, quant_map=None, tied_map=None):
|
| rules = {"model.language_model": "model", "model.vision_tower": None, "model.multi_modal_projector": None}
|
| return tuple(offload.map_state_dict([sd, quant_map, tied_map], rules))
|
|
|
| llm_model = offload.fast_load_transformers_model(
|
| fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_joycaption_quanto_bf16_int8.safetensors"),
|
| forcedConfigPath=fl.locate_file(f"{LLAMAJOY_FOLDER}/llama_config.json", error_if_none=False),
|
| configKwargs={"attn_implementation": "sdpa", "hidden_act": "silu"},
|
| preprocess_sd=preprocess_sd,
|
| writable_tensors=False,
|
| )
|
| llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder(LLAMAJOY_FOLDER))
|
| _set_pad_token_from_tokenizer(llm_model, llm_tokenizer)
|
| llm_model.eval()
|
| return llm_model, llm_tokenizer, 10000
|
|
|
|
|
| def load_prompt_enhancer_runtime(process_files_def, enhancer_enabled: int, lm_decoder_engine: str = "", qwen_backend: str = "quanto_int8") -> PromptEnhancerRuntime:
|
| enhancer_enabled = int(enhancer_enabled)
|
| runtime = PromptEnhancerRuntime()
|
| if enhancer_enabled <= 0:
|
| return runtime
|
|
|
| ensure_prompt_enhancer_assets(process_files_def, enhancer_enabled=enhancer_enabled, qwen_backend=qwen_backend)
|
|
|
| if enhancer_enabled in (3, 4):
|
| from .qwen35_text import load_qwen35_text_prompt_enhancer
|
| from .qwen35_vl import (
|
| enhancer_quantization_QUANTO_INT8,
|
| get_qwen35_assets_dir_name,
|
| get_qwen35_prompt_enhancer_variant,
|
| load_qwen35_vl_prompt_enhancer,
|
| )
|
|
|
| backend = qwen_backend or enhancer_quantization_QUANTO_INT8
|
| qwen35_variant = get_qwen35_prompt_enhancer_variant(enhancer_enabled)
|
| assets_dir_name = get_qwen35_assets_dir_name(qwen35_variant)
|
| assets_dir = fl.locate_folder(assets_dir_name, error_if_none=False) or fl.get_download_location(assets_dir_name)
|
| runtime.llm_model = load_qwen35_text_prompt_enhancer(
|
| assets_dir=assets_dir,
|
| backend=backend,
|
| attn_implementation="sdpa",
|
| requested_lm_engine=lm_decoder_engine,
|
| variant=qwen35_variant,
|
| )
|
| runtime.llm_tokenizer = getattr(runtime.llm_model, "_prompt_enhancer_tokenizer", None)
|
| runtime.llm_model.eval()
|
| runtime.image_caption_model, vision_tower_model = load_qwen35_vl_prompt_enhancer(
|
| assets_dir=assets_dir,
|
| attn_implementation="sdpa",
|
| text_model=runtime.llm_model,
|
| backend=backend,
|
| variant=qwen35_variant,
|
| )
|
| runtime.image_caption_processor = getattr(runtime.image_caption_model, "_prompt_enhancer_processor", None)
|
| runtime.image_caption_model.eval()
|
| runtime.pipe_models["prompt_enhancer_image_caption_vision_tower_model"] = vision_tower_model
|
| runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model
|
| runtime.budgets["prompt_enhancer_image_caption_vision_tower_model"] = 3000
|
| runtime.budgets["prompt_enhancer_llm_model"] = 10000
|
| return runtime
|
|
|
| runtime.image_caption_model, runtime.image_caption_processor = load_florence2(fl.locate_folder(FLORENCE2_FOLDER), attn_implementation="sdpa")
|
| runtime.image_caption_model._model_dtype = torch.float
|
| runtime.image_caption_model.eval()
|
| runtime.pipe_models["prompt_enhancer_image_caption_model"] = runtime.image_caption_model
|
| if enhancer_enabled == 1:
|
| runtime.llm_model, runtime.llm_tokenizer, budget = _load_llama32_prompt_enhancer()
|
| else:
|
| runtime.llm_model, runtime.llm_tokenizer, budget = _load_joycaption_prompt_enhancer()
|
| runtime.pipe_models["prompt_enhancer_llm_model"] = runtime.llm_model
|
| runtime.budgets["prompt_enhancer_llm_model"] = budget
|
| return runtime
|
|
|
|
|
| def _load_state_dict(weights_path: Path) -> dict:
|
| if weights_path.suffix == ".safetensors":
|
| state_dict = {}
|
| with safe_open(str(weights_path), framework="pt", device="cpu") as f:
|
| for key in f.keys():
|
| state_dict[key] = f.get_tensor(key)
|
| return state_dict
|
| return torch.load(str(weights_path), map_location="cpu")
|
|
|
|
|
| def _resolve_weights_path(model_path: Path) -> Path:
|
|
|
| preferred = model_path / "xmodel.safetensors"
|
| if preferred.exists():
|
| return preferred
|
| fallback = model_path / "model.safetensors"
|
| if fallback.exists():
|
| return fallback
|
| fallback = model_path / "pytorch_model.bin"
|
| if fallback.exists():
|
| return fallback
|
| raise FileNotFoundError(
|
| f"No Florence2 weights found in {model_path} (expected model.safetensors/xmodel.safetensors/pytorch_model.bin)"
|
| )
|
|
|
|
|
| def load_florence2(
|
| model_dir: str,
|
| attn_implementation: str = "sdpa",
|
| ) -> Tuple[Florence2ForConditionalGeneration, Florence2Processor]:
|
| model_path = Path(model_dir)
|
| if not model_path.exists():
|
| raise FileNotFoundError(f"Florence2 folder not found: {model_path}")
|
|
|
| config = Florence2Config.from_pretrained(str(model_path))
|
| if attn_implementation:
|
| config._attn_implementation = attn_implementation
|
| weights_path = _resolve_weights_path(model_path)
|
| state_dict = _load_state_dict(weights_path)
|
|
|
| model = Florence2ForConditionalGeneration(config)
|
| load_info = model.load_state_dict(state_dict, strict=False)
|
| del state_dict
|
| if load_info.missing_keys:
|
| allowed_missing = {
|
| "language_model.model.encoder.embed_tokens.weight",
|
| "language_model.model.decoder.embed_tokens.weight",
|
| }
|
| extra_missing = [k for k in load_info.missing_keys if k not in allowed_missing]
|
| if extra_missing:
|
| print(f"Florence2 missing keys: {extra_missing}")
|
| if load_info.unexpected_keys:
|
| print(f"Florence2 unexpected keys: {len(load_info.unexpected_keys)}")
|
| model.eval()
|
|
|
| image_processor = Florence2ImageProcessorLite.from_preprocessor_config(model_path)
|
| tokenizer = None
|
| tokenizer_errors = []
|
| for tok_cls in (BartTokenizerFast, BartTokenizer):
|
| try:
|
| tokenizer = tok_cls.from_pretrained(str(model_path))
|
| break
|
| except Exception as exc:
|
| tokenizer_errors.append(exc)
|
| if tokenizer is None:
|
| raise RuntimeError(f"Unable to load Florence2 tokenizer: {tokenizer_errors}")
|
| try:
|
| processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
| except TypeError as exc:
|
| if "CLIPImageProcessor" not in str(exc):
|
| raise
|
| try:
|
| from transformers import CLIPImageProcessor
|
| except Exception:
|
| from transformers.models.clip import CLIPImageProcessor
|
| image_processor = CLIPImageProcessor.from_pretrained(str(model_path))
|
| processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
|
|
| return model, processor
|
|
|