| import os |
|
|
| import torch |
| from PIL import Image |
| from .prompt_enhancer import HIDREAM_PROMPT_ENHANCER_INSTRUCTIONS |
|
|
|
|
| _PROJECT_REPO = "DeepBeepMeep/HiDream" |
| _ASSET_FOLDER = "hidream_o1" |
| _ASSET_FILES = [ |
| "chat_template.json", |
| "config.json", |
| "configuration.json", |
| "generation_config.json", |
| "merges.txt", |
| "preprocessor_config.json", |
| "tokenizer.json", |
| "tokenizer_config.json", |
| "video_preprocessor_config.json", |
| "vocab.json", |
| ] |
|
|
|
|
| class family_handler: |
| @staticmethod |
| def query_model_def(base_model_type, model_def): |
| is_dev = base_model_type == "hidream_o1_dev" |
| return { |
| "image_outputs": True, |
| "sample_solvers": [("Flash", "flash")] if is_dev else [("Default", "default")], |
| "guidance_max_phases": 0 if is_dev else 1, |
| "fit_into_canvas_image_refs": 0, |
| "profiles_dir": [base_model_type], |
| "flow_shift": True, |
| "no_negative_prompt": True, |
| "no_background_removal": True, |
| "processor_folder": _ASSET_FOLDER, |
| "vae_block_size": 32, |
| "text_prompt_enhancer_instructions": HIDREAM_PROMPT_ENHANCER_INSTRUCTIONS, |
| "image_prompt_enhancer_instructions": HIDREAM_PROMPT_ENHANCER_INSTRUCTIONS, |
| "text_prompt_enhancer_max_tokens": 512, |
| "image_prompt_enhancer_max_tokens": 512, |
| "guide_preprocessing": { |
| "selection": ["", "V", "PV", "DV", "EV"], |
| "labels": {"V": "Use Control Image Unchanged"}, |
| }, |
| "image_ref_choices": { |
| "choices": [ |
| ("None", ""), |
| ("Conditional Image is first Main Subject / Landscape and may be followed by People / Objects", "KI"), |
| ("Conditional Images are References", "I"), |
| ], |
| "letters_filter": "KI", |
| "default": "", |
| }, |
| } |
|
|
| @staticmethod |
| def query_supported_types(): |
| return ["hidream_o1", "hidream_o1_dev"] |
|
|
| @staticmethod |
| def query_family_maps(): |
| return {}, {"hidream_o1": ["hidream_o1", "hidream_o1_dev"]} |
|
|
| @staticmethod |
| def query_model_family(): |
| return "hidream" |
|
|
| @staticmethod |
| def query_family_infos(): |
| return {"hidream": (130, "HiDream")} |
|
|
| @staticmethod |
| def register_lora_cli_args(parser, lora_root): |
| parser.add_argument( |
| "--lora-dir-hidream-o1", |
| type=str, |
| default=None, |
| help=f"Path to a directory that contains HiDream O1 LoRAs (default: {os.path.join(lora_root, 'hidream_o1')})", |
| ) |
|
|
| @staticmethod |
| def get_lora_dir(base_model_type, args, lora_root): |
| return getattr(args, "lora_dir_hidream_o1", None) or os.path.join(lora_root, "hidream_o1") |
|
|
| @staticmethod |
| def query_model_files(computeList, base_model_type, model_def=None): |
| return [ |
| { |
| "repoId": _PROJECT_REPO, |
| "sourceFolderList": [_ASSET_FOLDER], |
| "fileList": [_ASSET_FILES], |
| } |
| ] |
|
|
| @staticmethod |
| def load_model( |
| model_filename, |
| model_type=None, |
| base_model_type=None, |
| model_def=None, |
| quantizeTransformer=False, |
| text_encoder_quantization=None, |
| dtype=torch.bfloat16, |
| VAE_dtype=torch.float32, |
| mixed_precision_transformer=False, |
| save_quantized=False, |
| submodel_no_list=None, |
| text_encoder_filename=None, |
| **kwargs, |
| ): |
| from .hidream_main import model_factory |
|
|
| pipe_processor = model_factory( |
| checkpoint_dir="ckpts", |
| model_filename=model_filename, |
| model_type=model_type, |
| model_def=model_def, |
| base_model_type=base_model_type, |
| quantizeTransformer=quantizeTransformer, |
| dtype=dtype, |
| save_quantized=save_quantized, |
| ) |
| return pipe_processor, {"transformer": pipe_processor.transformer} |
|
|
| @staticmethod |
| def update_default_settings(base_model_type, model_def, ui_defaults): |
| if base_model_type == "hidream_o1_dev": |
| ui_defaults.update({ |
| "guidance_scale": 0, |
| "num_inference_steps": 28, |
| "sample_solver": "flash", |
| "flow_shift": 1.0, |
| }) |
| else: |
| ui_defaults.update({ |
| "guidance_scale": 5, |
| "num_inference_steps": 50, |
| "sample_solver": "default", |
| "flow_shift": 3.0, |
| }) |
|
|
| @staticmethod |
| def fix_settings(base_model_type, settings_version, model_def, ui_defaults): |
| if base_model_type == "hidream_o1_dev" and ui_defaults.get("sample_solver", "") in ("", "default"): |
| ui_defaults["sample_solver"] = "flash" |
| elif ui_defaults.get("sample_solver", "") == "": |
| ui_defaults["sample_solver"] = "default" |
|
|
| @staticmethod |
| def preview_latents(base_model_type, latents, meta): |
| if not torch.is_tensor(latents) or latents.dim() != 4 or latents.shape[0] != 3: |
| return None |
| image = latents.detach().float().cpu().clamp(-1, 1) |
| channels, frames, height, width = image.shape |
| image = image.permute(0, 2, 1, 3).reshape(channels, height, frames * width) |
| image = image.add(1).mul(127.5).clamp(0, 255).to(torch.uint8) |
| preview = Image.fromarray(image.permute(1, 2, 0).numpy()) |
| if preview.height > 0: |
| scale = 200 / preview.height |
| resampling_module = getattr(Image, "Resampling", Image) |
| preview = preview.resize((max(1, int(round(preview.width * scale))), 200), resample=getattr(resampling_module, "BILINEAR", Image.BILINEAR)) |
| return preview |
|
|