Spaces:
Running
Running
| ''' | |
| @File : AestheticScore.py | |
| @Time : 2023/02/12 14:54:00 | |
| @Auther : Jiazheng Xu | |
| @Contact : xjz22@mails.tsinghua.edu.cn | |
| @Description: AestheticScore. | |
| * Based on improved-aesthetic-predictor code base | |
| * https://github.com/christophschuhmann/improved-aesthetic-predictor | |
| ''' | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import clip | |
| # if you changed the MLP architecture during training, change it also here: | |
| class MLP(nn.Module): | |
| def __init__(self, input_size): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.layers = nn.Sequential( | |
| nn.Linear(self.input_size, 1024), | |
| # nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(1024, 128), | |
| # nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, 64), | |
| # nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(64, 16), | |
| # nn.ReLU(), | |
| nn.Linear(16, 1) | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| class AestheticScore(nn.Module): | |
| def __init__(self, download_root, device='cpu'): | |
| super().__init__() | |
| self.device = device | |
| self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False, | |
| download_root=download_root) | |
| self.mlp = MLP(768) | |
| if device == "cpu": | |
| self.clip_model.float() | |
| else: | |
| clip.model.convert_weights( | |
| self.clip_model) # Actually this line is unnecessary since clip by default already on float16 | |
| # have clip.logit_scale require no grad. | |
| self.clip_model.logit_scale.requires_grad_(False) | |
| def score(self, prompt, image_path): | |
| if (type(image_path).__name__ == 'list'): | |
| _, rewards = self.inference_rank(prompt, image_path) | |
| return rewards | |
| # image encode | |
| pil_image = Image.open(image_path) | |
| image = self.preprocess(pil_image).unsqueeze(0).to(self.device) | |
| image_features = F.normalize(self.clip_model.encode_image(image)).float() | |
| # score | |
| rewards = self.mlp(image_features) | |
| return rewards.detach().cpu().numpy().item() | |
| def inference_rank(self, prompt, generations_list): | |
| img_set = [] | |
| for generations in generations_list: | |
| # image encode | |
| img_path = generations | |
| pil_image = Image.open(img_path) | |
| image = self.preprocess(pil_image).unsqueeze(0).to(self.device) | |
| image_features = F.normalize(self.clip_model.encode_image(image)) | |
| img_set.append(image_features) | |
| img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim] | |
| rewards = self.mlp(img_features) | |
| rewards = torch.squeeze(rewards) | |
| _, rank = torch.sort(rewards, dim=0, descending=True) | |
| _, indices = torch.sort(rank, dim=0) | |
| indices = indices + 1 | |
| return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist() | |