| |
|
| | import os |
| | from collections.abc import Mapping |
| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | import huggingface_hub |
| | from hpsv3.dataset.utils import process_vision_info |
| | from hpsv3.dataset.data_collator_qwen import prompt_with_special_token, prompt_without_special_token, INSTRUCTION |
| | from hpsv3.utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig, parse_args_with_yaml |
| | from hpsv3.train import create_model_and_processor |
| | from pathlib import Path |
| |
|
| | _MODEL_CONFIG_PATH = Path(__file__).parent / f"config/" |
| |
|
| | class HPSv3RewardInferencer(): |
| | def __init__(self, config_path=None, checkpoint_path=None, device='cuda', differentiable=False): |
| | if config_path is None: |
| | config_path = os.path.join(_MODEL_CONFIG_PATH, 'HPSv3_7B.yaml') |
| |
|
| | if checkpoint_path is None: |
| | checkpoint_path = os.path.join(huggingface_hub.hf_hub_download("xilanhua12138/HPSv3"), 'model.pth') |
| |
|
| | (data_config, training_args, model_config, peft_lora_config), config_path = ( |
| | parse_args_with_yaml( |
| | (DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config_path, is_train=False |
| | ) |
| | ) |
| | training_args.output_dir = os.path.join( |
| | training_args.output_dir, config_path.split("/")[-1].split(".")[0] |
| | ) |
| | model, processor, peft_config = create_model_and_processor( |
| | model_config=model_config, |
| | peft_lora_config=peft_lora_config, |
| | training_args=training_args, |
| | differentiable=differentiable, |
| | ) |
| |
|
| | self.device = device |
| | self.use_special_tokens = model_config.use_special_tokens |
| |
|
| | state_dict = torch.load(checkpoint_path , map_location="cpu") |
| | if "model" in state_dict: |
| | state_dict = state_dict["model"] |
| | model.load_state_dict(state_dict, strict=False) |
| | model.eval() |
| |
|
| | self.model = model |
| | self.processor = processor |
| |
|
| | self.model.to(self.device) |
| | self.data_config = data_config |
| |
|
| | def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'): |
| | """ |
| | Pad the sequences to the maximum length. |
| | """ |
| | assert padding_side in ['right', 'left'] |
| | if sequences.shape[1] >= max_len: |
| | return sequences, attention_mask |
| | |
| | pad_len = max_len - sequences.shape[1] |
| | padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0) |
| |
|
| | sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id) |
| | attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0) |
| |
|
| | return sequences_padded, attention_mask_padded |
| | |
| | def _prepare_input(self, data): |
| | """ |
| | Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and |
| | handling potential state. |
| | """ |
| | if isinstance(data, Mapping): |
| | return type(data)({k: self._prepare_input(v) for k, v in data.items()}) |
| | elif isinstance(data, (tuple, list)): |
| | return type(data)(self._prepare_input(v) for v in data) |
| | elif isinstance(data, torch.Tensor): |
| | kwargs = {"device": self.device} |
| | |
| | |
| | |
| | |
| | |
| | |
| | return data.to(**kwargs) |
| | return data |
| | |
| | def _prepare_inputs(self, inputs): |
| | """ |
| | Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and |
| | handling potential state. |
| | """ |
| | inputs = self._prepare_input(inputs) |
| | if len(inputs) == 0: |
| | raise ValueError |
| | return inputs |
| | |
| | def prepare_batch(self, image_paths, prompts): |
| | max_pixels = 256 * 28 * 28 |
| | min_pixels = 256 * 28 * 28 |
| | message_list = [] |
| | for text, image in zip(prompts, image_paths): |
| | out_message = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | { |
| | "type": "image", |
| | "image": image, |
| | "min_pixels": max_pixels, |
| | "max_pixels": max_pixels, |
| | }, |
| | { |
| | "type": "text", |
| | "text": ( |
| | INSTRUCTION.format(text_prompt=text) |
| | + prompt_with_special_token |
| | if self.use_special_tokens |
| | else prompt_without_special_token |
| | ), |
| | }, |
| | ], |
| | } |
| | ] |
| |
|
| | message_list.append(out_message) |
| |
|
| | image_inputs, _ = process_vision_info(message_list) |
| |
|
| | batch = self.processor( |
| | text=self.processor.apply_chat_template(message_list, tokenize=False, add_generation_prompt=True), |
| | images=image_inputs, |
| | padding=True, |
| | return_tensors="pt", |
| | videos_kwargs={"do_rescale": True}, |
| | ) |
| | batch = self._prepare_inputs(batch) |
| | return batch |
| |
|
| | def reward(self, image_paths, prompts): |
| | |
| | batch = self.prepare_batch(image_paths, prompts) |
| | rewards = self.model( |
| | return_dict=True, |
| | **batch |
| | )["logits"] |
| |
|
| | return rewards |
| |
|
| |
|
| | if __name__ == "__main__": |
| | config_path = '/preflab/shuiyunhao/tasks/HPSv3_official/hpsv3/config/HPSv3_7B.yaml' |
| | checkpoint_path = '/preflab/shuiyunhao/tasks/HPSv3_official/checkpoints/HPSv3_7B/model.pth' |
| | device = 'cuda' |
| | dtype = torch.bfloat16 |
| | inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, differentiable=True, device=device) |
| |
|
| | images = [ |
| | torch.from_numpy(np.array(Image.open("assets/example1.png"))), |
| | torch.from_numpy(np.array(Image.open("assets/example2.png"))) |
| | ] |
| | prompts = [ |
| | "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker", |
| | "cute chibi anime cartoon fox, smiling wagging tail with a small cartoon heart above sticker" |
| | ] |
| | rewards = inferencer.reward(images, prompts) |
| | print(rewards[0][0].item()) |
| | print(rewards[1][0].item()) |
| | |
| |
|
| | loss = rewards[0][0] |
| | print(f"Loss value: {loss.item()}") |
| | print(f"Loss requires_grad: {loss.requires_grad}") |
| | |
| | if loss.requires_grad: |
| | loss.backward(retain_graph=True) |
| | |
| | has_grad = False |
| | for name, param in inferencer.model.named_parameters(): |
| | if param.grad is not None: |
| | has_grad = True |
| | print(f"参数 {name} 有梯度: {param.grad.norm().item():.6f}") |
| | break |
| | |
| | if has_grad: |
| | print("has grad") |
| | else: |
| | print("NO GRAD!!") |
| | else: |
| | print("Final loss does not require gradient computation.") |
| |
|
| | |
| | img_pil = Image.open("assets/example1.png").convert('RGB') |
| | non_diff_result = inferencer.processor.preprocess( |
| | images=[img_pil], |
| | return_tensors="pt" |
| | ) |
| | |
| | img_tensor = torch.from_numpy(np.array(img_pil)).float().permute(2, 0, 1) / 255.0 |
| | diff_result = inferencer.processor.preprocess_tensor(img_tensor) |
| | |
| | if non_diff_result['pixel_values'].shape == diff_result['pixel_values'].shape: |
| | diff = torch.abs(non_diff_result['pixel_values'] - diff_result['pixel_values']).mean() |
| | if diff.item() < 0.01: |
| | print("Right") |
| | else: |
| | print("Different outputs") |
| | else: |
| | print("Shape mismatch between non-differentiable and differentiable outputs.") |
| | |