|
|
|
|
|
import os |
|
|
from collections.abc import Mapping |
|
|
import torch |
|
|
import huggingface_hub |
|
|
from .dataset.utils import process_vision_info |
|
|
from .dataset.data_collator_qwen import prompt_with_special_token, prompt_without_special_token, INSTRUCTION |
|
|
from .utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig, parse_args_with_yaml |
|
|
from .train import create_model_and_processor |
|
|
from pathlib import Path |
|
|
|
|
|
_MODEL_CONFIG_PATH = Path(__file__).parent / f"config/" |
|
|
|
|
|
class HPSv3RewardInferencer(): |
|
|
def __init__(self, config_path=None, checkpoint_path=None, device='cuda', differentiable=False): |
|
|
if config_path is None: |
|
|
config_path = os.path.join(_MODEL_CONFIG_PATH, 'HPSv3_7B.yaml') |
|
|
|
|
|
if checkpoint_path is None: |
|
|
checkpoint_path = huggingface_hub.hf_hub_download("MizzenAI/HPSv3", 'HPSv3.safetensors', repo_type='model') |
|
|
|
|
|
(data_config, training_args, model_config, peft_lora_config), config_path = ( |
|
|
parse_args_with_yaml( |
|
|
(DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config_path, is_train=False |
|
|
) |
|
|
) |
|
|
training_args.output_dir = os.path.join( |
|
|
training_args.output_dir, config_path.split("/")[-1].split(".")[0] |
|
|
) |
|
|
model, processor, peft_config = create_model_and_processor( |
|
|
model_config=model_config, |
|
|
peft_lora_config=peft_lora_config, |
|
|
training_args=training_args, |
|
|
differentiable=differentiable, |
|
|
) |
|
|
|
|
|
self.device = device |
|
|
self.use_special_tokens = model_config.use_special_tokens |
|
|
|
|
|
if checkpoint_path.endswith('.safetensors'): |
|
|
import safetensors.torch |
|
|
state_dict = safetensors.torch.load_file(checkpoint_path, device="cpu") |
|
|
else: |
|
|
state_dict = torch.load(checkpoint_path , map_location="cpu") |
|
|
|
|
|
if "model" in state_dict: |
|
|
state_dict = state_dict["model"] |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
model.eval() |
|
|
|
|
|
self.model = model |
|
|
self.processor = processor |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.data_config = data_config |
|
|
|
|
|
def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'): |
|
|
""" |
|
|
Pad the sequences to the maximum length. |
|
|
""" |
|
|
assert padding_side in ['right', 'left'] |
|
|
if sequences.shape[1] >= max_len: |
|
|
return sequences, attention_mask |
|
|
|
|
|
pad_len = max_len - sequences.shape[1] |
|
|
padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0) |
|
|
|
|
|
sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id) |
|
|
attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0) |
|
|
|
|
|
return sequences_padded, attention_mask_padded |
|
|
|
|
|
def _prepare_input(self, data): |
|
|
""" |
|
|
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and |
|
|
handling potential state. |
|
|
""" |
|
|
if isinstance(data, Mapping): |
|
|
return type(data)({k: self._prepare_input(v) for k, v in data.items()}) |
|
|
elif isinstance(data, (tuple, list)): |
|
|
return type(data)(self._prepare_input(v) for v in data) |
|
|
elif isinstance(data, torch.Tensor): |
|
|
kwargs = {"device": self.device} |
|
|
return data.to(**kwargs) |
|
|
return data |
|
|
|
|
|
def _prepare_inputs(self, inputs): |
|
|
""" |
|
|
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and |
|
|
handling potential state. |
|
|
""" |
|
|
inputs = self._prepare_input(inputs) |
|
|
if len(inputs) == 0: |
|
|
raise ValueError |
|
|
return inputs |
|
|
|
|
|
def prepare_batch(self, image_paths, prompts): |
|
|
max_pixels = 256 * 28 * 28 |
|
|
min_pixels = 256 * 28 * 28 |
|
|
message_list = [] |
|
|
for text, image in zip(prompts, image_paths): |
|
|
out_message = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": image, |
|
|
"min_pixels": max_pixels, |
|
|
"max_pixels": max_pixels, |
|
|
}, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": ( |
|
|
INSTRUCTION.format(text_prompt=text) |
|
|
+ prompt_with_special_token |
|
|
if self.use_special_tokens |
|
|
else prompt_without_special_token |
|
|
), |
|
|
}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
message_list.append(out_message) |
|
|
|
|
|
image_inputs, _ = process_vision_info(message_list) |
|
|
|
|
|
batch = self.processor( |
|
|
text=self.processor.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True), |
|
|
images=image_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
videos_kwargs={"do_rescale": True}, |
|
|
) |
|
|
batch = self._prepare_inputs(batch) |
|
|
return batch |
|
|
|
|
|
def reward(self, image_paths, prompts): |
|
|
|
|
|
batch = self.prepare_batch(image_paths, prompts) |
|
|
rewards = self.model( |
|
|
return_dict=True, |
|
|
**batch |
|
|
)["logits"] |
|
|
|
|
|
return rewards |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
config_path = 'config/inference/HPSv3_7B.yaml' |
|
|
checkpoint_path = 'checkpoints/HPSv3_7B.pth' |
|
|
device = 'cuda' |
|
|
dtype = torch.bfloat16 |
|
|
inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, device=device) |
|
|
|
|
|
image_paths = [ |
|
|
"assets/example1.png", |
|
|
"assets/example2.png" |
|
|
] |
|
|
prompts = [ |
|
|
"cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker", |
|
|
"cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker" |
|
|
] |
|
|
rewards = inferencer.reward(image_paths, prompts) |
|
|
print(rewards[0][0].item()) |
|
|
print(rewards[1][0].item()) |