import os from collections.abc import Mapping import torch import huggingface_hub from .dataset.utils import process_vision_info from .dataset.data_collator_qwen import prompt_with_special_token, prompt_without_special_token, INSTRUCTION from .utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig, parse_args_with_yaml from .train import create_model_and_processor from pathlib import Path _MODEL_CONFIG_PATH = Path(__file__).parent / f"config/" class HPSv3RewardInferencer(): def __init__(self, config_path=None, checkpoint_path=None, device='cuda', differentiable=False): if config_path is None: config_path = os.path.join(_MODEL_CONFIG_PATH, 'HPSv3_7B.yaml') if checkpoint_path is None: checkpoint_path = huggingface_hub.hf_hub_download("MizzenAI/HPSv3", 'HPSv3.safetensors', repo_type='model') (data_config, training_args, model_config, peft_lora_config), config_path = ( parse_args_with_yaml( (DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config_path, is_train=False ) ) training_args.output_dir = os.path.join( training_args.output_dir, config_path.split("/")[-1].split(".")[0] ) model, processor, peft_config = create_model_and_processor( model_config=model_config, peft_lora_config=peft_lora_config, training_args=training_args, differentiable=differentiable, ) self.device = device self.use_special_tokens = model_config.use_special_tokens if checkpoint_path.endswith('.safetensors'): import safetensors.torch state_dict = safetensors.torch.load_file(checkpoint_path, device="cpu") else: state_dict = torch.load(checkpoint_path , map_location="cpu") if "model" in state_dict: state_dict = state_dict["model"] model.load_state_dict(state_dict, strict=True) model.eval() self.model = model self.processor = processor self.model.to(self.device) self.data_config = data_config def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'): """ Pad the sequences to the maximum length. """ assert padding_side in ['right', 'left'] if sequences.shape[1] >= max_len: return sequences, attention_mask pad_len = max_len - sequences.shape[1] padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0) sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id) attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0) return sequences_padded, attention_mask_padded def _prepare_input(self, data): """ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and handling potential state. """ if isinstance(data, Mapping): return type(data)({k: self._prepare_input(v) for k, v in data.items()}) elif isinstance(data, (tuple, list)): return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): kwargs = {"device": self.device} return data.to(**kwargs) return data def _prepare_inputs(self, inputs): """ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and handling potential state. """ inputs = self._prepare_input(inputs) if len(inputs) == 0: raise ValueError return inputs def prepare_batch(self, image_paths, prompts): max_pixels = 256 * 28 * 28 min_pixels = 256 * 28 * 28 message_list = [] for text, image in zip(prompts, image_paths): out_message = [ { "role": "user", "content": [ { "type": "image", "image": image, "min_pixels": max_pixels, "max_pixels": max_pixels, }, { "type": "text", "text": ( INSTRUCTION.format(text_prompt=text) + prompt_with_special_token if self.use_special_tokens else prompt_without_special_token ), }, ], } ] message_list.append(out_message) image_inputs, _ = process_vision_info(message_list) batch = self.processor( text=self.processor.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True), images=image_inputs, padding=True, return_tensors="pt", videos_kwargs={"do_rescale": True}, ) batch = self._prepare_inputs(batch) return batch def reward(self, image_paths, prompts): batch = self.prepare_batch(image_paths, prompts) rewards = self.model( return_dict=True, **batch )["logits"] return rewards if __name__ == "__main__": config_path = 'config/inference/HPSv3_7B.yaml' checkpoint_path = 'checkpoints/HPSv3_7B.pth' device = 'cuda' dtype = torch.bfloat16 inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, device=device) image_paths = [ "assets/example1.png", "assets/example2.png" ] prompts = [ "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker", "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker" ] rewards = inferencer.reward(image_paths, prompts) print(rewards[0][0].item()) # miu and sigma. we select miu as the final output print(rewards[1][0].item())