ColabWan / models /hidream /hidream_handler.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
5.81 kB
import os
import torch
from PIL import Image
from .prompt_enhancer import HIDREAM_PROMPT_ENHANCER_INSTRUCTIONS
_PROJECT_REPO = "DeepBeepMeep/HiDream"
_ASSET_FOLDER = "hidream_o1"
_ASSET_FILES = [
"chat_template.json",
"config.json",
"configuration.json",
"generation_config.json",
"merges.txt",
"preprocessor_config.json",
"tokenizer.json",
"tokenizer_config.json",
"video_preprocessor_config.json",
"vocab.json",
]
class family_handler:
@staticmethod
def query_model_def(base_model_type, model_def):
is_dev = base_model_type == "hidream_o1_dev"
return {
"image_outputs": True,
"sample_solvers": [("Flash", "flash")] if is_dev else [("Default", "default")],
"guidance_max_phases": 0 if is_dev else 1,
"fit_into_canvas_image_refs": 0,
"profiles_dir": [base_model_type],
"flow_shift": True,
"no_negative_prompt": True,
"no_background_removal": True,
"processor_folder": _ASSET_FOLDER,
"vae_block_size": 32,
"text_prompt_enhancer_instructions": HIDREAM_PROMPT_ENHANCER_INSTRUCTIONS,
"image_prompt_enhancer_instructions": HIDREAM_PROMPT_ENHANCER_INSTRUCTIONS,
"text_prompt_enhancer_max_tokens": 512,
"image_prompt_enhancer_max_tokens": 512,
"guide_preprocessing": {
"selection": ["", "V", "PV", "DV", "EV"],
"labels": {"V": "Use Control Image Unchanged"},
},
"image_ref_choices": {
"choices": [
("None", ""),
("Conditional Image is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
("Conditional Images are References", "I"),
],
"letters_filter": "KI",
"default": "",
},
}
@staticmethod
def query_supported_types():
return ["hidream_o1", "hidream_o1_dev"]
@staticmethod
def query_family_maps():
return {}, {"hidream_o1": ["hidream_o1", "hidream_o1_dev"]}
@staticmethod
def query_model_family():
return "hidream"
@staticmethod
def query_family_infos():
return {"hidream": (130, "HiDream")}
@staticmethod
def register_lora_cli_args(parser, lora_root):
parser.add_argument(
"--lora-dir-hidream-o1",
type=str,
default=None,
help=f"Path to a directory that contains HiDream O1 LoRAs (default: {os.path.join(lora_root, 'hidream_o1')})",
)
@staticmethod
def get_lora_dir(base_model_type, args, lora_root):
return getattr(args, "lora_dir_hidream_o1", None) or os.path.join(lora_root, "hidream_o1")
@staticmethod
def query_model_files(computeList, base_model_type, model_def=None):
return [
{
"repoId": _PROJECT_REPO,
"sourceFolderList": [_ASSET_FOLDER],
"fileList": [_ASSET_FILES],
}
]
@staticmethod
def load_model(
model_filename,
model_type=None,
base_model_type=None,
model_def=None,
quantizeTransformer=False,
text_encoder_quantization=None,
dtype=torch.bfloat16,
VAE_dtype=torch.float32,
mixed_precision_transformer=False,
save_quantized=False,
submodel_no_list=None,
text_encoder_filename=None,
**kwargs,
):
from .hidream_main import model_factory
pipe_processor = model_factory(
checkpoint_dir="ckpts",
model_filename=model_filename,
model_type=model_type,
model_def=model_def,
base_model_type=base_model_type,
quantizeTransformer=quantizeTransformer,
dtype=dtype,
save_quantized=save_quantized,
)
return pipe_processor, {"transformer": pipe_processor.transformer}
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
if base_model_type == "hidream_o1_dev":
ui_defaults.update({
"guidance_scale": 0,
"num_inference_steps": 28,
"sample_solver": "flash",
"flow_shift": 1.0,
})
else:
ui_defaults.update({
"guidance_scale": 5,
"num_inference_steps": 50,
"sample_solver": "default",
"flow_shift": 3.0,
})
@staticmethod
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
if base_model_type == "hidream_o1_dev" and ui_defaults.get("sample_solver", "") in ("", "default"):
ui_defaults["sample_solver"] = "flash"
elif ui_defaults.get("sample_solver", "") == "":
ui_defaults["sample_solver"] = "default"
@staticmethod
def preview_latents(base_model_type, latents, meta):
if not torch.is_tensor(latents) or latents.dim() != 4 or latents.shape[0] != 3:
return None
image = latents.detach().float().cpu().clamp(-1, 1)
channels, frames, height, width = image.shape
image = image.permute(0, 2, 1, 3).reshape(channels, height, frames * width)
image = image.add(1).mul(127.5).clamp(0, 255).to(torch.uint8)
preview = Image.fromarray(image.permute(1, 2, 0).numpy())
if preview.height > 0:
scale = 200 / preview.height
resampling_module = getattr(Image, "Resampling", Image)
preview = preview.resize((max(1, int(round(preview.width * scale))), 200), resample=getattr(resampling_module, "BILINEAR", Image.BILINEAR))
return preview