|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
from typing import List, Union |
|
|
import os |
|
|
from .config import MODEL_PATHS |
|
|
|
|
|
class PickScore(torch.nn.Module): |
|
|
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS): |
|
|
super().__init__() |
|
|
"""Initialize the Selector with a processor and model. |
|
|
|
|
|
Args: |
|
|
device (Union[str, torch.device]): The device to load the model on. |
|
|
""" |
|
|
self.device = device if isinstance(device, torch.device) else torch.device(device) |
|
|
processor_name_or_path = path.get("clip") |
|
|
model_pretrained_name_or_path = path.get("pickscore") |
|
|
self.processor = AutoProcessor.from_pretrained(processor_name_or_path) |
|
|
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device) |
|
|
|
|
|
def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float: |
|
|
"""Calculate the score for a single image and prompt. |
|
|
|
|
|
Args: |
|
|
image (torch.Tensor): The processed image tensor. |
|
|
prompt (str): The prompt text. |
|
|
softmax (bool): Whether to apply softmax to the scores. |
|
|
|
|
|
Returns: |
|
|
float: The score for the image. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
|
|
|
text_inputs = self.processor( |
|
|
text=prompt, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
return_tensors="pt", |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
image_embs = self.model.get_image_features(pixel_values=image) |
|
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) |
|
|
text_embs = self.model.get_text_features(**text_inputs) |
|
|
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
score = (text_embs @ image_embs.T)[0] |
|
|
if softmax: |
|
|
|
|
|
score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1) |
|
|
|
|
|
return score.cpu().item() |
|
|
|
|
|
@torch.no_grad() |
|
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]: |
|
|
"""Score the images based on the prompt. |
|
|
|
|
|
Args: |
|
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). |
|
|
prompt (str): The prompt text. |
|
|
softmax (bool): Whether to apply softmax to the scores. |
|
|
|
|
|
Returns: |
|
|
List[float]: List of scores for the images. |
|
|
""" |
|
|
try: |
|
|
if isinstance(images, (str, Image.Image)): |
|
|
|
|
|
if isinstance(images, str): |
|
|
pil_image = Image.open(images) |
|
|
else: |
|
|
pil_image = images |
|
|
|
|
|
|
|
|
image_inputs = self.processor( |
|
|
images=pil_image, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
return_tensors="pt", |
|
|
).to(self.device) |
|
|
|
|
|
return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)] |
|
|
elif isinstance(images, list): |
|
|
|
|
|
scores = [] |
|
|
for one_image in images: |
|
|
if isinstance(one_image, str): |
|
|
pil_image = Image.open(one_image) |
|
|
elif isinstance(one_image, Image.Image): |
|
|
pil_image = one_image |
|
|
else: |
|
|
raise TypeError("The type of parameter images is illegal.") |
|
|
|
|
|
|
|
|
image_inputs = self.processor( |
|
|
images=pil_image, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
return_tensors="pt", |
|
|
).to(self.device) |
|
|
|
|
|
scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax)) |
|
|
return scores |
|
|
else: |
|
|
raise TypeError("The type of parameter images is illegal.") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error in scoring images: {e}") |
|
|
|